diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/__init__.py b/egs/librispeech/ASR/pruned_transducer_stateless3/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/train.py b/egs/librispeech/ASR/pruned_transducer_stateless5/train.py index 7252ee4363..5d198f806a 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/train.py @@ -452,7 +452,7 @@ def get_transducer_model(params: AttributeDict) -> nn.Module: def load_checkpoint_if_available( params: AttributeDict, model: nn.Module, - model_avg: nn.Module = None, + model_avg: Optional[nn.Module] = None, optimizer: Optional[torch.optim.Optimizer] = None, scheduler: Optional[LRSchedulerType] = None, ) -> Optional[Dict[str, Any]]: diff --git a/egs/librispeech/ASR/transducer_lstm/beam_search.py b/egs/librispeech/ASR/transducer_lstm/beam_search.py deleted file mode 100644 index 3531a96334..0000000000 --- a/egs/librispeech/ASR/transducer_lstm/beam_search.py +++ /dev/null @@ -1,222 +0,0 @@ -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from dataclasses import dataclass -from typing import Dict, List, Optional, Tuple - -import torch -from model import Transducer - - -def greedy_search(model: Transducer, encoder_out: torch.Tensor) -> List[int]: - """ - Args: - model: - An instance of `Transducer`. - encoder_out: - A tensor of shape (N, T, C) from the encoder. Support only N==1 for now. - Returns: - Return the decoded result. - """ - assert encoder_out.ndim == 3 - - # support only batch_size == 1 for now - assert encoder_out.size(0) == 1, encoder_out.size(0) - blank_id = model.decoder.blank_id - device = model.device - - sos = torch.tensor([blank_id], device=device, dtype=torch.int64).reshape( - 1, 1 - ) - decoder_out, (h, c) = model.decoder(sos) - T = encoder_out.size(1) - t = 0 - hyp = [] - - sym_per_frame = 0 - sym_per_utt = 0 - - max_sym_per_utt = 1000 - max_sym_per_frame = 3 - - while t < T and sym_per_utt < max_sym_per_utt: - # fmt: off - current_encoder_out = encoder_out[:, t:t+1, :] - # fmt: on - logits = model.joiner(current_encoder_out, decoder_out) - # logits is (1, 1, 1, vocab_size) - - log_prob = logits.log_softmax(dim=-1) - # log_prob is (1, 1, 1, vocab_size) - # TODO: Use logits.argmax() - y = log_prob.argmax() - if y != blank_id: - hyp.append(y.item()) - y = y.reshape(1, 1) - decoder_out, (h, c) = model.decoder(y, (h, c)) - - sym_per_utt += 1 - sym_per_frame += 1 - - if y == blank_id or sym_per_frame > max_sym_per_frame: - sym_per_frame = 0 - t += 1 - - return hyp - - -@dataclass -class Hypothesis: - ys: List[int] # the predicted sequences so far - log_prob: float # The log prob of ys - - # Optional decoder state. We assume it is LSTM for now, - # so the state is a tuple (h, c) - decoder_state: Optional[Tuple[torch.Tensor, torch.Tensor]] = None - - -def beam_search( - model: Transducer, - encoder_out: torch.Tensor, - beam: int = 5, -) -> List[int]: - """ - It implements Algorithm 1 in https://arxiv.org/pdf/1211.3711.pdf - - espnet/nets/beam_search_transducer.py#L247 is used as a reference. - - Args: - model: - An instance of `Transducer`. - encoder_out: - A tensor of shape (N, T, C) from the encoder. Support only N==1 for now. - beam: - Beam size. - Returns: - Return the decoded result. - """ - assert encoder_out.ndim == 3 - - # support only batch_size == 1 for now - assert encoder_out.size(0) == 1, encoder_out.size(0) - blank_id = model.decoder.blank_id - sos_id = model.decoder.sos_id - device = model.device - - sos = torch.tensor([blank_id], device=device).reshape(1, 1) - decoder_out, (h, c) = model.decoder(sos) - T = encoder_out.size(1) - t = 0 - B = [Hypothesis(ys=[blank_id], log_prob=0.0, decoder_state=None)] - max_u = 20000 # terminate after this number of steps - u = 0 - - cache: Dict[ - str, Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]] - ] = {} - - while t < T and u < max_u: - # fmt: off - current_encoder_out = encoder_out[:, t:t+1, :] - # fmt: on - A = B - B = [] - # for hyp in A: - # for h in A: - # if h.ys == hyp.ys[:-1]: - # # update the score of hyp - # decoder_input = torch.tensor( - # [h.ys[-1]], device=device - # ).reshape(1, 1) - # decoder_out, _ = model.decoder( - # decoder_input, h.decoder_state - # ) - # logits = model.joiner(current_encoder_out, decoder_out) - # log_prob = logits.log_softmax(dim=-1) - # log_prob = log_prob.squeeze() - # hyp.log_prob += h.log_prob + log_prob[hyp.ys[-1]].item() - - while u < max_u: - y_star = max(A, key=lambda hyp: hyp.log_prob) - A.remove(y_star) - - # Note: y_star.ys is unhashable, i.e., cannot be used - # as a key into a dict - cached_key = "_".join(map(str, y_star.ys)) - - if cached_key not in cache: - decoder_input = torch.tensor( - [y_star.ys[-1]], device=device - ).reshape(1, 1) - - decoder_out, decoder_state = model.decoder( - decoder_input, - y_star.decoder_state, - ) - cache[cached_key] = (decoder_out, decoder_state) - else: - decoder_out, decoder_state = cache[cached_key] - - logits = model.joiner(current_encoder_out, decoder_out) - log_prob = logits.log_softmax(dim=-1) - # log_prob is (1, 1, 1, vocab_size) - log_prob = log_prob.squeeze() - # Now log_prob is (vocab_size,) - - # If we choose blank here, add the new hypothesis to B. - # Otherwise, add the new hypothesis to A - - # First, choose blank - skip_log_prob = log_prob[blank_id] - new_y_star_log_prob = y_star.log_prob + skip_log_prob.item() - - # ys[:] returns a copy of ys - new_y_star = Hypothesis( - ys=y_star.ys[:], - log_prob=new_y_star_log_prob, - # Caution: Use y_star.decoder_state here - decoder_state=y_star.decoder_state, - ) - B.append(new_y_star) - - # Second, choose other labels - for i, v in enumerate(log_prob.tolist()): - if i in (blank_id, sos_id): - continue - new_ys = y_star.ys + [i] - new_log_prob = y_star.log_prob + v - new_hyp = Hypothesis( - ys=new_ys, - log_prob=new_log_prob, - decoder_state=decoder_state, - ) - A.append(new_hyp) - u += 1 - # check whether B contains more than "beam" elements more probable - # than the most probable in A - A_most_probable = max(A, key=lambda hyp: hyp.log_prob) - B = sorted( - [hyp for hyp in B if hyp.log_prob > A_most_probable.log_prob], - key=lambda hyp: hyp.log_prob, - reverse=True, - ) - if len(B) >= beam: - B = B[:beam] - break - t += 1 - best_hyp = max(B, key=lambda hyp: hyp.log_prob / len(hyp.ys[1:])) - ys = best_hyp.ys[1:] # [1:] to remove the blank - return ys diff --git a/egs/librispeech/ASR/transducer_lstm/beam_search.py b/egs/librispeech/ASR/transducer_lstm/beam_search.py new file mode 120000 index 0000000000..8554e44ccf --- /dev/null +++ b/egs/librispeech/ASR/transducer_lstm/beam_search.py @@ -0,0 +1 @@ +../pruned_transducer_stateless2/beam_search.py \ No newline at end of file diff --git a/egs/librispeech/ASR/transducer_lstm/decode.py b/egs/librispeech/ASR/transducer_lstm/decode.py index 18ae5234c7..30d5d15a4b 100755 --- a/egs/librispeech/ASR/transducer_lstm/decode.py +++ b/egs/librispeech/ASR/transducer_lstm/decode.py @@ -19,20 +19,40 @@ Usage: (1) greedy search ./transducer_lstm/decode.py \ - --epoch 14 \ - --avg 7 \ + --epoch 28 \ + --avg 15 \ --exp-dir ./transducer_lstm/exp \ --max-duration 100 \ --decoding-method greedy_search -(2) beam search +(2) beam search ./transducer_lstm/decode.py \ - --epoch 14 \ - --avg 7 \ + --epoch 28 \ + --avg 15 \ --exp-dir ./transducer_lstm/exp \ --max-duration 100 \ --decoding-method beam_search \ - --beam-size 8 + --beam-size 4 + +(3) modified beam search +./transducer_lstm/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./transducer_lstm/exp \ + --max-duration 100 \ + --decoding-method modified_beam_search \ + --beam-size 4 + +(4) fast beam search +./transducer_lstm/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./transducer_lstm/exp \ + --max-duration 1500 \ + --decoding-method fast_beam_search \ + --beam 4 \ + --max-contexts 4 \ + --max-states 8 """ @@ -40,20 +60,27 @@ import logging from collections import defaultdict from pathlib import Path -from typing import Dict, List, Tuple +from typing import Dict, List, Optional, Tuple +import k2 import sentencepiece as spm import torch import torch.nn as nn from asr_datamodule import LibriSpeechAsrDataModule -from beam_search import beam_search, greedy_search -from decoder import Decoder -from encoder import LstmEncoder -from joiner import Joiner -from model import Transducer - -from icefall.checkpoint import average_checkpoints, load_checkpoint -from icefall.env import get_env_info +from beam_search import ( + beam_search, + fast_beam_search, + greedy_search, + greedy_search_batch, + modified_beam_search, +) +from train import get_params, get_transducer_model + +from icefall.checkpoint import ( + average_checkpoints, + find_checkpoints, + load_checkpoint, +) from icefall.utils import ( AttributeDict, setup_logger, @@ -70,17 +97,29 @@ def get_parser(): parser.add_argument( "--epoch", type=int, - default=77, - help="It specifies the checkpoint to use for decoding." - "Note: Epoch counts from 0.", + default=28, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 0. + You can specify --avg to use more checkpoints for model averaging.""", ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + parser.add_argument( "--avg", type=int, - default=55, + default=15, help="Number of checkpoints to average. Automatically select " "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. ", + "'--epoch' and '--iter'", ) parser.add_argument( @@ -104,84 +143,62 @@ def get_parser(): help="""Possible values are: - greedy_search - beam_search + - modified_beam_search + - fast_beam_search """, ) parser.add_argument( "--beam-size", type=int, - default=5, - help="Used only when --decoding-method is beam_search", + default=4, + help="""An integer indicating how many candidates we will keep for each + frame. Used only when --decoding-method is beam_search or + modified_beam_search.""", ) - return parser - - -def get_params() -> AttributeDict: - params = AttributeDict( - { - # parameters for conformer - "feature_dim": 80, - "encoder_out_dim": 512, - "subsampling_factor": 4, - "encoder_hidden_size": 1024, - "num_encoder_layers": 4, - "proj_size": 512, - "vgg_frontend": False, - # decoder params - "decoder_embedding_dim": 1024, - "num_decoder_layers": 4, - "decoder_hidden_dim": 512, - "env_info": get_env_info(), - } + parser.add_argument( + "--beam", + type=float, + default=4, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --decoding-method is fast_beam_search""", ) - return params - -def get_encoder_model(params: AttributeDict): - encoder = LstmEncoder( - num_features=params.feature_dim, - hidden_size=params.encoder_hidden_size, - output_dim=params.encoder_out_dim, - subsampling_factor=params.subsampling_factor, - num_encoder_layers=params.num_encoder_layers, - vgg_frontend=params.vgg_frontend, - ) - return encoder - - -def get_decoder_model(params: AttributeDict): - decoder = Decoder( - vocab_size=params.vocab_size, - embedding_dim=params.decoder_embedding_dim, - blank_id=params.blank_id, - sos_id=params.sos_id, - num_layers=params.num_decoder_layers, - hidden_dim=params.decoder_hidden_dim, - output_dim=params.encoder_out_dim, + parser.add_argument( + "--max-contexts", + type=int, + default=4, + help="""Used only when --decoding-method is + fast_beam_search""", ) - return decoder - -def get_joiner_model(params: AttributeDict): - joiner = Joiner( - input_dim=params.encoder_out_dim, - output_dim=params.vocab_size, + parser.add_argument( + "--max-states", + type=int, + default=8, + help="""Used only when --decoding-method is + fast_beam_search""", ) - return joiner - -def get_transducer_model(params: AttributeDict): - encoder = get_encoder_model(params) - decoder = get_decoder_model(params) - joiner = get_joiner_model(params) - - model = Transducer( - encoder=encoder, - decoder=decoder, - joiner=joiner, + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", + ) + parser.add_argument( + "--max-sym-per-frame", + type=int, + default=1, + help="""Maximum number of symbols per frame. + Used only when --decoding_method is greedy_search""", ) - return model + + return parser def decode_one_batch( @@ -189,6 +206,7 @@ def decode_one_batch( model: nn.Module, sp: spm.SentencePieceProcessor, batch: dict, + decoding_graph: Optional[k2.Fsa] = None, ) -> Dict[str, List[List[str]]]: """Decode one batch and return the result in a dict. The dict has the following format: @@ -211,6 +229,9 @@ def decode_one_batch( It is the return value from iterating `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation for the format of the `batch`. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search. Returns: Return the decoding result. See above description for the format of the returned dict. @@ -229,28 +250,74 @@ def decode_one_batch( x=feature, x_lens=feature_lens ) hyps = [] - batch_size = encoder_out.size(0) - - for i in range(batch_size): - # fmt: off - encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] - # fmt: on - if params.decoding_method == "greedy_search": - hyp = greedy_search(model=model, encoder_out=encoder_out_i) - elif params.decoding_method == "beam_search": - hyp = beam_search( - model=model, encoder_out=encoder_out_i, beam=params.beam_size - ) - else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) - hyps.append(sp.decode(hyp).split()) + + if params.decoding_method == "fast_beam_search": + hyp_tokens = fast_beam_search( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif ( + params.decoding_method == "greedy_search" + and params.max_sym_per_frame == 1 + ): + hyp_tokens = greedy_search_batch( + model=model, + encoder_out=encoder_out, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search": + hyp_tokens = modified_beam_search( + model=model, + encoder_out=encoder_out, + beam=params.beam_size, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + else: + batch_size = encoder_out.size(0) + + for i in range(batch_size): + # fmt: off + encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] + # fmt: on + if params.decoding_method == "greedy_search": + hyp = greedy_search( + model=model, + encoder_out=encoder_out_i, + max_sym_per_frame=params.max_sym_per_frame, + ) + elif params.decoding_method == "beam_search": + hyp = beam_search( + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, + ) + else: + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) + hyps.append(sp.decode(hyp).split()) if params.decoding_method == "greedy_search": return {"greedy_search": hyps} + elif params.decoding_method == "fast_beam_search": + return { + ( + f"beam_{params.beam}_" + f"max_contexts_{params.max_contexts}_" + f"max_states_{params.max_states}" + ): hyps + } else: - return {f"beam_{params.beam_size}": hyps} + return {f"beam_size_{params.beam_size}": hyps} def decode_dataset( @@ -258,6 +325,7 @@ def decode_dataset( params: AttributeDict, model: nn.Module, sp: spm.SentencePieceProcessor, + decoding_graph: Optional[k2.Fsa] = None, ) -> Dict[str, List[Tuple[List[str], List[str]]]]: """Decode dataset. @@ -270,6 +338,9 @@ def decode_dataset( The neural model. sp: The BPE model. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search. Returns: Return a dict, whose key may be "greedy_search" if greedy search is used, or it may be "beam_7" if beam size of 7 is used. @@ -297,6 +368,7 @@ def decode_dataset( params=params, model=model, sp=sp, + decoding_graph=decoding_graph, batch=batch, ) @@ -374,12 +446,30 @@ def main(): params = get_params() params.update(vars(args)) - assert params.decoding_method in ("greedy_search", "beam_search") + assert params.decoding_method in ( + "greedy_search", + "beam_search", + "fast_beam_search", + "modified_beam_search", + ) params.res_dir = params.exp_dir / params.decoding_method - params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" - if params.decoding_method == "beam_search": - params.suffix += f"-beam-{params.beam_size}" + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + if "fast_beam_search" in params.decoding_method: + params.suffix += f"-beam-{params.beam}" + params.suffix += f"-max-contexts-{params.max_contexts}" + params.suffix += f"-max-states-{params.max_states}" + elif "beam_search" in params.decoding_method: + params.suffix += ( + f"-{params.decoding_method}-beam-size-{params.beam_size}" + ) + else: + params.suffix += f"-context-{params.context_size}" + params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") logging.info("Decoding started") @@ -393,9 +483,9 @@ def main(): sp = spm.SentencePieceProcessor() sp.load(params.bpe_model) - # and are defined in local/train_bpe_model.py + # and is defined in local/train_bpe_model.py params.blank_id = sp.piece_to_id("") - params.sos_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") params.vocab_size = sp.get_piece_size() logging.info(params) @@ -403,7 +493,24 @@ def main(): logging.info("About to create model") model = get_transducer_model(params) - if params.avg == 1: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) else: start = params.epoch - params.avg + 1 @@ -419,6 +526,11 @@ def main(): model.eval() model.device = device + if params.decoding_method == "fast_beam_search": + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + else: + decoding_graph = None + num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") @@ -439,6 +551,7 @@ def main(): params=params, model=model, sp=sp, + decoding_graph=decoding_graph, ) save_results( @@ -450,8 +563,5 @@ def main(): logging.info("Done!") -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - if __name__ == "__main__": main() diff --git a/egs/librispeech/ASR/transducer_lstm/decoder.py b/egs/librispeech/ASR/transducer_lstm/decoder.py deleted file mode 100644 index 4d531bde11..0000000000 --- a/egs/librispeech/ASR/transducer_lstm/decoder.py +++ /dev/null @@ -1,101 +0,0 @@ -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Optional, Tuple - -import torch -import torch.nn as nn - - -# TODO(fangjun): Support switching between LSTM and GRU -class Decoder(nn.Module): - def __init__( - self, - vocab_size: int, - embedding_dim: int, - blank_id: int, - sos_id: int, - num_layers: int, - hidden_dim: int, - output_dim: int, - embedding_dropout: float = 0.0, - rnn_dropout: float = 0.0, - ): - """ - Args: - vocab_size: - Number of tokens of the modeling unit including blank. - embedding_dim: - Dimension of the input embedding. - blank_id: - The ID of the blank symbol. - sos_id: - The ID of the SOS symbol. - num_layers: - Number of LSTM layers. - hidden_dim: - Hidden dimension of LSTM layers. - output_dim: - Output dimension of the decoder. - embedding_dropout: - Dropout rate for the embedding layer. - rnn_dropout: - Dropout for LSTM layers. - """ - super().__init__() - self.embedding = nn.Embedding( - num_embeddings=vocab_size, - embedding_dim=embedding_dim, - padding_idx=blank_id, - ) - self.embedding_dropout = nn.Dropout(embedding_dropout) - # TODO(fangjun): Use layer normalized LSTM - self.rnn = nn.LSTM( - input_size=embedding_dim, - hidden_size=hidden_dim, - num_layers=num_layers, - batch_first=True, - dropout=rnn_dropout, - ) - self.blank_id = blank_id - self.sos_id = sos_id - self.output_linear = nn.Linear(hidden_dim, output_dim) - - def forward( - self, - y: torch.Tensor, - states: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, - ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: - """ - Args: - y: - A 2-D tensor of shape (N, U) with BOS prepended. - states: - A tuple of two tensors containing the states information of - LSTM layers in this decoder. - Returns: - Return a tuple containing: - - - rnn_output, a tensor of shape (N, U, C) - - (h, c), containing the state information for LSTM layers. - Both are of shape (num_layers, N, C) - """ - embedding_out = self.embedding(y) - embedding_out = self.embedding_dropout(embedding_out) - rnn_out, (h, c) = self.rnn(embedding_out, states) - out = self.output_linear(rnn_out) - - return out, (h, c) diff --git a/egs/librispeech/ASR/transducer_lstm/decoder.py b/egs/librispeech/ASR/transducer_lstm/decoder.py new file mode 120000 index 0000000000..0793c5709c --- /dev/null +++ b/egs/librispeech/ASR/transducer_lstm/decoder.py @@ -0,0 +1 @@ +../pruned_transducer_stateless2/decoder.py \ No newline at end of file diff --git a/egs/librispeech/ASR/transducer_lstm/encoder.py b/egs/librispeech/ASR/transducer_lstm/encoder.py index 3dc992dd29..ec00b0a7a2 100644 --- a/egs/librispeech/ASR/transducer_lstm/encoder.py +++ b/egs/librispeech/ASR/transducer_lstm/encoder.py @@ -29,13 +29,11 @@ def __init__( hidden_size: int, output_dim: int, subsampling_factor: int = 4, - num_encoder_layers: int = 12, + num_encoder_layers: int = 6, dropout: float = 0.1, vgg_frontend: bool = False, - proj_size: int = 0, ): super().__init__() - real_hidden_size = proj_size if proj_size > 0 else hidden_size assert ( subsampling_factor == 4 ), "Only subsampling_factor==4 is supported at present" @@ -46,28 +44,21 @@ def __init__( # (1) subsampling: T -> T//subsampling_factor # (2) embedding: num_features -> d_model if vgg_frontend: - self.encoder_embed = VggSubsampling(num_features, real_hidden_size) + self.encoder_embed = VggSubsampling(num_features, output_dim) else: - self.encoder_embed = Conv2dSubsampling( - num_features, real_hidden_size - ) + self.encoder_embed = Conv2dSubsampling(num_features, output_dim) self.rnn = nn.LSTM( - input_size=hidden_size, + input_size=output_dim, hidden_size=hidden_size, num_layers=num_encoder_layers, bias=True, - proj_size=proj_size, + proj_size=output_dim, batch_first=True, dropout=dropout, bidirectional=False, ) - self.encoder_output_layer = nn.Sequential( - nn.Dropout(p=dropout), - nn.Linear(real_hidden_size, output_dim), - ) - def forward( self, x: torch.Tensor, x_lens: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: @@ -87,29 +78,21 @@ def forward( x = self.encoder_embed(x) # Caution: We assume the subsampling factor is 4! - lengths = ((x_lens - 1) // 2 - 1) // 2 + + lengths = (((x_lens - 1) >> 1) - 1) >> 1 assert x.size(1) == lengths.max().item(), ( x.size(1), lengths.max(), ) - if False: - # It is commented out as DDP complains that not all parameters are - # used. Need more checks later for the reason. - # - # Caution: We assume the dataloader returns utterances with - # duration being sorted in decreasing order - packed_x = pack_padded_sequence( - input=x, - lengths=lengths.cpu(), - batch_first=True, - enforce_sorted=True, - ) + packed_x = pack_padded_sequence( + input=x, + lengths=lengths.cpu(), + batch_first=True, + enforce_sorted=False, + ) - packed_rnn_out, _ = self.rnn(packed_x) - rnn_out, _ = pad_packed_sequence(packed_rnn_out, batch_first=True) - else: - rnn_out, _ = self.rnn(x) + packed_rnn_out, _ = self.rnn(packed_x) + rnn_out, _ = pad_packed_sequence(packed_rnn_out, batch_first=True) - logits = self.encoder_output_layer(rnn_out) - return logits, lengths + return rnn_out, lengths diff --git a/egs/librispeech/ASR/transducer_lstm/joiner.py b/egs/librispeech/ASR/transducer_lstm/joiner.py deleted file mode 100644 index 0422f8a6fe..0000000000 --- a/egs/librispeech/ASR/transducer_lstm/joiner.py +++ /dev/null @@ -1,55 +0,0 @@ -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import torch -import torch.nn as nn -import torch.nn.functional as F - - -class Joiner(nn.Module): - def __init__(self, input_dim: int, output_dim: int): - super().__init__() - - self.output_linear = nn.Linear(input_dim, output_dim) - - def forward( - self, encoder_out: torch.Tensor, decoder_out: torch.Tensor - ) -> torch.Tensor: - """ - Args: - encoder_out: - Output from the encoder. Its shape is (N, T, C). - decoder_out: - Output from the decoder. Its shape is (N, U, C). - Returns: - Return a tensor of shape (N, T, U, C). - """ - assert encoder_out.ndim == decoder_out.ndim == 3 - assert encoder_out.size(0) == decoder_out.size(0) - assert encoder_out.size(2) == decoder_out.size(2) - - encoder_out = encoder_out.unsqueeze(2) - # Now encoder_out is (N, T, 1, C) - - decoder_out = decoder_out.unsqueeze(1) - # Now decoder_out is (N, 1, U, C) - - logit = encoder_out + decoder_out - logit = F.relu(logit) - - output = self.output_linear(logit) - - return output diff --git a/egs/librispeech/ASR/transducer_lstm/joiner.py b/egs/librispeech/ASR/transducer_lstm/joiner.py new file mode 120000 index 0000000000..815fd4bb6f --- /dev/null +++ b/egs/librispeech/ASR/transducer_lstm/joiner.py @@ -0,0 +1 @@ +../pruned_transducer_stateless2/joiner.py \ No newline at end of file diff --git a/egs/librispeech/ASR/transducer_lstm/model.py b/egs/librispeech/ASR/transducer_lstm/model.py index e37558a980..f59032db80 100644 --- a/egs/librispeech/ASR/transducer_lstm/model.py +++ b/egs/librispeech/ASR/transducer_lstm/model.py @@ -1,4 +1,4 @@ -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, Wei Kang) # # See ../../../../LICENSE for clarification regarding multiple authors # @@ -14,18 +14,40 @@ # See the License for the specific language governing permissions and # limitations under the License. -""" -Note we use `rnnt_loss` from torchaudio, which exists only in -torchaudio >= v0.10.0. It also means you have to use torch >= v1.10.0 -""" + +from typing import Optional + import k2 import torch import torch.nn as nn -import torchaudio -import torchaudio.functional from encoder_interface import EncoderInterface +from scaling import ScaledLinear + +from icefall.utils import add_sos, make_pad_mask -from icefall.utils import add_sos + +def compute_teacher_student_loss( + encoder_out: torch.Tensor, + teacher_encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, +) -> torch.Tensor: + """ + Args: + encoder_out: + Encoder output of the student. Its shape is (N, T, C) + teacher_encoder_out: + Encoder output of the teacher. Its shape is also (N, T, C) + encoder_out_lens: + A 1-D tensor containing the number of valid frames in encoder_out before + padding. + Returns: + Return the l1 loss between encoder_out and teacher_encoder_out. + """ + loss = (encoder_out - teacher_encoder_out).abs().sum(dim=-1) + mask = make_pad_mask(encoder_out_lens) + loss.masked_fill_(mask, 0) + + return loss.sum() / encoder_out.size(-1) class Transducer(nn.Module): @@ -38,37 +60,50 @@ def __init__( encoder: EncoderInterface, decoder: nn.Module, joiner: nn.Module, + encoder_dim: int, + decoder_dim: int, + joiner_dim: int, + vocab_size: int, ): """ Args: encoder: It is the transcription network in the paper. Its accepts - two inputs: `x` of (N, T, C) and `x_lens` of shape (N,). - It returns two tensors: `logits` of shape (N, T, C) and + two inputs: `x` of (N, T, encoder_dim) and `x_lens` of shape (N,). + It returns two tensors: `logits` of shape (N, T, encoder_dm) and `logit_lens` of shape (N,). decoder: It is the prediction network in the paper. Its input shape - is (N, U) and its output shape is (N, U, C). It should contain - two attributes: `blank_id` and `sos_id`. + is (N, U) and its output shape is (N, U, decoder_dim). + It should contain one attribute: `blank_id`. joiner: - It has two inputs with shapes: (N, T, C) and (N, U, C). Its - output shape is (N, T, U, C). Note that its output contains - unnormalized probs, i.e., not processed by log-softmax. + It has two inputs with shapes: (N, T, encoder_dim) and + (N, U, decoder_dim). + Its output shape is (N, T, U, vocab_size). Note that its output + contains unnormalized probs, i.e., not processed by log-softmax. """ super().__init__() - assert isinstance(encoder, EncoderInterface) + assert isinstance(encoder, EncoderInterface), type(encoder) assert hasattr(decoder, "blank_id") - assert hasattr(decoder, "sos_id") self.encoder = encoder self.decoder = decoder self.joiner = joiner + self.simple_am_proj = ScaledLinear( + encoder_dim, vocab_size, initial_speed=0.5 + ) + self.simple_lm_proj = ScaledLinear(decoder_dim, vocab_size) + def forward( self, x: torch.Tensor, x_lens: torch.Tensor, y: k2.RaggedTensor, + prune_range: int = 5, + am_scale: float = 0.0, + lm_scale: float = 0.0, + teacher_model: Optional[torch.jit.ScriptModule] = None, ) -> torch.Tensor: """ Args: @@ -80,8 +115,25 @@ def forward( y: A ragged tensor with 2 axes [utt][label]. It contains labels of each utterance. + prune_range: + The prune range for rnnt loss, it means how many symbols(context) + we are considering for each frame to compute the loss. + am_scale: + The scale to smooth the loss with am (output of encoder network) + part + lm_scale: + The scale to smooth the loss with lm (output of predictor network) + part + teacher_model: + The teacher model. Returns: Return the transducer loss. + + Note: + Regarding am_scale & lm_scale, it will make the loss-function one of + the form: + lm_scale * lm_probs + am_scale * am_probs + + (1-lm_scale-am_scale) * combined_probs """ assert x.ndim == 3, x.shape assert x_lens.ndim == 1, x_lens.shape @@ -89,8 +141,20 @@ def forward( assert x.size(0) == x_lens.size(0) == y.dim0 - encoder_out, x_lens = self.encoder(x, x_lens) - assert torch.all(x_lens > 0) + encoder_out, encoder_out_lens = self.encoder(x, x_lens) + assert torch.all(encoder_out_lens > 0) + + if self.training is True: + with torch.no_grad(): + teacher_encoder_out, _ = teacher_model.encoder(x, x_lens) + + ts_loss = compute_teacher_student_loss( + encoder_out, + teacher_encoder_out, + encoder_out_lens, + ) + else: + ts_loss = torch.tensor([0.0]) # Now for the decoder, i.e., the prediction network row_splits = y.shape.row_splits(1) @@ -99,29 +163,69 @@ def forward( blank_id = self.decoder.blank_id sos_y = add_sos(y, sos_id=blank_id) + # sos_y_padded: [B, S + 1], start with SOS. sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id) - sos_y_padded = sos_y_padded.to(torch.int64) - - decoder_out, _ = self.decoder(sos_y_padded) - logits = self.joiner(encoder_out, decoder_out) + # decoder_out: [B, S + 1, decoder_dim] + decoder_out = self.decoder(sos_y_padded) - # rnnt_loss requires 0 padded targets # Note: y does not start with SOS + # y_padded : [B, S] y_padded = y.pad(mode="constant", padding_value=0) - assert hasattr(torchaudio.functional, "rnnt_loss"), ( - f"Current torchaudio version: {torchaudio.__version__}\n" - "Please install a version >= 0.10.0" + y_padded = y_padded.to(torch.int64) + boundary = torch.zeros( + (x.size(0), 4), dtype=torch.int64, device=x.device + ) + boundary[:, 2] = y_lens + boundary[:, 3] = encoder_out_lens + + lm = self.simple_lm_proj(decoder_out) + am = self.simple_am_proj(encoder_out) + + with torch.cuda.amp.autocast(enabled=False): + simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( + lm=lm.float(), + am=am.float(), + symbols=y_padded, + termination_symbol=blank_id, + lm_only_scale=lm_scale, + am_only_scale=am_scale, + boundary=boundary, + reduction="sum", + return_grad=True, + ) + + # ranges : [B, T, prune_range] + ranges = k2.get_rnnt_prune_ranges( + px_grad=px_grad, + py_grad=py_grad, + boundary=boundary, + s_range=prune_range, ) - loss = torchaudio.functional.rnnt_loss( - logits=logits, - targets=y_padded, - logit_lengths=x_lens, - target_lengths=y_lens, - blank=blank_id, - reduction="sum", + # am_pruned : [B, T, prune_range, encoder_dim] + # lm_pruned : [B, T, prune_range, decoder_dim] + am_pruned, lm_pruned = k2.do_rnnt_pruning( + am=self.joiner.encoder_proj(encoder_out), + lm=self.joiner.decoder_proj(decoder_out), + ranges=ranges, ) - return loss + # logits : [B, T, prune_range, vocab_size] + + # project_input=False since we applied the decoder's input projections + # prior to do_rnnt_pruning (this is an optimization for speed). + logits = self.joiner(am_pruned, lm_pruned, project_input=False) + + with torch.cuda.amp.autocast(enabled=False): + pruned_loss = k2.rnnt_loss_pruned( + logits=logits.float(), + symbols=y_padded, + ranges=ranges, + termination_symbol=blank_id, + boundary=boundary, + reduction="sum", + ) + + return (simple_loss, pruned_loss, ts_loss) diff --git a/egs/librispeech/ASR/transducer_lstm/optim.py b/egs/librispeech/ASR/transducer_lstm/optim.py new file mode 120000 index 0000000000..e2deb44925 --- /dev/null +++ b/egs/librispeech/ASR/transducer_lstm/optim.py @@ -0,0 +1 @@ +../pruned_transducer_stateless2/optim.py \ No newline at end of file diff --git a/egs/librispeech/ASR/transducer_lstm/scaling.py b/egs/librispeech/ASR/transducer_lstm/scaling.py new file mode 120000 index 0000000000..09d802cc44 --- /dev/null +++ b/egs/librispeech/ASR/transducer_lstm/scaling.py @@ -0,0 +1 @@ +../pruned_transducer_stateless2/scaling.py \ No newline at end of file diff --git a/egs/librispeech/ASR/transducer_lstm/teacher_model.py b/egs/librispeech/ASR/transducer_lstm/teacher_model.py new file mode 100644 index 0000000000..0ad31ba7e1 --- /dev/null +++ b/egs/librispeech/ASR/transducer_lstm/teacher_model.py @@ -0,0 +1,25 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + + +def get_teacher_model() -> torch.jit.ScriptModule: + filename = "/ceph-fj/fangjun/open-source-2/icefall-master-2/egs/librispeech/ASR/pruned_transducer_stateless3/exp/cpu_jit.pt" + model = torch.jit.load(filename) + + return model diff --git a/egs/librispeech/ASR/transducer_lstm/test_encoder.py b/egs/librispeech/ASR/transducer_lstm/test_encoder.py index cad5f1148a..e0e2b2747d 100755 --- a/egs/librispeech/ASR/transducer_lstm/test_encoder.py +++ b/egs/librispeech/ASR/transducer_lstm/test_encoder.py @@ -1,6 +1,5 @@ #!/usr/bin/env python3 -# -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) # # See ../../../../LICENSE for clarification regarding multiple authors # @@ -15,6 +14,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + + """ To run this file, do: @@ -22,26 +23,38 @@ python ./transducer_lstm/test_encoder.py """ -from encoder import LstmEncoder +import torch +from train import get_encoder_model, get_params + + +def test_encoder_model(): + params = get_params() + params.vocab_size = 500 + params.blank_id = 0 + params.context_size = 2 + encoder = get_encoder_model(params) + num_param = sum([p.numel() for p in encoder.parameters()]) + print(f"Number of encoder model parameters: {num_param}") + + N = 3 + T = 500 + C = 80 + + x = torch.rand(N, T, C) + x_lens = torch.tensor([100, 500, 300]) + y, y_lens = encoder(x, x_lens) + print(y.shape) + expected_y_lens = (((x_lens - 1) >> 1) - 1) >> 1 -def test_encoder(): - encoder = LstmEncoder( - num_features=80, - hidden_size=1024, - proj_size=512, - output_dim=512, - subsampling_factor=4, - num_encoder_layers=12, + assert torch.all(torch.eq(y_lens, expected_y_lens)), ( + y_lens, + expected_y_lens, ) - num_params = sum(p.numel() for p in encoder.parameters() if p.requires_grad) - print(num_params) - # 93979284 - # 66427392 def main(): - test_encoder() + test_encoder_model() if __name__ == "__main__": diff --git a/egs/librispeech/ASR/transducer_lstm/test_model.py b/egs/librispeech/ASR/transducer_lstm/test_model.py new file mode 100755 index 0000000000..071671f273 --- /dev/null +++ b/egs/librispeech/ASR/transducer_lstm/test_model.py @@ -0,0 +1,44 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +To run this file, do: + + cd icefall/egs/librispeech/ASR + python ./transducer_lstm/test_model.py +""" + +from train import get_params, get_transducer_model + + +def test_model(): + params = get_params() + params.vocab_size = 500 + params.blank_id = 0 + params.context_size = 2 + model = get_transducer_model(params) + num_param = sum([p.numel() for p in model.parameters()]) + print(f"Number of model parameters: {num_param}") + + +def main(): + test_model() + + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/transducer_lstm/test_teacher_model.py b/egs/librispeech/ASR/transducer_lstm/test_teacher_model.py new file mode 100755 index 0000000000..3204c11f47 --- /dev/null +++ b/egs/librispeech/ASR/transducer_lstm/test_teacher_model.py @@ -0,0 +1,59 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +To run this file, do: + + cd icefall/egs/librispeech/ASR + python ./transducer_lstm/test_teacher_model.py +""" + +import warnings + +import torch +from teacher_model import get_teacher_model + + +def test_teacher_model(): + model = get_teacher_model() + num_param = sum([p.numel() for p in model.parameters()]) + print(f"Number of encoder model parameters: {num_param}") + + N = 3 + T = 500 + C = 80 + + x = torch.rand(N, T, C) + x_lens = torch.tensor([100, 500, 300]) + + y, y_lens = model.encoder(x, x_lens) + print(y.shape) + expected_y_lens = (((x_lens - 1) >> 1) - 1) >> 1 + + assert torch.all(torch.eq(y_lens, expected_y_lens)), ( + y_lens, + expected_y_lens, + ) + + +def main(): + test_teacher_model() + + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/transducer_lstm/train.py b/egs/librispeech/ASR/transducer_lstm/train.py index eef4d34308..5692f91a1f 100755 --- a/egs/librispeech/ASR/transducer_lstm/train.py +++ b/egs/librispeech/ASR/transducer_lstm/train.py @@ -16,31 +16,42 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """ Usage: -export CUDA_VISIBLE_DEVICES="0,1,2" +export CUDA_VISIBLE_DEVICES="0,1,2,3" ./transducer_lstm/train.py \ - --world-size 3 \ + --world-size 4 \ --num-epochs 30 \ --start-epoch 0 \ --exp-dir transducer_lstm/exp \ --full-libri 1 \ - --max-duration 400 \ - --lr-factor 3 -""" + --max-duration 300 + +# For mix precision training: + +./transducer_lstm/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 0 \ + --use-fp16 1 \ + --exp-dir transducer_lstm/exp \ + --full-libri 1 \ + --max-duration 550 +""" import argparse +import copy import logging import warnings from pathlib import Path from shutil import copyfile -from typing import Optional, Tuple +from typing import Any, Dict, Optional, Tuple, Union import k2 +import optim import sentencepiece as spm import torch import torch.multiprocessing as mp @@ -50,19 +61,38 @@ from encoder import LstmEncoder from joiner import Joiner from lhotse.cut import Cut +from lhotse.dataset.sampling.base import CutSampler from lhotse.utils import fix_random_seed from model import Transducer -from noam import Noam +from optim import Eden, Eve +from teacher_model import get_teacher_model from torch import Tensor +from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP -from torch.nn.utils import clip_grad_norm_ from torch.utils.tensorboard import SummaryWriter -from icefall.checkpoint import load_checkpoint +from icefall import diagnostics +from icefall.checkpoint import load_checkpoint, remove_checkpoints from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.checkpoint import ( + save_checkpoint_with_global_batch_idx, + update_averaged_model, +) from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info -from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool +from icefall.utils import ( + AttributeDict, + MetricsTracker, + measure_gradient_norms, + measure_weight_norms, + optim_step_and_measure_param_change, + setup_logger, + str2bool, +) + +LRSchedulerType = Union[ + torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler +] def get_parser(): @@ -101,10 +131,19 @@ def get_parser(): parser.add_argument( "--start-epoch", type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--start-batch", + type=int, default=0, - help="""Resume training from from this epoch. - If it is positive, it will load checkpoint from - transducer_lstm/exp/epoch-{start_epoch-1}.pt + help="""If positive, --start-epoch is ignored and + it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt """, ) @@ -126,10 +165,79 @@ def get_parser(): ) parser.add_argument( - "--lr-factor", + "--initial-lr", + type=float, + default=0.003, + help="The initial learning rate. This value should not need to " + "be changed.", + ) + + parser.add_argument( + "--lr-batches", + type=float, + default=5000, + help="""Number of steps that affects how rapidly the learning + rate decreases. We suggest not to change this.""", + ) + + parser.add_argument( + "--lr-epochs", + type=float, + default=6, + help="""Number of epochs that affects how rapidly the learning rate decreases. + """, + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", + ) + + parser.add_argument( + "--prune-range", + type=int, + default=5, + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.25, + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", + ) + + parser.add_argument( + "--am-scale", + type=float, + default=0.0, + help="The scale to smooth the loss with am (output of encoder network)" + "part.", + ) + + parser.add_argument( + "--simple-loss-scale", type=float, - default=3.0, - help="The lr_factor for Noam optimizer", + default=0.5, + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", + ) + + parser.add_argument( + "--ts-loss-scale", + type=float, + default=0.1, + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", ) parser.add_argument( @@ -139,6 +247,65 @@ def get_parser(): help="The seed for random generators intended for reproducibility", ) + parser.add_argument( + "--print-diagnostics", + type=str2bool, + default=False, + help="Accumulate stats on activations, print them and exit.", + ) + + parser.add_argument( + "--log-diagnostics", + type=str2bool, + default=False, + help="True to also log parameter norm and " + "gradient norm to tensorboard.", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=4000, + help="""Save checkpoint after processing this number of batches" + periodically. We save checkpoint to exp-dir/ whenever + params.batch_idx_train % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' + Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the + end of each epoch where `xxx` is the epoch number counting from 0. + """, + ) + + parser.add_argument( + "--keep-last-k", + type=int, + default=30, + help="""Only keep this number of checkpoints on disk. + For instance, if it is 3, there are only 3 checkpoints + in the exp-dir with filenames `checkpoint-xxx.pt`. + It does not affect checkpoints with name `epoch-xxx.pt`. + """, + ) + + parser.add_argument( + "--average-period", + type=int, + default=100, + help="""Update the averaged model, namely `model_avg`, after processing + this number of batches. `model_avg` is a separate version of model, + in which each floating-point parameter is the average of all the + parameters from the start of training. Each time we take the average, + we do: `model_avg = model * (average_period / batch_idx_train) + + model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. + """, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=False, + help="Whether to use half precision training.", + ) + return parser @@ -180,15 +347,10 @@ def get_params() -> AttributeDict: - subsampling_factor: The subsampling factor for the model. - - use_feat_batchnorm: Whether to do batch normalization for the - input features. - - - attention_dim: Hidden dim for multi-head attention model. + - encoder_dim: Hidden dim for multi-head attention model. - num_decoder_layers: Number of decoder layer of transformer decoder. - - weight_decay: The weight_decay for the optimizer. - - warm_step: The warm_step for Noam optimizer. """ params = AttributeDict( @@ -200,22 +362,21 @@ def get_params() -> AttributeDict: "batch_idx_train": 0, "log_interval": 50, "reset_interval": 200, - "valid_interval": 3000, # For the 100h subset, use 800 - # parameters for conformer + "valid_interval": 3000, # For the 100h subset, use 1600 + # parameters for encoder "feature_dim": 80, - "encoder_out_dim": 512, "subsampling_factor": 4, - "encoder_hidden_size": 1024, - "num_encoder_layers": 4, - "proj_size": 512, + "encoder_dim": 512, + "encoder_hidden_size": 2048, + "num_encoder_layers": 6, + "dropout": 0.1, "vgg_frontend": False, - # decoder params - "decoder_embedding_dim": 1024, - "num_decoder_layers": 4, - "decoder_hidden_dim": 512, + # parameters for decoder + "decoder_dim": 512, + # parameters for joiner + "joiner_dim": 512, # parameters for Noam - "weight_decay": 1e-6, - "warm_step": 80000, # For the 100h subset, use 8k + "model_warm_step": 3000, # arg given to model, not for lrate "env_info": get_env_info(), } ) @@ -223,40 +384,40 @@ def get_params() -> AttributeDict: return params -def get_encoder_model(params: AttributeDict): +def get_encoder_model(params: AttributeDict) -> nn.Module: encoder = LstmEncoder( num_features=params.feature_dim, hidden_size=params.encoder_hidden_size, - output_dim=params.encoder_out_dim, + output_dim=params.encoder_dim, subsampling_factor=params.subsampling_factor, num_encoder_layers=params.num_encoder_layers, + dropout=params.dropout, vgg_frontend=params.vgg_frontend, ) return encoder -def get_decoder_model(params: AttributeDict): +def get_decoder_model(params: AttributeDict) -> nn.Module: decoder = Decoder( vocab_size=params.vocab_size, - embedding_dim=params.decoder_embedding_dim, + decoder_dim=params.decoder_dim, blank_id=params.blank_id, - sos_id=params.sos_id, - num_layers=params.num_decoder_layers, - hidden_dim=params.decoder_hidden_dim, - output_dim=params.encoder_out_dim, + context_size=params.context_size, ) return decoder -def get_joiner_model(params: AttributeDict): +def get_joiner_model(params: AttributeDict) -> nn.Module: joiner = Joiner( - input_dim=params.encoder_out_dim, - output_dim=params.vocab_size, + encoder_dim=params.encoder_dim, + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, ) return joiner -def get_transducer_model(params: AttributeDict): +def get_transducer_model(params: AttributeDict) -> nn.Module: encoder = get_encoder_model(params) decoder = get_decoder_model(params) joiner = get_joiner_model(params) @@ -265,6 +426,10 @@ def get_transducer_model(params: AttributeDict): encoder=encoder, decoder=decoder, joiner=joiner, + encoder_dim=params.encoder_dim, + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, ) return model @@ -272,16 +437,19 @@ def get_transducer_model(params: AttributeDict): def load_checkpoint_if_available( params: AttributeDict, model: nn.Module, + model_avg: Optional[nn.Module] = None, optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, -) -> None: + scheduler: Optional[LRSchedulerType] = None, +) -> Optional[Dict[str, Any]]: """Load checkpoint from file. - If params.start_epoch is positive, it will load the checkpoint from - `params.start_epoch - 1`. Otherwise, this function does nothing. + If params.start_batch is positive, it will load the checkpoint from + `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if + params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. - Apart from loading state dict for `model`, `optimizer` and `scheduler`, - it also updates `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, and `best_valid_loss` in `params`. Args: @@ -289,20 +457,28 @@ def load_checkpoint_if_available( The return value of :func:`get_params`. model: The training model. + model_avg: + The stored model averaged from the start of training. optimizer: The optimizer that we are using. scheduler: - The learning rate scheduler we are using. + The scheduler that we are using. Returns: - Return None. + Return a dict containing previously saved training info. """ - if params.start_epoch <= 0: - return + if params.start_batch > 0: + filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" + elif params.start_epoch > 1: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" - filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" saved_params = load_checkpoint( filename, model=model, + model_avg=model_avg, optimizer=optimizer, scheduler=scheduler, ) @@ -317,14 +493,24 @@ def load_checkpoint_if_available( for k in keys: params[k] = saved_params[k] + if params.start_batch > 0: + if "cur_epoch" in saved_params: + params["start_epoch"] = saved_params["cur_epoch"] + + if "cur_batch_idx" in saved_params: + params["cur_batch_idx"] = saved_params["cur_batch_idx"] + return saved_params def save_checkpoint( params: AttributeDict, - model: nn.Module, + model: Union[nn.Module, DDP], + model_avg: Optional[nn.Module] = None, optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, + scheduler: Optional[LRSchedulerType] = None, + sampler: Optional[CutSampler] = None, + scaler: Optional[GradScaler] = None, rank: int = 0, ) -> None: """Save model, optimizer, scheduler and training stats to file. @@ -334,6 +520,14 @@ def save_checkpoint( It is returned by :func:`get_params`. model: The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer used in the training. + sampler: + The sampler for the training dataset. + scaler: + The scaler used for mix precision training. """ if rank != 0: return @@ -341,9 +535,12 @@ def save_checkpoint( save_checkpoint_impl( filename=filename, model=model, + model_avg=model_avg, params=params, optimizer=optimizer, scheduler=scheduler, + sampler=sampler, + scaler=scaler, rank=rank, ) @@ -362,6 +559,7 @@ def compute_loss( sp: spm.SentencePieceProcessor, batch: dict, is_training: bool, + teacher_model: Optional[torch.jit.ScriptModule] = None, ) -> Tuple[Tensor, MetricsTracker]: """ Compute CTC loss given the model and its inputs. @@ -378,8 +576,10 @@ def compute_loss( True for training. False for validation. When it is True, this function enables autograd during computation; when it is False, it disables autograd. + teacher_model: + The teacher model. """ - device = model.device + device = params.device feature = batch["inputs"] # at entry, feature is (N, T, C) assert feature.ndim == 3 @@ -393,7 +593,18 @@ def compute_loss( y = k2.RaggedTensor(y).to(device) with torch.set_grad_enabled(is_training): - loss = model(x=feature, x_lens=feature_lens, y=y) + simple_loss, pruned_loss, ts_loss = model( + x=feature, + x_lens=feature_lens, + y=y, + prune_range=params.prune_range, + am_scale=params.am_scale, + lm_scale=params.lm_scale, + teacher_model=teacher_model, + ) + + loss = params.simple_loss_scale * simple_loss + pruned_loss + loss = loss + params.ts_loss_scale * ts_loss assert loss.requires_grad == is_training @@ -406,6 +617,9 @@ def compute_loss( # Note: We use reduction=sum while computing the loss. info["loss"] = loss.detach().cpu().item() + info["simple_loss"] = simple_loss.detach().cpu().item() + info["pruned_loss"] = pruned_loss.detach().cpu().item() + info["ts_loss"] = ts_loss.detach().cpu().item() return loss, info @@ -426,6 +640,7 @@ def compute_validation_loss( loss, loss_info = compute_loss( params=params, model=model, + teacher_model=None, sp=sp, batch=batch, is_training=False, @@ -447,12 +662,17 @@ def compute_validation_loss( def train_one_epoch( params: AttributeDict, model: nn.Module, + teacher_model: Optional[torch.jit.ScriptModule], optimizer: torch.optim.Optimizer, + scheduler: LRSchedulerType, sp: spm.SentencePieceProcessor, train_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader, + scaler: GradScaler, + model_avg: Optional[nn.Module] = None, tb_writer: Optional[SummaryWriter] = None, world_size: int = 1, + rank: int = 0, ) -> None: """Train the model for one epoch. @@ -465,53 +685,167 @@ def train_one_epoch( It is returned by :func:`get_params`. model: The model for training. + teacher_model: + The teacher model. optimizer: The optimizer we are using. + scheduler: + The learning rate scheduler, we call step() every step. train_dl: Dataloader for the training dataset. valid_dl: Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. + model_avg: + The stored model averaged from the start of training. tb_writer: Writer to write log messages to tensorboard. world_size: Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. """ model.train() tot_loss = MetricsTracker() + def maybe_log_gradients(tag: str): + if ( + params.log_diagnostics + and tb_writer is not None + and params.batch_idx_train % (params.log_interval * 5) == 0 + ): + tb_writer.add_scalars( + tag, + measure_gradient_norms(model, norm="l2"), + global_step=params.batch_idx_train, + ) + + def maybe_log_weights(tag: str): + if ( + params.log_diagnostics + and tb_writer is not None + and params.batch_idx_train % (params.log_interval * 5) == 0 + ): + tb_writer.add_scalars( + tag, + measure_weight_norms(model, norm="l2"), + global_step=params.batch_idx_train, + ) + + cur_batch_idx = params.get("cur_batch_idx", 0) + for batch_idx, batch in enumerate(train_dl): + if batch_idx < cur_batch_idx: + continue + cur_batch_idx = batch_idx + params.batch_idx_train += 1 batch_size = len(batch["supervisions"]["text"]) - loss, loss_info = compute_loss( - params=params, - model=model, - sp=sp, - batch=batch, - is_training=True, - ) - # summary stats - tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, loss_info = compute_loss( + params=params, + model=model, + teacher_model=teacher_model, + sp=sp, + batch=batch, + is_training=True, + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + scaler.scale(loss).backward() + + maybe_log_weights("train/param_norms") + maybe_log_gradients("train/grad_norms") + + old_parameters = None + if ( + params.log_diagnostics + and tb_writer is not None + and params.batch_idx_train % (params.log_interval * 5) == 0 + ): + old_parameters = { + n: p.detach().clone() for n, p in model.named_parameters() + } + + scheduler.step_batch(params.batch_idx_train) + scaler.step(optimizer) + scaler.update() + + if old_parameters is not None: + deltas = optim_step_and_measure_param_change( + model, old_parameters + ) + tb_writer.add_scalars( + "train/relative_param_change_per_minibatch", + deltas, + global_step=params.batch_idx_train, + ) + + optimizer.zero_grad() + except: # noqa + display_and_save_batch(batch, params=params, sp=sp) + raise + + if params.print_diagnostics and batch_idx == 5: + return - # NOTE: We use reduction==sum and loss is computed over utterances - # in the batch and there is no normalization to it so far. + if ( + rank == 0 + and params.batch_idx_train > 0 + and params.batch_idx_train % params.average_period == 0 + ): + update_averaged_model( + params=params, + model_cur=model, + model_avg=model_avg, + ) - optimizer.zero_grad() - loss.backward() - clip_grad_norm_(model.parameters(), 5.0, 2.0) - optimizer.step() + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.save_every_n == 0 + ): + params.cur_batch_idx = batch_idx + save_checkpoint_with_global_batch_idx( + out_dir=params.exp_dir, + global_batch_idx=params.batch_idx_train, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + del params.cur_batch_idx + remove_checkpoints( + out_dir=params.exp_dir, + topk=params.keep_last_k, + rank=rank, + ) if batch_idx % params.log_interval == 0: + cur_lr = scheduler.get_last_lr()[0] logging.info( f"Epoch {params.cur_epoch}, " f"batch {batch_idx}, loss[{loss_info}], " - f"tot_loss[{tot_loss}], batch size: {batch_size}" + f"tot_loss[{tot_loss}], batch size: {batch_size}, " + f"lr: {cur_lr:.2e}" ) - if batch_idx % params.log_interval == 0: - if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) @@ -557,8 +891,7 @@ def run(rank, world_size, args): params = get_params() params.update(vars(args)) if params.full_libri is False: - params.valid_interval = 800 - params.warm_step = 8000 + params.valid_interval = 1600 fix_random_seed(params.seed) if world_size > 1: @@ -577,42 +910,68 @@ def run(rank, world_size, args): device = torch.device("cuda", rank) logging.info(f"Device: {device}") + params.device = device + sp = spm.SentencePieceProcessor() sp.load(params.bpe_model) - # and are defined in local/train_bpe_model.py + # is defined in local/train_bpe_model.py params.blank_id = sp.piece_to_id("") - params.sos_id = sp.piece_to_id("") params.vocab_size = sp.get_piece_size() logging.info(params) logging.info("About to create model") model = get_transducer_model(params) + teacher_model = get_teacher_model() - checkpoints = load_checkpoint_if_available(params=params, model=model) - - num_param = sum([p.numel() for p in model.parameters() if p.requires_grad]) + num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") + num_teacher_param = sum([p.numel() for p in teacher_model.parameters()]) + logging.info(f"Number of teacher model parameters: {num_teacher_param}") + + assert params.save_every_n >= params.average_period + model_avg: Optional[nn.Module] = None + if rank == 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model) + + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available( + params=params, + model=model, + model_avg=model_avg, + ) + model.to(device) + teacher_model.to(device) if world_size > 1: logging.info("Using DDP") model = DDP(model, device_ids=[rank]) - model.device = device - optimizer = Noam( - model.parameters(), - model_size=params.encoder_hidden_size, - factor=params.lr_factor, - warm_step=params.warm_step, - weight_decay=params.weight_decay, - ) + optimizer = Eve(model.parameters(), lr=params.initial_lr) + + scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) if checkpoints and "optimizer" in checkpoints: logging.info("Loading optimizer state dict") optimizer.load_state_dict(checkpoints["optimizer"]) + if ( + checkpoints + and "scheduler" in checkpoints + and checkpoints["scheduler"] is not None + ): + logging.info("Loading scheduler state dict") + scheduler.load_state_dict(checkpoints["scheduler"]) + + if params.print_diagnostics: + opts = diagnostics.TensorDiagnosticOptions( + 2 ** 22 + ) # allow 4 megabytes per sub-module + diagnostic = diagnostics.attach_diagnostics(model, opts) + librispeech = LibriSpeechAsrDataModule(args) train_cuts = librispeech.train_clean_100_cuts() @@ -622,65 +981,85 @@ def run(rank, world_size, args): def remove_short_and_long_utt(c: Cut): # Keep only utterances with duration between 1 second and 20 seconds + # + # Caution: There is a reason to select 20.0 here. Please see + # ../local/display_manifest_statistics.py + # + # You should use ../local/display_manifest_statistics.py to get + # an utterance duration distribution for your dataset to select + # the threshold return 1.0 <= c.duration <= 20.0 - num_in_total = len(train_cuts) - train_cuts = train_cuts.filter(remove_short_and_long_utt) - num_left = len(train_cuts) - num_removed = num_in_total - num_left - removed_percent = num_removed / num_in_total * 100 - - logging.info(f"Before removing short and long utterances: {num_in_total}") - logging.info(f"After removing short and long utterances: {num_left}") - logging.info(f"Removed {num_removed} utterances ({removed_percent:.5f}%)") + if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: + # We only load the sampler's state dict when it loads a checkpoint + # saved in the middle of an epoch + sampler_state_dict = checkpoints["sampler"] + else: + sampler_state_dict = None - train_dl = librispeech.train_dataloaders(train_cuts) + train_dl = librispeech.train_dataloaders( + train_cuts, sampler_state_dict=sampler_state_dict + ) valid_cuts = librispeech.dev_clean_cuts() valid_cuts += librispeech.dev_other_cuts() valid_dl = librispeech.valid_dataloaders(valid_cuts) - scan_pessimistic_batches_for_oom( - model=model, - train_dl=train_dl, - optimizer=optimizer, - sp=sp, - params=params, - ) + if not params.print_diagnostics: + scan_pessimistic_batches_for_oom( + model=model, + teacher_model=teacher_model, + train_dl=train_dl, + optimizer=optimizer, + sp=sp, + params=params, + ) + + scaler = GradScaler(enabled=params.use_fp16) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) - for epoch in range(params.start_epoch, params.num_epochs): - fix_random_seed(params.seed + epoch) - train_dl.sampler.set_epoch(epoch) + for epoch in range(params.start_epoch, params.num_epochs + 1): + scheduler.step_epoch(epoch - 1) + fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) - cur_lr = optimizer._rate if tb_writer is not None: - tb_writer.add_scalar( - "train/learning_rate", cur_lr, params.batch_idx_train - ) tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) - if rank == 0: - logging.info("epoch {}, learning rate {}".format(epoch, cur_lr)) - params.cur_epoch = epoch train_one_epoch( params=params, model=model, + teacher_model=teacher_model, + model_avg=model_avg, optimizer=optimizer, + scheduler=scheduler, sp=sp, train_dl=train_dl, valid_dl=valid_dl, + scaler=scaler, tb_writer=tb_writer, world_size=world_size, + rank=rank, ) + if params.print_diagnostics: + diagnostic.print_diagnostics() + break + save_checkpoint( params=params, model=model, + model_avg=model_avg, optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, rank=rank, ) @@ -691,8 +1070,41 @@ def remove_short_and_long_utt(c: Cut): cleanup_dist() +def display_and_save_batch( + batch: dict, + params: AttributeDict, + sp: spm.SentencePieceProcessor, +) -> None: + """Display the batch statistics and save the batch into disk. + + Args: + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + params: + Parameters for training. See :func:`get_params`. + sp: + The BPE model. + """ + from lhotse.utils import uuid4 + + filename = f"{params.exp_dir}/batch-{uuid4()}.pt" + logging.info(f"Saving batch to {filename}") + torch.save(batch, filename) + + supervisions = batch["supervisions"] + features = batch["inputs"] + + logging.info(f"features shape: {features.shape}") + + y = sp.encode(supervisions["text"], out_type=int) + num_tokens = sum(len(i) for i in y) + logging.info(f"num tokens: {num_tokens}") + + def scan_pessimistic_batches_for_oom( model: nn.Module, + teacher_model: torch.jit.ScriptModule, train_dl: torch.utils.data.DataLoader, optimizer: torch.optim.Optimizer, sp: spm.SentencePieceProcessor, @@ -707,17 +1119,18 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - optimizer.zero_grad() - loss, _ = compute_loss( - params=params, - model=model, - sp=sp, - batch=batch, - is_training=True, - ) + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, _ = compute_loss( + params=params, + model=model, + teacher_model=teacher_model, + sp=sp, + batch=batch, + is_training=True, + ) loss.backward() - clip_grad_norm_(model.parameters(), 5.0, 2.0) optimizer.step() + optimizer.zero_grad() except RuntimeError as e: if "CUDA out of memory" in str(e): logging.error( @@ -727,6 +1140,7 @@ def scan_pessimistic_batches_for_oom( f"Failing criterion: {criterion} " f"(={crit_values[criterion]}) ..." ) + display_and_save_batch(batch, params=params, sp=sp) raise @@ -747,5 +1161,9 @@ def main(): torch.set_num_threads(1) torch.set_num_interop_threads(1) +torch._C._jit_set_profiling_executor(False) +torch._C._jit_set_profiling_mode(False) +torch._C._set_graph_executor_optimize(False) + if __name__ == "__main__": main()