= += ; # 言語音と独立に講演者の笑いが生じている場合 -<笑> = +<笑> = ; # 言語音と独立に講演者の咳が生じている場合 -<咳> = +<咳> = ; # 言語音と独立に講演者の息が生じている場合 -<息> = +<息> = ; # 講演者の泣き声 -<泣> = +<泣> = ; # 聴衆(司会者なども含む)の発話 -<フロア発話> = +<フロア発話> = ; # 聴衆の笑い -<フロア笑> = +<フロア笑> = ; # 聴衆の拍手 -<拍手> = +<拍手> = ; # 講演者が発表中に用いたデモンストレーションの音声 -<デモ> = +<デモ> = ; # 学会講演に発表時間を知らせるためにならすベルの音 -<ベル> = +<ベル> = ; # 転記単位全体が再度読み直された場合 -<朗読間違い> = +<朗読間違い> = ; # 上記以外の音で特に目立った音 -<雑音> = +<雑音> = ; # 0.2秒以上のポーズ -= +
= ; # Redacted information, for R ; # It is \x00D7 multiplication sign, not your normal 'x' × = × @@ -318,4 +318,3 @@ spk_id = 2 ャ = ǐa ュ = ǐu ョ = ǐo - diff --git a/egs/csj/ASR/local/conf/fluent.ini b/egs/csj/ASR/local/conf/fluent.ini index 5d22f9eb8a..f7f27f5bc6 100644 --- a/egs/csj/ASR/local/conf/fluent.ini +++ b/egs/csj/ASR/local/conf/fluent.ini @@ -1,17 +1,17 @@ ; # This section is ignored if this file is not supplied as the first config file to -; # lhotse prepare csj +; # lhotse prepare csj [SEGMENTS] ; # Allowed period of nonverbal noise. If exceeded, a new segment is created. gap = 0.5 ; # Maximum length of segment (s). maxlen = 10 -; # Minimum length of segment (s). Segments shorter than `minlen` will be dropped silently. +; # Minimum length of segment (s). Segments shorter than `minlen` will be dropped silently. minlen = 0.02 -; # Use this symbol to represent a period of allowed nonverbal noise, i.e. `gap`. -; # Pass an empty string to avoid adding any symbol. It was "
" in kaldi. -; # If you intend to use a multicharacter string for gap_sym, remember to register the -; # multicharacter string as part of userdef-string in prepare_lang_char.py. -gap_sym = +; # Use this symbol to represent a period of allowed nonverbal noise, i.e. `gap`. +; # Pass an empty string to avoid adding any symbol. It was " " in kaldi. +; # If you intend to use a multicharacter string for gap_sym, remember to register the +; # multicharacter string as part of userdef-string in prepare_lang_char.py. +gap_sym = [CONSTANTS] ; # Name of this mode @@ -115,59 +115,59 @@ B^ = 0 ; # 0 to remain, 1 to delete ; # Example: '(笑 ナニガ)', '(笑 (F エー)+ソー+イッ+タ+ヨー+ナ)' 笑 = 0 -; # Example: 'コク(笑 サイ+(D オン))', +; # Example: 'コク(笑 サイ+(D オン))', 笑^ = 0 ; # 泣きながら発話 ; # 0 to remain, 1 to delete -; # Example: '(泣 ドンナニ)' +; # Example: '(泣 ドンナニ)' 泣 = 0 泣^ = 0 ; # 咳をしながら発話 ; # 0 to remain, 1 to delete -; # Example: 'シャ(咳 リン) ノ' +; # Example: 'シャ(咳 リン) ノ' 咳 = 0 ; # Example: 'イッ(咳 パン)', 'ワズ(咳 カ)' 咳^ = 0 ; # ささやき声や独り言などの小さな声 ; # 0 to remain, 1 to delete -; # Example: '(L アレコレナンダッケ)', '(L (W コデ;(? コレ,ココデ))+(? セツメー+シ+タ+ホー+ガ+イー+カ+ナ))' +; # Example: '(L アレコレナンダッケ)', '(L (W コデ;(? コレ,ココデ))+(? セツメー+シ+タ+ホー+ガ+イー+カ+ナ))' L = 0 ; # Example: 'デ(L ス)', 'ッ(L テ+コ)ト' L^ = 0 [REPLACEMENTS] ; # ボーカルフライなどで母音が同定できない場合 - = + = ; # 「うん/うーん/ふーん」の音の特定が困難な場合 - = + = ; # 非語彙的な母音の引き延ばし - = + = ; # 非語彙的な子音の引き延ばし - = += ; # 言語音と独立に講演者の笑いが生じている場合 -<笑> = +<笑> = ; # 言語音と独立に講演者の咳が生じている場合 -<咳> = +<咳> = ; # 言語音と独立に講演者の息が生じている場合 -<息> = +<息> = ; # 講演者の泣き声 -<泣> = +<泣> = ; # 聴衆(司会者なども含む)の発話 -<フロア発話> = +<フロア発話> = ; # 聴衆の笑い -<フロア笑> = +<フロア笑> = ; # 聴衆の拍手 -<拍手> = +<拍手> = ; # 講演者が発表中に用いたデモンストレーションの音声 -<デモ> = +<デモ> = ; # 学会講演に発表時間を知らせるためにならすベルの音 -<ベル> = +<ベル> = ; # 転記単位全体が再度読み直された場合 -<朗読間違い> = +<朗読間違い> = ; # 上記以外の音で特に目立った音 -<雑音> = +<雑音> = ; # 0.2秒以上のポーズ -= +
= ; # Redacted information, for R ; # It is \x00D7 multiplication sign, not your normal 'x' × = × @@ -318,4 +318,3 @@ spk_id = 2 ャ = ǐa ュ = ǐu ョ = ǐo - diff --git a/egs/csj/ASR/local/conf/number.ini b/egs/csj/ASR/local/conf/number.ini index 2613c34095..cf9038f620 100644 --- a/egs/csj/ASR/local/conf/number.ini +++ b/egs/csj/ASR/local/conf/number.ini @@ -1,17 +1,17 @@ ; # This section is ignored if this file is not supplied as the first config file to -; # lhotse prepare csj +; # lhotse prepare csj [SEGMENTS] ; # Allowed period of nonverbal noise. If exceeded, a new segment is created. gap = 0.5 ; # Maximum length of segment (s). maxlen = 10 -; # Minimum length of segment (s). Segments shorter than `minlen` will be dropped silently. +; # Minimum length of segment (s). Segments shorter than `minlen` will be dropped silently. minlen = 0.02 -; # Use this symbol to represent a period of allowed nonverbal noise, i.e. `gap`. -; # Pass an empty string to avoid adding any symbol. It was "
" in kaldi. -; # If you intend to use a multicharacter string for gap_sym, remember to register the -; # multicharacter string as part of userdef-string in prepare_lang_char.py. -gap_sym = +; # Use this symbol to represent a period of allowed nonverbal noise, i.e. `gap`. +; # Pass an empty string to avoid adding any symbol. It was " " in kaldi. +; # If you intend to use a multicharacter string for gap_sym, remember to register the +; # multicharacter string as part of userdef-string in prepare_lang_char.py. +gap_sym = [CONSTANTS] ; # Name of this mode @@ -115,59 +115,59 @@ B^ = 0 ; # 0 to remain, 1 to delete ; # Example: '(笑 ナニガ)', '(笑 (F エー)+ソー+イッ+タ+ヨー+ナ)' 笑 = 0 -; # Example: 'コク(笑 サイ+(D オン))', +; # Example: 'コク(笑 サイ+(D オン))', 笑^ = 0 ; # 泣きながら発話 ; # 0 to remain, 1 to delete -; # Example: '(泣 ドンナニ)' +; # Example: '(泣 ドンナニ)' 泣 = 0 泣^ = 0 ; # 咳をしながら発話 ; # 0 to remain, 1 to delete -; # Example: 'シャ(咳 リン) ノ' +; # Example: 'シャ(咳 リン) ノ' 咳 = 0 ; # Example: 'イッ(咳 パン)', 'ワズ(咳 カ)' 咳^ = 0 ; # ささやき声や独り言などの小さな声 ; # 0 to remain, 1 to delete -; # Example: '(L アレコレナンダッケ)', '(L (W コデ;(? コレ,ココデ))+(? セツメー+シ+タ+ホー+ガ+イー+カ+ナ))' +; # Example: '(L アレコレナンダッケ)', '(L (W コデ;(? コレ,ココデ))+(? セツメー+シ+タ+ホー+ガ+イー+カ+ナ))' L = 0 ; # Example: 'デ(L ス)', 'ッ(L テ+コ)ト' L^ = 0 [REPLACEMENTS] ; # ボーカルフライなどで母音が同定できない場合 - = + = ; # 「うん/うーん/ふーん」の音の特定が困難な場合 - = + = ; # 非語彙的な母音の引き延ばし - = + = ; # 非語彙的な子音の引き延ばし - = += ; # 言語音と独立に講演者の笑いが生じている場合 -<笑> = +<笑> = ; # 言語音と独立に講演者の咳が生じている場合 -<咳> = +<咳> = ; # 言語音と独立に講演者の息が生じている場合 -<息> = +<息> = ; # 講演者の泣き声 -<泣> = +<泣> = ; # 聴衆(司会者なども含む)の発話 -<フロア発話> = +<フロア発話> = ; # 聴衆の笑い -<フロア笑> = +<フロア笑> = ; # 聴衆の拍手 -<拍手> = +<拍手> = ; # 講演者が発表中に用いたデモンストレーションの音声 -<デモ> = +<デモ> = ; # 学会講演に発表時間を知らせるためにならすベルの音 -<ベル> = +<ベル> = ; # 転記単位全体が再度読み直された場合 -<朗読間違い> = +<朗読間違い> = ; # 上記以外の音で特に目立った音 -<雑音> = +<雑音> = ; # 0.2秒以上のポーズ -= +
= ; # Redacted information, for R ; # It is \x00D7 multiplication sign, not your normal 'x' × = × @@ -318,4 +318,3 @@ spk_id = 2 ャ = ǐa ュ = ǐu ョ = ǐo - diff --git a/egs/csj/ASR/local/conf/symbol.ini b/egs/csj/ASR/local/conf/symbol.ini index 8ba451dd5e..f9801284bb 100644 --- a/egs/csj/ASR/local/conf/symbol.ini +++ b/egs/csj/ASR/local/conf/symbol.ini @@ -1,17 +1,17 @@ ; # This section is ignored if this file is not supplied as the first config file to -; # lhotse prepare csj +; # lhotse prepare csj [SEGMENTS] ; # Allowed period of nonverbal noise. If exceeded, a new segment is created. gap = 0.5 ; # Maximum length of segment (s). maxlen = 10 -; # Minimum length of segment (s). Segments shorter than `minlen` will be dropped silently. +; # Minimum length of segment (s). Segments shorter than `minlen` will be dropped silently. minlen = 0.02 -; # Use this symbol to represent a period of allowed nonverbal noise, i.e. `gap`. -; # Pass an empty string to avoid adding any symbol. It was "
" in kaldi. -; # If you intend to use a multicharacter string for gap_sym, remember to register the -; # multicharacter string as part of userdef-string in prepare_lang_char.py. -gap_sym = +; # Use this symbol to represent a period of allowed nonverbal noise, i.e. `gap`. +; # Pass an empty string to avoid adding any symbol. It was " " in kaldi. +; # If you intend to use a multicharacter string for gap_sym, remember to register the +; # multicharacter string as part of userdef-string in prepare_lang_char.py. +gap_sym = [CONSTANTS] ; # Name of this mode @@ -116,59 +116,59 @@ B^ = 0 ; # 0 to remain, 1 to delete ; # Example: '(笑 ナニガ)', '(笑 (F エー)+ソー+イッ+タ+ヨー+ナ)' 笑 = 0 -; # Example: 'コク(笑 サイ+(D オン))', +; # Example: 'コク(笑 サイ+(D オン))', 笑^ = 0 ; # 泣きながら発話 ; # 0 to remain, 1 to delete -; # Example: '(泣 ドンナニ)' +; # Example: '(泣 ドンナニ)' 泣 = 0 泣^ = 0 ; # 咳をしながら発話 ; # 0 to remain, 1 to delete -; # Example: 'シャ(咳 リン) ノ' +; # Example: 'シャ(咳 リン) ノ' 咳 = 0 ; # Example: 'イッ(咳 パン)', 'ワズ(咳 カ)' 咳^ = 0 ; # ささやき声や独り言などの小さな声 ; # 0 to remain, 1 to delete -; # Example: '(L アレコレナンダッケ)', '(L (W コデ;(? コレ,ココデ))+(? セツメー+シ+タ+ホー+ガ+イー+カ+ナ))' +; # Example: '(L アレコレナンダッケ)', '(L (W コデ;(? コレ,ココデ))+(? セツメー+シ+タ+ホー+ガ+イー+カ+ナ))' L = 0 ; # Example: 'デ(L ス)', 'ッ(L テ+コ)ト' L^ = 0 [REPLACEMENTS] ; # ボーカルフライなどで母音が同定できない場合 - = + = ; # 「うん/うーん/ふーん」の音の特定が困難な場合 - = + = ; # 非語彙的な母音の引き延ばし - = + = ; # 非語彙的な子音の引き延ばし - = += ; # 言語音と独立に講演者の笑いが生じている場合 -<笑> = +<笑> = ; # 言語音と独立に講演者の咳が生じている場合 -<咳> = +<咳> = ; # 言語音と独立に講演者の息が生じている場合 -<息> = +<息> = ; # 講演者の泣き声 -<泣> = +<泣> = ; # 聴衆(司会者なども含む)の発話 -<フロア発話> = +<フロア発話> = ; # 聴衆の笑い -<フロア笑> = +<フロア笑> = ; # 聴衆の拍手 -<拍手> = +<拍手> = ; # 講演者が発表中に用いたデモンストレーションの音声 -<デモ> = +<デモ> = ; # 学会講演に発表時間を知らせるためにならすベルの音 -<ベル> = +<ベル> = ; # 転記単位全体が再度読み直された場合 -<朗読間違い> = +<朗読間違い> = ; # 上記以外の音で特に目立った音 -<雑音> = +<雑音> = ; # 0.2秒以上のポーズ -= +
= ; # Redacted information, for R ; # It is \x00D7 multiplication sign, not your normal 'x' × = × @@ -319,4 +319,3 @@ spk_id = 2 ャ = ǐa ュ = ǐu ョ = ǐo - diff --git a/egs/csj/ASR/local/display_manifest_statistics.py b/egs/csj/ASR/local/display_manifest_statistics.py index c9de21073b..c043cf8531 100644 --- a/egs/csj/ASR/local/display_manifest_statistics.py +++ b/egs/csj/ASR/local/display_manifest_statistics.py @@ -37,9 +37,7 @@ def get_parser(): formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) - parser.add_argument( - "--manifest-dir", type=Path, help="Path to cutset manifests" - ) + parser.add_argument("--manifest-dir", type=Path, help="Path to cutset manifests") return parser.parse_args() diff --git a/egs/csj/ASR/local/prepare_lang_char.py b/egs/csj/ASR/local/prepare_lang_char.py index e4d9968711..ef91f6e431 100644 --- a/egs/csj/ASR/local/prepare_lang_char.py +++ b/egs/csj/ASR/local/prepare_lang_char.py @@ -87,9 +87,7 @@ def main(): args = get_args() logging.basicConfig( - format=( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] " "%(message)s" - ), + format=("%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] " "%(message)s"), level=logging.INFO, ) @@ -111,8 +109,7 @@ def main(): words = set() logging.info( - f"Creating vocabulary from {args.train_cut.name}" - f" at {args.trans_mode} mode." + f"Creating vocabulary from {args.train_cut.name}" f" at {args.trans_mode} mode." ) for cut in train_set: try: @@ -123,8 +120,7 @@ def main(): ) except KeyError: raise KeyError( - f"Could not find {args.trans_mode} in " - f"{cut.supervisions[0].custom}" + f"Could not find {args.trans_mode} in " f"{cut.supervisions[0].custom}" ) for t in text.split(): if t in args.userdef_string: @@ -143,9 +139,7 @@ def main(): (args.lang_dir / "words_len").write_text(f"{len(words)}") - (args.lang_dir / "userdef_string").write_text( - "\n".join(args.userdef_string) - ) + (args.lang_dir / "userdef_string").write_text("\n".join(args.userdef_string)) (args.lang_dir / "trans_mode").write_text(args.trans_mode) logging.info("Done.") diff --git a/egs/csj/ASR/local/validate_manifest.py b/egs/csj/ASR/local/validate_manifest.py index 0c4c6c1ea3..7f67c64b6c 100644 --- a/egs/csj/ASR/local/validate_manifest.py +++ b/egs/csj/ASR/local/validate_manifest.py @@ -89,9 +89,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) diff --git a/egs/gigaspeech/ASR/conformer_ctc/asr_datamodule.py b/egs/gigaspeech/ASR/conformer_ctc/asr_datamodule.py index d78e262406..72dcd772a4 100644 --- a/egs/gigaspeech/ASR/conformer_ctc/asr_datamodule.py +++ b/egs/gigaspeech/ASR/conformer_ctc/asr_datamodule.py @@ -183,23 +183,18 @@ def add_arguments(cls, parser: argparse.ArgumentParser): "--small-dev", type=str2bool, default=False, - help="Should we use only 1000 utterances for dev " - "(speeds up training)", + help="Should we use only 1000 utterances for dev " "(speeds up training)", ) def train_dataloaders(self, cuts_train: CutSet) -> DataLoader: logging.info("About to get Musan cuts") - cuts_musan = load_manifest( - self.args.manifest_dir / "musan_cuts.jsonl.gz" - ) + cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") transforms = [] if self.args.enable_musan: logging.info("Enable MUSAN") transforms.append( - CutMix( - cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True - ) + CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) ) else: logging.info("Disable MUSAN") @@ -221,9 +216,7 @@ def train_dataloaders(self, cuts_train: CutSet) -> DataLoader: input_transforms = [] if self.args.enable_spec_aug: logging.info("Enable SpecAugment") - logging.info( - f"Time warp factor: {self.args.spec_aug_time_warp_factor}" - ) + logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") input_transforms.append( SpecAugment( time_warp_factor=self.args.spec_aug_time_warp_factor, @@ -256,9 +249,7 @@ def train_dataloaders(self, cuts_train: CutSet) -> DataLoader: # Drop feats to be on the safe side. train = K2SpeechRecognitionDataset( cut_transforms=transforms, - input_strategy=OnTheFlyFeatures( - Fbank(FbankConfig(num_mel_bins=80)) - ), + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), input_transforms=input_transforms, return_cuts=self.args.return_cuts, ) @@ -304,9 +295,7 @@ def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: if self.args.on_the_fly_feats: validate = K2SpeechRecognitionDataset( cut_transforms=transforms, - input_strategy=OnTheFlyFeatures( - Fbank(FbankConfig(num_mel_bins=80)) - ), + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), return_cuts=self.args.return_cuts, ) else: @@ -362,9 +351,7 @@ def train_cuts(self) -> CutSet: @lru_cache() def dev_cuts(self) -> CutSet: logging.info("About to get dev cuts") - cuts_valid = load_manifest_lazy( - self.args.manifest_dir / "cuts_DEV.jsonl.gz" - ) + cuts_valid = load_manifest_lazy(self.args.manifest_dir / "cuts_DEV.jsonl.gz") if self.args.small_dev: return cuts_valid.subset(first=1000) else: diff --git a/egs/gigaspeech/ASR/conformer_ctc/conformer.py b/egs/gigaspeech/ASR/conformer_ctc/conformer.py index 6fac07f93c..a1cfe6e75b 100644 --- a/egs/gigaspeech/ASR/conformer_ctc/conformer.py +++ b/egs/gigaspeech/ASR/conformer_ctc/conformer.py @@ -160,9 +160,7 @@ def __init__( use_conv_batchnorm: bool = False, ) -> None: super(ConformerEncoderLayer, self).__init__() - self.self_attn = RelPositionMultiheadAttention( - d_model, nhead, dropout=0.0 - ) + self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0) self.feed_forward = nn.Sequential( nn.Linear(d_model, dim_feedforward), @@ -182,18 +180,14 @@ def __init__( d_model, cnn_module_kernel, use_batchnorm=use_conv_batchnorm ) - self.norm_ff_macaron = nn.LayerNorm( - d_model - ) # for the macaron style FNN module + self.norm_ff_macaron = nn.LayerNorm(d_model) # for the macaron style FNN module self.norm_ff = nn.LayerNorm(d_model) # for the FNN module self.norm_mha = nn.LayerNorm(d_model) # for the MHA module self.ff_scale = 0.5 self.norm_conv = nn.LayerNorm(d_model) # for the CNN module - self.norm_final = nn.LayerNorm( - d_model - ) # for the final output of the block + self.norm_final = nn.LayerNorm(d_model) # for the final output of the block self.dropout = nn.Dropout(dropout) @@ -227,9 +221,7 @@ def forward( residual = src if self.normalize_before: src = self.norm_ff_macaron(src) - src = residual + self.ff_scale * self.dropout( - self.feed_forward_macaron(src) - ) + src = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(src)) if not self.normalize_before: src = self.norm_ff_macaron(src) @@ -348,9 +340,7 @@ class RelPositionalEncoding(torch.nn.Module): """ - def __init__( - self, d_model: int, dropout_rate: float, max_len: int = 5000 - ) -> None: + def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None: """Construct an PositionalEncoding object.""" super(RelPositionalEncoding, self).__init__() self.d_model = d_model @@ -366,9 +356,7 @@ def extend_pe(self, x: Tensor) -> None: # the length of self.pe is 2 * input_len - 1 if self.pe.size(1) >= x.size(1) * 2 - 1: # Note: TorchScript doesn't implement operator== for torch.Device - if self.pe.dtype != x.dtype or str(self.pe.device) != str( - x.device - ): + if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device): self.pe = self.pe.to(dtype=x.dtype, device=x.device) return # Suppose `i` means to the position of query vector and `j` means the @@ -638,9 +626,9 @@ def multi_head_attention_forward( if torch.equal(query, key) and torch.equal(key, value): # self-attention - q, k, v = nn.functional.linear( - query, in_proj_weight, in_proj_bias - ).chunk(3, dim=-1) + q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk( + 3, dim=-1 + ) elif torch.equal(key, value): # encoder-decoder attention @@ -708,31 +696,22 @@ def multi_head_attention_forward( if attn_mask.dim() == 2: attn_mask = attn_mask.unsqueeze(0) if list(attn_mask.size()) != [1, query.size(0), key.size(0)]: - raise RuntimeError( - "The size of the 2D attn_mask is not correct." - ) + raise RuntimeError("The size of the 2D attn_mask is not correct.") elif attn_mask.dim() == 3: if list(attn_mask.size()) != [ bsz * num_heads, query.size(0), key.size(0), ]: - raise RuntimeError( - "The size of the 3D attn_mask is not correct." - ) + raise RuntimeError("The size of the 3D attn_mask is not correct.") else: raise RuntimeError( - "attn_mask's dimension {} is not supported".format( - attn_mask.dim() - ) + "attn_mask's dimension {} is not supported".format(attn_mask.dim()) ) # attn_mask's dim is 3 now. # convert ByteTensor key_padding_mask to bool - if ( - key_padding_mask is not None - and key_padding_mask.dtype == torch.uint8 - ): + if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: warnings.warn( "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead." ) @@ -771,9 +750,7 @@ def multi_head_attention_forward( # first compute matrix a and matrix c # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) - matrix_ac = torch.matmul( - q_with_bias_u, k - ) # (batch, head, time1, time2) + matrix_ac = torch.matmul(q_with_bias_u, k) # (batch, head, time1, time2) # compute matrix b and matrix d matrix_bd = torch.matmul( @@ -785,9 +762,7 @@ def multi_head_attention_forward( matrix_ac + matrix_bd ) * scaling # (batch, head, time1, time2) - attn_output_weights = attn_output_weights.view( - bsz * num_heads, tgt_len, -1 - ) + attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1) assert list(attn_output_weights.size()) == [ bsz * num_heads, @@ -821,13 +796,9 @@ def multi_head_attention_forward( attn_output = torch.bmm(attn_output_weights, v) assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim] attn_output = ( - attn_output.transpose(0, 1) - .contiguous() - .view(tgt_len, bsz, embed_dim) - ) - attn_output = nn.functional.linear( - attn_output, out_proj_weight, out_proj_bias + attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) ) + attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias) if need_weights: # average attention weights over heads diff --git a/egs/gigaspeech/ASR/conformer_ctc/decode.py b/egs/gigaspeech/ASR/conformer_ctc/decode.py index 9c1418baae..d7035a1f8e 100755 --- a/egs/gigaspeech/ASR/conformer_ctc/decode.py +++ b/egs/gigaspeech/ASR/conformer_ctc/decode.py @@ -476,9 +476,7 @@ def decode_dataset( results[lm_scale].extend(this_batch) else: - assert ( - len(results) > 0 - ), "It should not decode to empty in the first batch!" + assert len(results) > 0, "It should not decode to empty in the first batch!" this_batch = [] hyp_words = [] for cut_id, ref_text in zip(cut_ids, texts): @@ -493,9 +491,7 @@ def decode_dataset( if batch_idx % 100 == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -528,9 +524,7 @@ def save_results( test_set_wers[key] = wer if enable_log: - logging.info( - "Wrote detailed error stats to {}".format(errs_filename) - ) + logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = params.exp_dir / f"wer-summary-{test_set_name}.txt" @@ -705,9 +699,7 @@ def main(): eos_id=eos_id, ) - save_results( - params=params, test_set_name=test_set, results_dict=results_dict - ) + save_results(params=params, test_set_name=test_set, results_dict=results_dict) logging.info("Done!") diff --git a/egs/gigaspeech/ASR/conformer_ctc/label_smoothing.py b/egs/gigaspeech/ASR/conformer_ctc/label_smoothing.py index cdc85ce9a6..3b94f0c4b0 100644 --- a/egs/gigaspeech/ASR/conformer_ctc/label_smoothing.py +++ b/egs/gigaspeech/ASR/conformer_ctc/label_smoothing.py @@ -78,13 +78,10 @@ def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor: ignored = target == self.ignore_index target[ignored] = 0 - true_dist = torch.nn.functional.one_hot( - target, num_classes=num_classes - ).to(x) + true_dist = torch.nn.functional.one_hot(target, num_classes=num_classes).to(x) true_dist = ( - true_dist * (1 - self.label_smoothing) - + self.label_smoothing / num_classes + true_dist * (1 - self.label_smoothing) + self.label_smoothing / num_classes ) # Set the value of ignored indexes to 0 true_dist[ignored] = 0 diff --git a/egs/gigaspeech/ASR/conformer_ctc/subsampling.py b/egs/gigaspeech/ASR/conformer_ctc/subsampling.py index 542fb0364e..8e0f73d05b 100644 --- a/egs/gigaspeech/ASR/conformer_ctc/subsampling.py +++ b/egs/gigaspeech/ASR/conformer_ctc/subsampling.py @@ -42,13 +42,9 @@ def __init__(self, idim: int, odim: int) -> None: assert idim >= 7 super().__init__() self.conv = nn.Sequential( - nn.Conv2d( - in_channels=1, out_channels=odim, kernel_size=3, stride=2 - ), + nn.Conv2d(in_channels=1, out_channels=odim, kernel_size=3, stride=2), nn.ReLU(), - nn.Conv2d( - in_channels=odim, out_channels=odim, kernel_size=3, stride=2 - ), + nn.Conv2d(in_channels=odim, out_channels=odim, kernel_size=3, stride=2), nn.ReLU(), ) self.out = nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim) @@ -132,17 +128,13 @@ def __init__(self, idim: int, odim: int) -> None: ) ) layers.append( - torch.nn.MaxPool2d( - kernel_size=2, stride=2, padding=0, ceil_mode=True - ) + torch.nn.MaxPool2d(kernel_size=2, stride=2, padding=0, ceil_mode=True) ) cur_channels = block_dim self.layers = nn.Sequential(*layers) - self.out = nn.Linear( - block_dims[-1] * (((idim - 1) // 2 - 1) // 2), odim - ) + self.out = nn.Linear(block_dims[-1] * (((idim - 1) // 2 - 1) // 2), odim) def forward(self, x: torch.Tensor) -> torch.Tensor: """Subsample x. diff --git a/egs/gigaspeech/ASR/conformer_ctc/train.py b/egs/gigaspeech/ASR/conformer_ctc/train.py index 2965cde18b..4883d04d8f 100755 --- a/egs/gigaspeech/ASR/conformer_ctc/train.py +++ b/egs/gigaspeech/ASR/conformer_ctc/train.py @@ -386,9 +386,7 @@ def compute_loss( # # See https://github.com/k2-fsa/icefall/issues/97 # for more details - unsorted_token_ids = graph_compiler.texts_to_ids( - supervisions["text"] - ) + unsorted_token_ids = graph_compiler.texts_to_ids(supervisions["text"]) att_loss = mmodel.decoder_forward( encoder_memory, memory_mask, @@ -521,9 +519,7 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -641,9 +637,7 @@ def run(rank, world_size, args): cur_lr = optimizer._rate if tb_writer is not None: - tb_writer.add_scalar( - "train/learning_rate", cur_lr, params.batch_idx_train - ) + tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train) tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) if rank == 0: diff --git a/egs/gigaspeech/ASR/conformer_ctc/transformer.py b/egs/gigaspeech/ASR/conformer_ctc/transformer.py index 00ca027a7e..0566cfc81c 100644 --- a/egs/gigaspeech/ASR/conformer_ctc/transformer.py +++ b/egs/gigaspeech/ASR/conformer_ctc/transformer.py @@ -151,9 +151,7 @@ def __init__( norm=decoder_norm, ) - self.decoder_output_layer = torch.nn.Linear( - d_model, self.decoder_num_class - ) + self.decoder_output_layer = torch.nn.Linear(d_model, self.decoder_num_class) self.decoder_criterion = LabelSmoothingLoss() else: @@ -181,18 +179,13 @@ def forward( memory_key_padding_mask for the decoder. Its shape is (N, T). It is None if `supervision` is None. """ - if ( - isinstance(self.use_feat_batchnorm, bool) - and self.use_feat_batchnorm - ): + if isinstance(self.use_feat_batchnorm, bool) and self.use_feat_batchnorm: x = x.permute(0, 2, 1) # (N, T, C) -> (N, C, T) x = self.feat_batchnorm(x) x = x.permute(0, 2, 1) # (N, C, T) -> (N, T, C) if isinstance(self.use_feat_batchnorm, float): x *= self.use_feat_batchnorm - encoder_memory, memory_key_padding_mask = self.run_encoder( - x, supervision - ) + encoder_memory, memory_key_padding_mask = self.run_encoder(x, supervision) x = self.ctc_output(encoder_memory) return x, encoder_memory, memory_key_padding_mask @@ -273,23 +266,17 @@ def decoder_forward( """ ys_in = add_sos(token_ids, sos_id=sos_id) ys_in = [torch.tensor(y) for y in ys_in] - ys_in_pad = pad_sequence( - ys_in, batch_first=True, padding_value=float(eos_id) - ) + ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id)) ys_out = add_eos(token_ids, eos_id=eos_id) ys_out = [torch.tensor(y) for y in ys_out] - ys_out_pad = pad_sequence( - ys_out, batch_first=True, padding_value=float(-1) - ) + ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1)) device = memory.device ys_in_pad = ys_in_pad.to(device) ys_out_pad = ys_out_pad.to(device) - tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to( - device - ) + tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device) tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) # TODO: Use length information to create the decoder padding mask @@ -350,23 +337,17 @@ def decoder_nll( ys_in = add_sos(token_ids, sos_id=sos_id) ys_in = [torch.tensor(y) for y in ys_in] - ys_in_pad = pad_sequence( - ys_in, batch_first=True, padding_value=float(eos_id) - ) + ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id)) ys_out = add_eos(token_ids, eos_id=eos_id) ys_out = [torch.tensor(y) for y in ys_out] - ys_out_pad = pad_sequence( - ys_out, batch_first=True, padding_value=float(-1) - ) + ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1)) device = memory.device ys_in_pad = ys_in_pad.to(device, dtype=torch.int64) ys_out_pad = ys_out_pad.to(device, dtype=torch.int64) - tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to( - device - ) + tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device) tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) # TODO: Use length information to create the decoder padding mask @@ -639,9 +620,7 @@ def _get_activation_fn(activation: str): elif activation == "gelu": return nn.functional.gelu - raise RuntimeError( - "activation should be relu/gelu, not {}".format(activation) - ) + raise RuntimeError("activation should be relu/gelu, not {}".format(activation)) class PositionalEncoding(nn.Module): @@ -843,9 +822,7 @@ def encoder_padding_mask( 1, ).to(torch.int32) - lengths = [ - 0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1) - ] + lengths = [0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1)] for idx in range(supervision_segments.size(0)): # Note: TorchScript doesn't allow to unpack tensors as tuples sequence_idx = supervision_segments[idx, 0].item() @@ -866,9 +843,7 @@ def encoder_padding_mask( return mask -def decoder_padding_mask( - ys_pad: torch.Tensor, ignore_id: int = -1 -) -> torch.Tensor: +def decoder_padding_mask(ys_pad: torch.Tensor, ignore_id: int = -1) -> torch.Tensor: """Generate a length mask for input. The masked position are filled with True, diff --git a/egs/gigaspeech/ASR/local/compute_fbank_gigaspeech_dev_test.py b/egs/gigaspeech/ASR/local/compute_fbank_gigaspeech_dev_test.py index 8209ee3ece..07beeb1f0e 100755 --- a/egs/gigaspeech/ASR/local/compute_fbank_gigaspeech_dev_test.py +++ b/egs/gigaspeech/ASR/local/compute_fbank_gigaspeech_dev_test.py @@ -77,9 +77,7 @@ def compute_fbank_gigaspeech_dev_test(): def main(): - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) compute_fbank_gigaspeech_dev_test() diff --git a/egs/gigaspeech/ASR/local/compute_fbank_gigaspeech_splits.py b/egs/gigaspeech/ASR/local/compute_fbank_gigaspeech_splits.py index 6410249db4..1c71be0f97 100755 --- a/egs/gigaspeech/ASR/local/compute_fbank_gigaspeech_splits.py +++ b/egs/gigaspeech/ASR/local/compute_fbank_gigaspeech_splits.py @@ -134,9 +134,7 @@ def main(): date_time = now.strftime("%Y-%m-%d-%H-%M-%S") log_filename = "log-compute_fbank_gigaspeech_splits" - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" log_filename = f"{log_filename}-{date_time}" logging.basicConfig( diff --git a/egs/gigaspeech/ASR/local/preprocess_gigaspeech.py b/egs/gigaspeech/ASR/local/preprocess_gigaspeech.py index 48d10a1579..31abe7fff0 100755 --- a/egs/gigaspeech/ASR/local/preprocess_gigaspeech.py +++ b/egs/gigaspeech/ASR/local/preprocess_gigaspeech.py @@ -98,19 +98,13 @@ def preprocess_giga_speech(): f"Speed perturb for {partition} with factors 0.9 and 1.1 " "(Perturbing may take 8 minutes and saving may take 20 minutes)" ) - cut_set = ( - cut_set - + cut_set.perturb_speed(0.9) - + cut_set.perturb_speed(1.1) - ) + cut_set = cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) logging.info(f"Saving to {raw_cuts_path}") cut_set.to_file(raw_cuts_path) def main(): - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) preprocess_giga_speech() diff --git a/egs/gigaspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py b/egs/gigaspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py index c87686e1e0..7f114fba62 100644 --- a/egs/gigaspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py +++ b/egs/gigaspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py @@ -195,8 +195,7 @@ def add_arguments(cls, parser: argparse.ArgumentParser): "--small-dev", type=str2bool, default=False, - help="Should we use only 1000 utterances for dev " - "(speeds up training)", + help="Should we use only 1000 utterances for dev " "(speeds up training)", ) def train_dataloaders( @@ -216,13 +215,9 @@ def train_dataloaders( if self.args.enable_musan: logging.info("Enable MUSAN") logging.info("About to get Musan cuts") - cuts_musan = load_manifest( - self.args.manifest_dir / "musan_cuts.jsonl.gz" - ) + cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") transforms.append( - CutMix( - cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True - ) + CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) ) else: logging.info("Disable MUSAN") @@ -244,9 +239,7 @@ def train_dataloaders( input_transforms = [] if self.args.enable_spec_aug: logging.info("Enable SpecAugment") - logging.info( - f"Time warp factor: {self.args.spec_aug_time_warp_factor}" - ) + logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") # Set the value of num_frame_masks according to Lhotse's version. # In different Lhotse's versions, the default of num_frame_masks is # different. @@ -289,9 +282,7 @@ def train_dataloaders( # Drop feats to be on the safe side. train = K2SpeechRecognitionDataset( cut_transforms=transforms, - input_strategy=OnTheFlyFeatures( - Fbank(FbankConfig(num_mel_bins=80)) - ), + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), input_transforms=input_transforms, return_cuts=self.args.return_cuts, ) @@ -347,9 +338,7 @@ def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: if self.args.on_the_fly_feats: validate = K2SpeechRecognitionDataset( cut_transforms=transforms, - input_strategy=OnTheFlyFeatures( - Fbank(FbankConfig(num_mel_bins=80)) - ), + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), return_cuts=self.args.return_cuts, ) else: @@ -405,9 +394,7 @@ def train_cuts(self) -> CutSet: @lru_cache() def dev_cuts(self) -> CutSet: logging.info("About to get dev cuts") - cuts_valid = load_manifest_lazy( - self.args.manifest_dir / "cuts_DEV.jsonl.gz" - ) + cuts_valid = load_manifest_lazy(self.args.manifest_dir / "cuts_DEV.jsonl.gz") if self.args.small_dev: return cuts_valid.subset(first=1000) else: diff --git a/egs/gigaspeech/ASR/pruned_transducer_stateless2/decode.py b/egs/gigaspeech/ASR/pruned_transducer_stateless2/decode.py index 5849a34715..c0b17750e5 100755 --- a/egs/gigaspeech/ASR/pruned_transducer_stateless2/decode.py +++ b/egs/gigaspeech/ASR/pruned_transducer_stateless2/decode.py @@ -77,11 +77,7 @@ from gigaspeech_scoring import asr_text_post_processing from train import get_params, get_transducer_model -from icefall.checkpoint import ( - average_checkpoints, - find_checkpoints, - load_checkpoint, -) +from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint from icefall.utils import ( AttributeDict, setup_logger, @@ -188,8 +184,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -258,9 +253,7 @@ def decode_one_batch( supervisions = batch["supervisions"] feature_lens = supervisions["num_frames"].to(device) - encoder_out, encoder_out_lens = model.encoder( - x=feature, x_lens=feature_lens - ) + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) hyps = [] if params.decoding_method == "fast_beam_search": @@ -275,10 +268,7 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif ( - params.decoding_method == "greedy_search" - and params.max_sym_per_frame == 1 - ): + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -398,9 +388,7 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -434,8 +422,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -511,8 +498,7 @@ def main(): ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( diff --git a/egs/gigaspeech/ASR/pruned_transducer_stateless2/export.py b/egs/gigaspeech/ASR/pruned_transducer_stateless2/export.py index cff9c73776..3d1e7bc18b 100755 --- a/egs/gigaspeech/ASR/pruned_transducer_stateless2/export.py +++ b/egs/gigaspeech/ASR/pruned_transducer_stateless2/export.py @@ -51,11 +51,7 @@ import torch from train import get_params, get_transducer_model -from icefall.checkpoint import ( - average_checkpoints, - find_checkpoints, - load_checkpoint, -) +from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint from icefall.utils import str2bool @@ -120,8 +116,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) return parser @@ -160,8 +155,7 @@ def main(): ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -209,9 +203,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/gigaspeech/ASR/pruned_transducer_stateless2/train.py b/egs/gigaspeech/ASR/pruned_transducer_stateless2/train.py index 83ae255613..f51584120f 100755 --- a/egs/gigaspeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/gigaspeech/ASR/pruned_transducer_stateless2/train.py @@ -77,9 +77,7 @@ from icefall.env import get_env_info from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool -LRSchedulerType = Union[ - torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler -] +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] def get_parser(): @@ -178,8 +176,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( @@ -202,8 +199,7 @@ def get_parser(): "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)" - "part.", + help="The scale to smooth the loss with am (output of encoder network)" "part.", ) parser.add_argument( @@ -553,23 +549,16 @@ def compute_loss( # overwhelming the simple_loss and causing it to diverge, # in case it had not fully learned the alignment yet. pruned_loss_scale = ( - 0.0 - if warmup < 1.0 - else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) - ) - loss = ( - params.simple_loss_scale * simple_loss - + pruned_loss_scale * pruned_loss + 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) ) + loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss assert loss.requires_grad == is_training info = MetricsTracker() with warnings.catch_warnings(): warnings.simplefilter("ignore") - info["frames"] = ( - (feature_lens // params.subsampling_factor).sum().item() - ) + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() # Note: We use reduction=sum while computing the loss. info["loss"] = loss.detach().cpu().item() @@ -732,9 +721,7 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") diff --git a/egs/librispeech/ASR/conformer_ctc/ali.py b/egs/librispeech/ASR/conformer_ctc/ali.py index 2828e309e8..42e14abac9 100755 --- a/egs/librispeech/ASR/conformer_ctc/ali.py +++ b/egs/librispeech/ASR/conformer_ctc/ali.py @@ -231,9 +231,7 @@ def compute_alignments( labels_ali = get_alignments(best_path, kind="labels") aux_labels_ali = get_alignments(best_path, kind="aux_labels") assert len(labels_ali) == len(aux_labels_ali) == len(cut_list) - for cut, labels, aux_labels in zip( - cut_list, labels_ali, aux_labels_ali - ): + for cut, labels, aux_labels in zip(cut_list, labels_ali, aux_labels_ali): cut.labels_alignment = labels_writer.store_array( key=cut.id, value=np.asarray(labels, dtype=np.int32), @@ -258,9 +256,7 @@ def compute_alignments( if batch_idx % 100 == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return CutSet.from_cuts(cuts) @@ -289,9 +285,7 @@ def main(): out_labels_ali_filename = out_dir / f"labels_{params.dataset}.h5" out_aux_labels_ali_filename = out_dir / f"aux_labels_{params.dataset}.h5" - out_manifest_filename = ( - out_dir / f"librispeech_cuts_{params.dataset}.jsonl.gz" - ) + out_manifest_filename = out_dir / f"librispeech_cuts_{params.dataset}.jsonl.gz" for f in ( out_labels_ali_filename, diff --git a/egs/librispeech/ASR/conformer_ctc/conformer.py b/egs/librispeech/ASR/conformer_ctc/conformer.py index 6fac07f93c..a1cfe6e75b 100644 --- a/egs/librispeech/ASR/conformer_ctc/conformer.py +++ b/egs/librispeech/ASR/conformer_ctc/conformer.py @@ -160,9 +160,7 @@ def __init__( use_conv_batchnorm: bool = False, ) -> None: super(ConformerEncoderLayer, self).__init__() - self.self_attn = RelPositionMultiheadAttention( - d_model, nhead, dropout=0.0 - ) + self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0) self.feed_forward = nn.Sequential( nn.Linear(d_model, dim_feedforward), @@ -182,18 +180,14 @@ def __init__( d_model, cnn_module_kernel, use_batchnorm=use_conv_batchnorm ) - self.norm_ff_macaron = nn.LayerNorm( - d_model - ) # for the macaron style FNN module + self.norm_ff_macaron = nn.LayerNorm(d_model) # for the macaron style FNN module self.norm_ff = nn.LayerNorm(d_model) # for the FNN module self.norm_mha = nn.LayerNorm(d_model) # for the MHA module self.ff_scale = 0.5 self.norm_conv = nn.LayerNorm(d_model) # for the CNN module - self.norm_final = nn.LayerNorm( - d_model - ) # for the final output of the block + self.norm_final = nn.LayerNorm(d_model) # for the final output of the block self.dropout = nn.Dropout(dropout) @@ -227,9 +221,7 @@ def forward( residual = src if self.normalize_before: src = self.norm_ff_macaron(src) - src = residual + self.ff_scale * self.dropout( - self.feed_forward_macaron(src) - ) + src = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(src)) if not self.normalize_before: src = self.norm_ff_macaron(src) @@ -348,9 +340,7 @@ class RelPositionalEncoding(torch.nn.Module): """ - def __init__( - self, d_model: int, dropout_rate: float, max_len: int = 5000 - ) -> None: + def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None: """Construct an PositionalEncoding object.""" super(RelPositionalEncoding, self).__init__() self.d_model = d_model @@ -366,9 +356,7 @@ def extend_pe(self, x: Tensor) -> None: # the length of self.pe is 2 * input_len - 1 if self.pe.size(1) >= x.size(1) * 2 - 1: # Note: TorchScript doesn't implement operator== for torch.Device - if self.pe.dtype != x.dtype or str(self.pe.device) != str( - x.device - ): + if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device): self.pe = self.pe.to(dtype=x.dtype, device=x.device) return # Suppose `i` means to the position of query vector and `j` means the @@ -638,9 +626,9 @@ def multi_head_attention_forward( if torch.equal(query, key) and torch.equal(key, value): # self-attention - q, k, v = nn.functional.linear( - query, in_proj_weight, in_proj_bias - ).chunk(3, dim=-1) + q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk( + 3, dim=-1 + ) elif torch.equal(key, value): # encoder-decoder attention @@ -708,31 +696,22 @@ def multi_head_attention_forward( if attn_mask.dim() == 2: attn_mask = attn_mask.unsqueeze(0) if list(attn_mask.size()) != [1, query.size(0), key.size(0)]: - raise RuntimeError( - "The size of the 2D attn_mask is not correct." - ) + raise RuntimeError("The size of the 2D attn_mask is not correct.") elif attn_mask.dim() == 3: if list(attn_mask.size()) != [ bsz * num_heads, query.size(0), key.size(0), ]: - raise RuntimeError( - "The size of the 3D attn_mask is not correct." - ) + raise RuntimeError("The size of the 3D attn_mask is not correct.") else: raise RuntimeError( - "attn_mask's dimension {} is not supported".format( - attn_mask.dim() - ) + "attn_mask's dimension {} is not supported".format(attn_mask.dim()) ) # attn_mask's dim is 3 now. # convert ByteTensor key_padding_mask to bool - if ( - key_padding_mask is not None - and key_padding_mask.dtype == torch.uint8 - ): + if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: warnings.warn( "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead." ) @@ -771,9 +750,7 @@ def multi_head_attention_forward( # first compute matrix a and matrix c # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) - matrix_ac = torch.matmul( - q_with_bias_u, k - ) # (batch, head, time1, time2) + matrix_ac = torch.matmul(q_with_bias_u, k) # (batch, head, time1, time2) # compute matrix b and matrix d matrix_bd = torch.matmul( @@ -785,9 +762,7 @@ def multi_head_attention_forward( matrix_ac + matrix_bd ) * scaling # (batch, head, time1, time2) - attn_output_weights = attn_output_weights.view( - bsz * num_heads, tgt_len, -1 - ) + attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1) assert list(attn_output_weights.size()) == [ bsz * num_heads, @@ -821,13 +796,9 @@ def multi_head_attention_forward( attn_output = torch.bmm(attn_output_weights, v) assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim] attn_output = ( - attn_output.transpose(0, 1) - .contiguous() - .view(tgt_len, bsz, embed_dim) - ) - attn_output = nn.functional.linear( - attn_output, out_proj_weight, out_proj_bias + attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) ) + attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias) if need_weights: # average attention weights over heads diff --git a/egs/librispeech/ASR/conformer_ctc/decode.py b/egs/librispeech/ASR/conformer_ctc/decode.py index 3f3b1acda1..7e0bf5b7be 100755 --- a/egs/librispeech/ASR/conformer_ctc/decode.py +++ b/egs/librispeech/ASR/conformer_ctc/decode.py @@ -551,9 +551,7 @@ def decode_dataset( results[lm_scale].extend(this_batch) else: - assert ( - len(results) > 0 - ), "It should not decode to empty in the first batch!" + assert len(results) > 0, "It should not decode to empty in the first batch!" this_batch = [] hyp_words = [] for ref_text in texts: @@ -568,9 +566,7 @@ def decode_dataset( if batch_idx % 100 == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -602,9 +598,7 @@ def save_results( test_set_wers[key] = wer if enable_log: - logging.info( - "Wrote detailed error stats to {}".format(errs_filename) - ) + logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = params.exp_dir / f"wer-summary-{test_set_name}.txt" @@ -809,9 +803,7 @@ def main(): eos_id=eos_id, ) - save_results( - params=params, test_set_name=test_set, results_dict=results_dict - ) + save_results(params=params, test_set_name=test_set, results_dict=results_dict) logging.info("Done!") diff --git a/egs/librispeech/ASR/conformer_ctc/export.py b/egs/librispeech/ASR/conformer_ctc/export.py index 28c28df015..fbcbd7b29a 100755 --- a/egs/librispeech/ASR/conformer_ctc/export.py +++ b/egs/librispeech/ASR/conformer_ctc/export.py @@ -157,9 +157,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/conformer_ctc/label_smoothing.py b/egs/librispeech/ASR/conformer_ctc/label_smoothing.py index 1f2f3b1374..cb0d6e04dc 100644 --- a/egs/librispeech/ASR/conformer_ctc/label_smoothing.py +++ b/egs/librispeech/ASR/conformer_ctc/label_smoothing.py @@ -82,13 +82,10 @@ def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor: # for why we don't use target[ignored] = 0 here target = torch.where(ignored, torch.zeros_like(target), target) - true_dist = torch.nn.functional.one_hot( - target, num_classes=num_classes - ).to(x) + true_dist = torch.nn.functional.one_hot(target, num_classes=num_classes).to(x) true_dist = ( - true_dist * (1 - self.label_smoothing) - + self.label_smoothing / num_classes + true_dist * (1 - self.label_smoothing) + self.label_smoothing / num_classes ) # Set the value of ignored indexes to 0 diff --git a/egs/librispeech/ASR/conformer_ctc/pretrained.py b/egs/librispeech/ASR/conformer_ctc/pretrained.py index a2c0a54868..8200af866d 100755 --- a/egs/librispeech/ASR/conformer_ctc/pretrained.py +++ b/egs/librispeech/ASR/conformer_ctc/pretrained.py @@ -237,8 +237,7 @@ def read_sound_files( for f in filenames: wave, sample_rate = torchaudio.load(f) assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" + f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}" ) # We use only the first channel ans.append(wave[0]) @@ -300,9 +299,7 @@ def main(): logging.info("Decoding started") features = fbank(waves) - features = pad_sequence( - features, batch_first=True, padding_value=math.log(1e-10) - ) + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) # Note: We don't use key padding mask for attention during decoding with torch.no_grad(): @@ -427,9 +424,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/conformer_ctc/subsampling.py b/egs/librispeech/ASR/conformer_ctc/subsampling.py index 542fb0364e..8e0f73d05b 100644 --- a/egs/librispeech/ASR/conformer_ctc/subsampling.py +++ b/egs/librispeech/ASR/conformer_ctc/subsampling.py @@ -42,13 +42,9 @@ def __init__(self, idim: int, odim: int) -> None: assert idim >= 7 super().__init__() self.conv = nn.Sequential( - nn.Conv2d( - in_channels=1, out_channels=odim, kernel_size=3, stride=2 - ), + nn.Conv2d(in_channels=1, out_channels=odim, kernel_size=3, stride=2), nn.ReLU(), - nn.Conv2d( - in_channels=odim, out_channels=odim, kernel_size=3, stride=2 - ), + nn.Conv2d(in_channels=odim, out_channels=odim, kernel_size=3, stride=2), nn.ReLU(), ) self.out = nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim) @@ -132,17 +128,13 @@ def __init__(self, idim: int, odim: int) -> None: ) ) layers.append( - torch.nn.MaxPool2d( - kernel_size=2, stride=2, padding=0, ceil_mode=True - ) + torch.nn.MaxPool2d(kernel_size=2, stride=2, padding=0, ceil_mode=True) ) cur_channels = block_dim self.layers = nn.Sequential(*layers) - self.out = nn.Linear( - block_dims[-1] * (((idim - 1) // 2 - 1) // 2), odim - ) + self.out = nn.Linear(block_dims[-1] * (((idim - 1) // 2 - 1) // 2), odim) def forward(self, x: torch.Tensor) -> torch.Tensor: """Subsample x. diff --git a/egs/librispeech/ASR/conformer_ctc/train.py b/egs/librispeech/ASR/conformer_ctc/train.py index 6419f6816d..1449bc310c 100755 --- a/egs/librispeech/ASR/conformer_ctc/train.py +++ b/egs/librispeech/ASR/conformer_ctc/train.py @@ -393,9 +393,7 @@ def compute_loss( # Works with a phone lexicon decoding_graph = graph_compiler.compile(texts) else: - raise ValueError( - f"Unsupported type of graph compiler: {type(graph_compiler)}" - ) + raise ValueError(f"Unsupported type of graph compiler: {type(graph_compiler)}") dense_fsa_vec = k2.DenseFsaVec( nnet_output, @@ -422,9 +420,7 @@ def compute_loss( # # See https://github.com/k2-fsa/icefall/issues/97 # for more details - unsorted_token_ids = graph_compiler.texts_to_ids( - supervisions["text"] - ) + unsorted_token_ids = graph_compiler.texts_to_ids(supervisions["text"]) att_loss = mmodel.decoder_forward( encoder_memory, memory_mask, @@ -453,9 +449,7 @@ def compute_loss( info["utt_duration"] = supervisions["num_frames"].sum().item() # averaged padding proportion over utterances info["utt_pad_proportion"] = ( - ((feature.size(1) - supervisions["num_frames"]) / feature.size(1)) - .sum() - .item() + ((feature.size(1) - supervisions["num_frames"]) / feature.size(1)).sum().item() ) return loss, info @@ -568,9 +562,7 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -733,9 +725,7 @@ def remove_short_and_long_utt(c: Cut): cur_lr = optimizer._rate if tb_writer is not None: - tb_writer.add_scalar( - "train/learning_rate", cur_lr, params.batch_idx_train - ) + tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train) tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) if rank == 0: diff --git a/egs/librispeech/ASR/conformer_ctc/transformer.py b/egs/librispeech/ASR/conformer_ctc/transformer.py index 00ca027a7e..0566cfc81c 100644 --- a/egs/librispeech/ASR/conformer_ctc/transformer.py +++ b/egs/librispeech/ASR/conformer_ctc/transformer.py @@ -151,9 +151,7 @@ def __init__( norm=decoder_norm, ) - self.decoder_output_layer = torch.nn.Linear( - d_model, self.decoder_num_class - ) + self.decoder_output_layer = torch.nn.Linear(d_model, self.decoder_num_class) self.decoder_criterion = LabelSmoothingLoss() else: @@ -181,18 +179,13 @@ def forward( memory_key_padding_mask for the decoder. Its shape is (N, T). It is None if `supervision` is None. """ - if ( - isinstance(self.use_feat_batchnorm, bool) - and self.use_feat_batchnorm - ): + if isinstance(self.use_feat_batchnorm, bool) and self.use_feat_batchnorm: x = x.permute(0, 2, 1) # (N, T, C) -> (N, C, T) x = self.feat_batchnorm(x) x = x.permute(0, 2, 1) # (N, C, T) -> (N, T, C) if isinstance(self.use_feat_batchnorm, float): x *= self.use_feat_batchnorm - encoder_memory, memory_key_padding_mask = self.run_encoder( - x, supervision - ) + encoder_memory, memory_key_padding_mask = self.run_encoder(x, supervision) x = self.ctc_output(encoder_memory) return x, encoder_memory, memory_key_padding_mask @@ -273,23 +266,17 @@ def decoder_forward( """ ys_in = add_sos(token_ids, sos_id=sos_id) ys_in = [torch.tensor(y) for y in ys_in] - ys_in_pad = pad_sequence( - ys_in, batch_first=True, padding_value=float(eos_id) - ) + ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id)) ys_out = add_eos(token_ids, eos_id=eos_id) ys_out = [torch.tensor(y) for y in ys_out] - ys_out_pad = pad_sequence( - ys_out, batch_first=True, padding_value=float(-1) - ) + ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1)) device = memory.device ys_in_pad = ys_in_pad.to(device) ys_out_pad = ys_out_pad.to(device) - tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to( - device - ) + tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device) tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) # TODO: Use length information to create the decoder padding mask @@ -350,23 +337,17 @@ def decoder_nll( ys_in = add_sos(token_ids, sos_id=sos_id) ys_in = [torch.tensor(y) for y in ys_in] - ys_in_pad = pad_sequence( - ys_in, batch_first=True, padding_value=float(eos_id) - ) + ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id)) ys_out = add_eos(token_ids, eos_id=eos_id) ys_out = [torch.tensor(y) for y in ys_out] - ys_out_pad = pad_sequence( - ys_out, batch_first=True, padding_value=float(-1) - ) + ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1)) device = memory.device ys_in_pad = ys_in_pad.to(device, dtype=torch.int64) ys_out_pad = ys_out_pad.to(device, dtype=torch.int64) - tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to( - device - ) + tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device) tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) # TODO: Use length information to create the decoder padding mask @@ -639,9 +620,7 @@ def _get_activation_fn(activation: str): elif activation == "gelu": return nn.functional.gelu - raise RuntimeError( - "activation should be relu/gelu, not {}".format(activation) - ) + raise RuntimeError("activation should be relu/gelu, not {}".format(activation)) class PositionalEncoding(nn.Module): @@ -843,9 +822,7 @@ def encoder_padding_mask( 1, ).to(torch.int32) - lengths = [ - 0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1) - ] + lengths = [0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1)] for idx in range(supervision_segments.size(0)): # Note: TorchScript doesn't allow to unpack tensors as tuples sequence_idx = supervision_segments[idx, 0].item() @@ -866,9 +843,7 @@ def encoder_padding_mask( return mask -def decoder_padding_mask( - ys_pad: torch.Tensor, ignore_id: int = -1 -) -> torch.Tensor: +def decoder_padding_mask(ys_pad: torch.Tensor, ignore_id: int = -1) -> torch.Tensor: """Generate a length mask for input. The masked position are filled with True, diff --git a/egs/librispeech/ASR/conformer_ctc2/attention.py b/egs/librispeech/ASR/conformer_ctc2/attention.py index 1375d72457..356d3f21bc 100644 --- a/egs/librispeech/ASR/conformer_ctc2/attention.py +++ b/egs/librispeech/ASR/conformer_ctc2/attention.py @@ -18,11 +18,10 @@ import torch import torch.nn as nn +from scaling import ScaledLinear from torch import Tensor from torch.nn.init import xavier_normal_ -from scaling import ScaledLinear - class MultiheadAttention(nn.Module): r"""Allows the model to jointly attend to information @@ -76,9 +75,7 @@ def __init__( self.embed_dim = embed_dim self.kdim = kdim if kdim is not None else embed_dim self.vdim = vdim if vdim is not None else embed_dim - self._qkv_same_embed_dim = ( - self.kdim == embed_dim and self.vdim == embed_dim - ) + self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim self.num_heads = num_heads self.dropout = dropout @@ -94,9 +91,7 @@ def __init__( self.v_proj_weight = ScaledLinear(self.vdim, embed_dim, bias=bias) self.register_parameter("in_proj_weight", None) else: - self.in_proj_weight = ScaledLinear( - embed_dim, 3 * embed_dim, bias=bias - ) + self.in_proj_weight = ScaledLinear(embed_dim, 3 * embed_dim, bias=bias) self.register_parameter("q_proj_weight", None) self.register_parameter("k_proj_weight", None) self.register_parameter("v_proj_weight", None) @@ -107,12 +102,8 @@ def __init__( self.out_proj = ScaledLinear(embed_dim, embed_dim, bias=bias) if add_bias_kv: - self.bias_k = nn.Parameter( - torch.empty((1, 1, embed_dim), **factory_kwargs) - ) - self.bias_v = nn.Parameter( - torch.empty((1, 1, embed_dim), **factory_kwargs) - ) + self.bias_k = nn.Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs)) + self.bias_v = nn.Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs)) else: self.bias_k = self.bias_v = None diff --git a/egs/librispeech/ASR/conformer_ctc2/conformer.py b/egs/librispeech/ASR/conformer_ctc2/conformer.py index b906d2650e..09f1eb0004 100644 --- a/egs/librispeech/ASR/conformer_ctc2/conformer.py +++ b/egs/librispeech/ASR/conformer_ctc2/conformer.py @@ -29,9 +29,8 @@ ScaledConv1d, ScaledLinear, ) -from torch import Tensor, nn from subsampling import Conv2dSubsampling - +from torch import Tensor, nn from transformer import Supervisions, Transformer, encoder_padding_mask @@ -182,9 +181,7 @@ def __init__( self.d_model = d_model - self.self_attn = RelPositionMultiheadAttention( - d_model, nhead, dropout=0.0 - ) + self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0) self.feed_forward = nn.Sequential( ScaledLinear(d_model, dim_feedforward), @@ -356,9 +353,7 @@ class RelPositionalEncoding(torch.nn.Module): """ - def __init__( - self, d_model: int, dropout_rate: float, max_len: int = 5000 - ) -> None: + def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None: """Construct an PositionalEncoding object.""" super(RelPositionalEncoding, self).__init__() self.d_model = d_model @@ -373,9 +368,7 @@ def extend_pe(self, x: Tensor) -> None: # the length of self.pe is 2 * input_len - 1 if self.pe.size(1) >= x.size(1) * 2 - 1: # Note: TorchScript doesn't implement operator== for torch.Device - if self.pe.dtype != x.dtype or str(self.pe.device) != str( - x.device - ): + if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device): self.pe = self.pe.to(dtype=x.dtype, device=x.device) return # Suppose `i` means to the position of query vecotr and `j` means the @@ -650,9 +643,9 @@ def multi_head_attention_forward( if torch.equal(query, key) and torch.equal(key, value): # self-attention - q, k, v = nn.functional.linear( - query, in_proj_weight, in_proj_bias - ).chunk(3, dim=-1) + q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk( + 3, dim=-1 + ) elif torch.equal(key, value): # encoder-decoder attention @@ -721,31 +714,22 @@ def multi_head_attention_forward( if attn_mask.dim() == 2: attn_mask = attn_mask.unsqueeze(0) if list(attn_mask.size()) != [1, query.size(0), key.size(0)]: - raise RuntimeError( - "The size of the 2D attn_mask is not correct." - ) + raise RuntimeError("The size of the 2D attn_mask is not correct.") elif attn_mask.dim() == 3: if list(attn_mask.size()) != [ bsz * num_heads, query.size(0), key.size(0), ]: - raise RuntimeError( - "The size of the 3D attn_mask is not correct." - ) + raise RuntimeError("The size of the 3D attn_mask is not correct.") else: raise RuntimeError( - "attn_mask's dimension {} is not supported".format( - attn_mask.dim() - ) + "attn_mask's dimension {} is not supported".format(attn_mask.dim()) ) # attn_mask's dim is 3 now. # convert ByteTensor key_padding_mask to bool - if ( - key_padding_mask is not None - and key_padding_mask.dtype == torch.uint8 - ): + if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: warnings.warn( "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead." ) @@ -784,9 +768,7 @@ def multi_head_attention_forward( # first compute matrix a and matrix c # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) - matrix_ac = torch.matmul( - q_with_bias_u, k - ) # (batch, head, time1, time2) + matrix_ac = torch.matmul(q_with_bias_u, k) # (batch, head, time1, time2) # compute matrix b and matrix d matrix_bd = torch.matmul( @@ -794,13 +776,9 @@ def multi_head_attention_forward( ) # (batch, head, time1, 2*time1-1) matrix_bd = self.rel_shift(matrix_bd) - attn_output_weights = ( - matrix_ac + matrix_bd - ) # (batch, head, time1, time2) + attn_output_weights = matrix_ac + matrix_bd # (batch, head, time1, time2) - attn_output_weights = attn_output_weights.view( - bsz * num_heads, tgt_len, -1 - ) + attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1) assert list(attn_output_weights.size()) == [ bsz * num_heads, @@ -834,13 +812,9 @@ def multi_head_attention_forward( attn_output = torch.bmm(attn_output_weights, v) assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim] attn_output = ( - attn_output.transpose(0, 1) - .contiguous() - .view(tgt_len, bsz, embed_dim) - ) - attn_output = nn.functional.linear( - attn_output, out_proj_weight, out_proj_bias + attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) ) + attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias) if need_weights: # average attention weights over heads @@ -863,9 +837,7 @@ class ConvolutionModule(nn.Module): """ - def __init__( - self, channels: int, kernel_size: int, bias: bool = True - ) -> None: + def __init__(self, channels: int, kernel_size: int, bias: bool = True) -> None: """Construct an ConvolutionModule object.""" super(ConvolutionModule, self).__init__() # kernerl_size should be a odd number for 'SAME' padding diff --git a/egs/librispeech/ASR/conformer_ctc2/decode.py b/egs/librispeech/ASR/conformer_ctc2/decode.py index 97f2f2d392..0b271a51cf 100755 --- a/egs/librispeech/ASR/conformer_ctc2/decode.py +++ b/egs/librispeech/ASR/conformer_ctc2/decode.py @@ -658,9 +658,7 @@ def decode_dataset( results[lm_scale].extend(this_batch) else: - assert ( - len(results) > 0 - ), "It should not decode to empty in the first batch!" + assert len(results) > 0, "It should not decode to empty in the first batch!" this_batch = [] hyp_words = [] for ref_text in texts: @@ -675,9 +673,7 @@ def decode_dataset( if batch_idx % 100 == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -709,9 +705,7 @@ def save_results( test_set_wers[key] = wer if enable_log: - logging.info( - "Wrote detailed error stats to {}".format(errs_filename) - ) + logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = params.exp_dir / f"wer-summary-{test_set_name}.txt" @@ -852,9 +846,9 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -881,9 +875,9 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -985,9 +979,7 @@ def main(): eos_id=eos_id, ) - save_results( - params=params, test_set_name=test_set, results_dict=results_dict - ) + save_results(params=params, test_set_name=test_set, results_dict=results_dict) logging.info("Done!") diff --git a/egs/librispeech/ASR/conformer_ctc2/export.py b/egs/librispeech/ASR/conformer_ctc2/export.py index 584b3c3fc5..7892b03c67 100755 --- a/egs/librispeech/ASR/conformer_ctc2/export.py +++ b/egs/librispeech/ASR/conformer_ctc2/export.py @@ -47,6 +47,7 @@ from pathlib import Path import torch +from conformer import Conformer from decode import get_params from icefall.checkpoint import ( @@ -55,10 +56,8 @@ find_checkpoints, load_checkpoint, ) -from conformer import Conformer - -from icefall.utils import str2bool from icefall.lexicon import Lexicon +from icefall.utils import str2bool def get_parser(): @@ -177,9 +176,9 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -206,9 +205,9 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -273,9 +272,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/conformer_ctc2/train.py b/egs/librispeech/ASR/conformer_ctc2/train.py index 18fa3e69f2..ceea0c22c6 100755 --- a/egs/librispeech/ASR/conformer_ctc2/train.py +++ b/egs/librispeech/ASR/conformer_ctc2/train.py @@ -69,8 +69,8 @@ from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter -from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler from icefall import diagnostics +from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler from icefall.checkpoint import load_checkpoint, remove_checkpoints from icefall.checkpoint import save_checkpoint as save_checkpoint_impl from icefall.checkpoint import ( @@ -89,9 +89,7 @@ str2bool, ) -LRSchedulerType = Union[ - torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler -] +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] def get_parser(): @@ -498,11 +496,7 @@ def compute_loss( warmup: a floating point value which increases throughout training; values >= 1.0 are fully warmed up and have all modules present. """ - device = ( - model.device - if isinstance(model, DDP) - else next(model.parameters()).device - ) + device = model.device if isinstance(model, DDP) else next(model.parameters()).device feature = batch["inputs"] # at entry, feature is (N, T, C) assert feature.ndim == 3 @@ -531,9 +525,7 @@ def compute_loss( # Works with a phone lexicon decoding_graph = graph_compiler.compile(texts) else: - raise ValueError( - f"Unsupported type of graph compiler: {type(graph_compiler)}" - ) + raise ValueError(f"Unsupported type of graph compiler: {type(graph_compiler)}") dense_fsa_vec = k2.DenseFsaVec( nnet_output, @@ -560,9 +552,7 @@ def compute_loss( # # See https://github.com/k2-fsa/icefall/issues/97 # for more details - unsorted_token_ids = graph_compiler.texts_to_ids( - supervisions["text"] - ) + unsorted_token_ids = graph_compiler.texts_to_ids(supervisions["text"]) att_loss = mmodel.decoder_forward( encoder_memory, memory_mask, @@ -580,9 +570,7 @@ def compute_loss( info = MetricsTracker() with warnings.catch_warnings(): warnings.simplefilter("ignore") - info["frames"] = ( - (feature_lens // params.subsampling_factor).sum().item() - ) + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() info["ctc_loss"] = ctc_loss.detach().cpu().item() if params.att_rate != 0.0: info["att_loss"] = att_loss.detach().cpu().item() @@ -776,9 +764,9 @@ def train_one_epoch( f"tot_loss[{tot_loss}], batch size: {batch_size}, " f"lr: {cur_lr:.2e}" ) - if loss_info["ctc_loss"] == float("inf") or loss_info[ - "att_loss" - ] == float("inf"): + if loss_info["ctc_loss"] == float("inf") or loss_info["att_loss"] == float( + "inf" + ): logging.error( "Your loss contains inf, something goes wrong" f"failing batch names {batch_name}" @@ -791,9 +779,7 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") diff --git a/egs/librispeech/ASR/conformer_ctc2/transformer.py b/egs/librispeech/ASR/conformer_ctc2/transformer.py index 3ef7edc234..d3443dc942 100644 --- a/egs/librispeech/ASR/conformer_ctc2/transformer.py +++ b/egs/librispeech/ASR/conformer_ctc2/transformer.py @@ -21,19 +21,17 @@ import torch import torch.nn as nn -from label_smoothing import LabelSmoothingLoss -from subsampling import Conv2dSubsampling from attention import MultiheadAttention -from torch.nn.utils.rnn import pad_sequence - +from label_smoothing import LabelSmoothingLoss from scaling import ( ActivationBalancer, BasicNorm, DoubleSwish, - ScaledLinear, ScaledEmbedding, + ScaledLinear, ) - +from subsampling import Conv2dSubsampling +from torch.nn.utils.rnn import pad_sequence # Note: TorchScript requires Dict/List/etc. to be fully typed. Supervisions = Dict[str, torch.Tensor] @@ -210,9 +208,7 @@ def run_encoder( x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) mask = encoder_padding_mask(x.size(0), supervisions) mask = mask.to(x.device) if mask is not None else None - x = self.encoder( - x, src_key_padding_mask=mask, warmup=warmup - ) # (T, N, C) + x = self.encoder(x, src_key_padding_mask=mask, warmup=warmup) # (T, N, C) return x, mask @@ -261,23 +257,17 @@ def decoder_forward( """ ys_in = add_sos(token_ids, sos_id=sos_id) ys_in = [torch.tensor(y) for y in ys_in] - ys_in_pad = pad_sequence( - ys_in, batch_first=True, padding_value=float(eos_id) - ) + ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id)) ys_out = add_eos(token_ids, eos_id=eos_id) ys_out = [torch.tensor(y) for y in ys_out] - ys_out_pad = pad_sequence( - ys_out, batch_first=True, padding_value=float(-1) - ) + ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1)) device = memory.device ys_in_pad = ys_in_pad.to(device) ys_out_pad = ys_out_pad.to(device) - tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to( - device - ) + tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device) tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) # TODO: Use length information to create the decoder padding mask @@ -338,23 +328,17 @@ def decoder_nll( ys_in = add_sos(token_ids, sos_id=sos_id) ys_in = [torch.tensor(y) for y in ys_in] - ys_in_pad = pad_sequence( - ys_in, batch_first=True, padding_value=float(eos_id) - ) + ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id)) ys_out = add_eos(token_ids, eos_id=eos_id) ys_out = [torch.tensor(y) for y in ys_out] - ys_out_pad = pad_sequence( - ys_out, batch_first=True, padding_value=float(-1) - ) + ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1)) device = memory.device ys_in_pad = ys_in_pad.to(device, dtype=torch.int64) ys_out_pad = ys_out_pad.to(device, dtype=torch.int64) - tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to( - device - ) + tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device) tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) # TODO: Use length information to create the decoder padding mask @@ -959,9 +943,7 @@ def encoder_padding_mask( 1, ).to(torch.int32) - lengths = [ - 0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1) - ] + lengths = [0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1)] for idx in range(supervision_segments.size(0)): # Note: TorchScript doesn't allow to unpack tensors as tuples sequence_idx = supervision_segments[idx, 0].item() @@ -982,9 +964,7 @@ def encoder_padding_mask( return mask -def decoder_padding_mask( - ys_pad: torch.Tensor, ignore_id: int = -1 -) -> torch.Tensor: +def decoder_padding_mask(ys_pad: torch.Tensor, ignore_id: int = -1) -> torch.Tensor: """Generate a length mask for input. The masked position are filled with True, diff --git a/egs/librispeech/ASR/conformer_mmi/conformer.py b/egs/librispeech/ASR/conformer_mmi/conformer.py index 97c8d83a21..53e48eb133 100644 --- a/egs/librispeech/ASR/conformer_mmi/conformer.py +++ b/egs/librispeech/ASR/conformer_mmi/conformer.py @@ -156,9 +156,7 @@ def __init__( normalize_before: bool = True, ) -> None: super(ConformerEncoderLayer, self).__init__() - self.self_attn = RelPositionMultiheadAttention( - d_model, nhead, dropout=0.0 - ) + self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0) self.feed_forward = nn.Sequential( nn.Linear(d_model, dim_feedforward), @@ -176,18 +174,14 @@ def __init__( self.conv_module = ConvolutionModule(d_model, cnn_module_kernel) - self.norm_ff_macaron = nn.LayerNorm( - d_model - ) # for the macaron style FNN module + self.norm_ff_macaron = nn.LayerNorm(d_model) # for the macaron style FNN module self.norm_ff = nn.LayerNorm(d_model) # for the FNN module self.norm_mha = nn.LayerNorm(d_model) # for the MHA module self.ff_scale = 0.5 self.norm_conv = nn.LayerNorm(d_model) # for the CNN module - self.norm_final = nn.LayerNorm( - d_model - ) # for the final output of the block + self.norm_final = nn.LayerNorm(d_model) # for the final output of the block self.dropout = nn.Dropout(dropout) @@ -221,9 +215,7 @@ def forward( residual = src if self.normalize_before: src = self.norm_ff_macaron(src) - src = residual + self.ff_scale * self.dropout( - self.feed_forward_macaron(src) - ) + src = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(src)) if not self.normalize_before: src = self.norm_ff_macaron(src) @@ -342,9 +334,7 @@ class RelPositionalEncoding(torch.nn.Module): """ - def __init__( - self, d_model: int, dropout_rate: float, max_len: int = 5000 - ) -> None: + def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None: """Construct an PositionalEncoding object.""" super(RelPositionalEncoding, self).__init__() self.d_model = d_model @@ -360,9 +350,7 @@ def extend_pe(self, x: Tensor) -> None: # the length of self.pe is 2 * input_len - 1 if self.pe.size(1) >= x.size(1) * 2 - 1: # Note: TorchScript doesn't implement operator== for torch.Device - if self.pe.dtype != x.dtype or str(self.pe.device) != str( - x.device - ): + if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device): self.pe = self.pe.to(dtype=x.dtype, device=x.device) return # Suppose `i` means to the position of query vector and `j` means the @@ -632,9 +620,9 @@ def multi_head_attention_forward( if torch.equal(query, key) and torch.equal(key, value): # self-attention - q, k, v = nn.functional.linear( - query, in_proj_weight, in_proj_bias - ).chunk(3, dim=-1) + q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk( + 3, dim=-1 + ) elif torch.equal(key, value): # encoder-decoder attention @@ -702,31 +690,22 @@ def multi_head_attention_forward( if attn_mask.dim() == 2: attn_mask = attn_mask.unsqueeze(0) if list(attn_mask.size()) != [1, query.size(0), key.size(0)]: - raise RuntimeError( - "The size of the 2D attn_mask is not correct." - ) + raise RuntimeError("The size of the 2D attn_mask is not correct.") elif attn_mask.dim() == 3: if list(attn_mask.size()) != [ bsz * num_heads, query.size(0), key.size(0), ]: - raise RuntimeError( - "The size of the 3D attn_mask is not correct." - ) + raise RuntimeError("The size of the 3D attn_mask is not correct.") else: raise RuntimeError( - "attn_mask's dimension {} is not supported".format( - attn_mask.dim() - ) + "attn_mask's dimension {} is not supported".format(attn_mask.dim()) ) # attn_mask's dim is 3 now. # convert ByteTensor key_padding_mask to bool - if ( - key_padding_mask is not None - and key_padding_mask.dtype == torch.uint8 - ): + if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: warnings.warn( "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead." ) @@ -765,9 +744,7 @@ def multi_head_attention_forward( # first compute matrix a and matrix c # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) - matrix_ac = torch.matmul( - q_with_bias_u, k - ) # (batch, head, time1, time2) + matrix_ac = torch.matmul(q_with_bias_u, k) # (batch, head, time1, time2) # compute matrix b and matrix d matrix_bd = torch.matmul( @@ -779,9 +756,7 @@ def multi_head_attention_forward( matrix_ac + matrix_bd ) * scaling # (batch, head, time1, time2) - attn_output_weights = attn_output_weights.view( - bsz * num_heads, tgt_len, -1 - ) + attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1) assert list(attn_output_weights.size()) == [ bsz * num_heads, @@ -815,13 +790,9 @@ def multi_head_attention_forward( attn_output = torch.bmm(attn_output_weights, v) assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim] attn_output = ( - attn_output.transpose(0, 1) - .contiguous() - .view(tgt_len, bsz, embed_dim) - ) - attn_output = nn.functional.linear( - attn_output, out_proj_weight, out_proj_bias + attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) ) + attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias) if need_weights: # average attention weights over heads @@ -844,9 +815,7 @@ class ConvolutionModule(nn.Module): """ - def __init__( - self, channels: int, kernel_size: int, bias: bool = True - ) -> None: + def __init__(self, channels: int, kernel_size: int, bias: bool = True) -> None: """Construct an ConvolutionModule object.""" super(ConvolutionModule, self).__init__() # kernerl_size should be a odd number for 'SAME' padding diff --git a/egs/librispeech/ASR/conformer_mmi/decode.py b/egs/librispeech/ASR/conformer_mmi/decode.py index fc9861489f..e3c7b685ff 100755 --- a/egs/librispeech/ASR/conformer_mmi/decode.py +++ b/egs/librispeech/ASR/conformer_mmi/decode.py @@ -478,9 +478,7 @@ def decode_dataset( if batch_idx % 100 == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -512,9 +510,7 @@ def save_results( test_set_wers[key] = wer if enable_log: - logging.info( - "Wrote detailed error stats to {}".format(errs_filename) - ) + logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = params.exp_dir / f"wer-summary-{test_set_name}.txt" @@ -653,9 +649,7 @@ def main(): if params.export: logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt") - torch.save( - {"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt" - ) + torch.save({"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt") return model.to(device) @@ -687,9 +681,7 @@ def main(): eos_id=eos_id, ) - save_results( - params=params, test_set_name=test_set, results_dict=results_dict - ) + save_results(params=params, test_set_name=test_set, results_dict=results_dict) logging.info("Done!") diff --git a/egs/librispeech/ASR/conformer_mmi/subsampling.py b/egs/librispeech/ASR/conformer_mmi/subsampling.py index 5c3e1222ef..ad94159878 100644 --- a/egs/librispeech/ASR/conformer_mmi/subsampling.py +++ b/egs/librispeech/ASR/conformer_mmi/subsampling.py @@ -25,13 +25,9 @@ def __init__(self, idim: int, odim: int) -> None: assert idim >= 7 super().__init__() self.conv = nn.Sequential( - nn.Conv2d( - in_channels=1, out_channels=odim, kernel_size=3, stride=2 - ), + nn.Conv2d(in_channels=1, out_channels=odim, kernel_size=3, stride=2), nn.ReLU(), - nn.Conv2d( - in_channels=odim, out_channels=odim, kernel_size=3, stride=2 - ), + nn.Conv2d(in_channels=odim, out_channels=odim, kernel_size=3, stride=2), nn.ReLU(), ) self.out = nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim) @@ -115,17 +111,13 @@ def __init__(self, idim: int, odim: int) -> None: ) ) layers.append( - torch.nn.MaxPool2d( - kernel_size=2, stride=2, padding=0, ceil_mode=True - ) + torch.nn.MaxPool2d(kernel_size=2, stride=2, padding=0, ceil_mode=True) ) cur_channels = block_dim self.layers = nn.Sequential(*layers) - self.out = nn.Linear( - block_dims[-1] * (((idim - 1) // 2 - 1) // 2), odim - ) + self.out = nn.Linear(block_dims[-1] * (((idim - 1) // 2 - 1) // 2), odim) def forward(self, x: torch.Tensor) -> torch.Tensor: """Subsample x. diff --git a/egs/librispeech/ASR/conformer_mmi/test_subsampling.py b/egs/librispeech/ASR/conformer_mmi/test_subsampling.py index 937845d779..d0bb017ddf 100755 --- a/egs/librispeech/ASR/conformer_mmi/test_subsampling.py +++ b/egs/librispeech/ASR/conformer_mmi/test_subsampling.py @@ -1,8 +1,7 @@ #!/usr/bin/env python3 -from subsampling import Conv2dSubsampling -from subsampling import VggSubsampling import torch +from subsampling import Conv2dSubsampling, VggSubsampling def test_conv2d_subsampling(): diff --git a/egs/librispeech/ASR/conformer_mmi/test_transformer.py b/egs/librispeech/ASR/conformer_mmi/test_transformer.py index 08e6806074..25d18076de 100644 --- a/egs/librispeech/ASR/conformer_mmi/test_transformer.py +++ b/egs/librispeech/ASR/conformer_mmi/test_transformer.py @@ -1,17 +1,16 @@ #!/usr/bin/env python3 import torch +from torch.nn.utils.rnn import pad_sequence from transformer import ( Transformer, + add_eos, + add_sos, + decoder_padding_mask, encoder_padding_mask, generate_square_subsequent_mask, - decoder_padding_mask, - add_sos, - add_eos, ) -from torch.nn.utils.rnn import pad_sequence - def test_encoder_padding_mask(): supervisions = { diff --git a/egs/librispeech/ASR/conformer_mmi/train-with-attention.py b/egs/librispeech/ASR/conformer_mmi/train-with-attention.py index 011dadd73b..f8c94cff9f 100755 --- a/egs/librispeech/ASR/conformer_mmi/train-with-attention.py +++ b/egs/librispeech/ASR/conformer_mmi/train-with-attention.py @@ -36,23 +36,14 @@ from torch.utils.tensorboard import SummaryWriter from transformer import Noam -from icefall.ali import ( - convert_alignments_to_tensor, - load_alignments, - lookup_alignments, -) +from icefall.ali import convert_alignments_to_tensor, load_alignments, lookup_alignments from icefall.checkpoint import load_checkpoint from icefall.checkpoint import save_checkpoint as save_checkpoint_impl from icefall.dist import cleanup_dist, setup_dist from icefall.lexicon import Lexicon from icefall.mmi import LFMMILoss from icefall.mmi_graph_compiler import MmiTrainingGraphCompiler -from icefall.utils import ( - AttributeDict, - encode_supervisions, - setup_logger, - str2bool, -) +from icefall.utils import AttributeDict, encode_supervisions, setup_logger, str2bool def get_parser(): @@ -370,10 +361,7 @@ def compute_loss( nnet_output = nnet_output.clone() nnet_output[:, :min_len, :] += ali_scale * mask[:, :min_len, :] - if ( - params.batch_idx_train > params.use_ali_until - and params.beam_size < 8 - ): + if params.batch_idx_train > params.use_ali_until and params.beam_size < 8: # logging.info("Change beam size to 8") params.beam_size = 8 else: @@ -762,19 +750,14 @@ def run(rank, world_size, args): for epoch in range(params.start_epoch, params.num_epochs): train_dl.sampler.set_epoch(epoch) - if ( - params.batch_idx_train >= params.use_ali_until - and train_ali is not None - ): + if params.batch_idx_train >= params.use_ali_until and train_ali is not None: # Delete the alignments to save memory train_ali = None valid_ali = None cur_lr = optimizer._rate if tb_writer is not None: - tb_writer.add_scalar( - "train/learning_rate", cur_lr, params.batch_idx_train - ) + tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train) tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) if rank == 0: diff --git a/egs/librispeech/ASR/conformer_mmi/train.py b/egs/librispeech/ASR/conformer_mmi/train.py index 9a5bdcce2f..5cfb2bfc7a 100755 --- a/egs/librispeech/ASR/conformer_mmi/train.py +++ b/egs/librispeech/ASR/conformer_mmi/train.py @@ -36,23 +36,14 @@ from torch.utils.tensorboard import SummaryWriter from transformer import Noam -from icefall.ali import ( - convert_alignments_to_tensor, - load_alignments, - lookup_alignments, -) +from icefall.ali import convert_alignments_to_tensor, load_alignments, lookup_alignments from icefall.checkpoint import load_checkpoint from icefall.checkpoint import save_checkpoint as save_checkpoint_impl from icefall.dist import cleanup_dist, setup_dist from icefall.lexicon import Lexicon from icefall.mmi import LFMMILoss from icefall.mmi_graph_compiler import MmiTrainingGraphCompiler -from icefall.utils import ( - AttributeDict, - encode_supervisions, - setup_logger, - str2bool, -) +from icefall.utils import AttributeDict, encode_supervisions, setup_logger, str2bool def get_parser(): @@ -377,10 +368,7 @@ def compute_loss( nnet_output = nnet_output.clone() nnet_output[:, :min_len, :] += ali_scale * mask[:, :min_len, :] - if ( - params.batch_idx_train > params.use_ali_until - and params.beam_size < 8 - ): + if params.batch_idx_train > params.use_ali_until and params.beam_size < 8: logging.info("Change beam size to 8") params.beam_size = 8 else: @@ -770,19 +758,14 @@ def run(rank, world_size, args): for epoch in range(params.start_epoch, params.num_epochs): fix_random_seed(params.seed + epoch) train_dl.sampler.set_epoch(epoch) - if ( - params.batch_idx_train >= params.use_ali_until - and train_ali is not None - ): + if params.batch_idx_train >= params.use_ali_until and train_ali is not None: # Delete the alignments to save memory train_ali = None valid_ali = None cur_lr = optimizer._rate if tb_writer is not None: - tb_writer.add_scalar( - "train/learning_rate", cur_lr, params.batch_idx_train - ) + tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train) tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) if rank == 0: diff --git a/egs/librispeech/ASR/conformer_mmi/transformer.py b/egs/librispeech/ASR/conformer_mmi/transformer.py index 68a4ff65cb..2542d9abe3 100644 --- a/egs/librispeech/ASR/conformer_mmi/transformer.py +++ b/egs/librispeech/ASR/conformer_mmi/transformer.py @@ -148,9 +148,7 @@ def __init__( norm=decoder_norm, ) - self.decoder_output_layer = torch.nn.Linear( - d_model, self.decoder_num_class - ) + self.decoder_output_layer = torch.nn.Linear(d_model, self.decoder_num_class) self.decoder_criterion = LabelSmoothingLoss(self.decoder_num_class) else: @@ -182,9 +180,7 @@ def forward( x = x.permute(0, 2, 1) # (N, T, C) -> (N, C, T) x = self.feat_batchnorm(x) x = x.permute(0, 2, 1) # (N, C, T) -> (N, T, C) - encoder_memory, memory_key_padding_mask = self.run_encoder( - x, supervision - ) + encoder_memory, memory_key_padding_mask = self.run_encoder(x, supervision) x = self.ctc_output(encoder_memory) return x, encoder_memory, memory_key_padding_mask @@ -274,9 +270,7 @@ def decoder_forward( ys_in_pad = ys_in_pad.to(device) ys_out_pad = ys_out_pad.to(device) - tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to( - device - ) + tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device) tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) # TODO: Use length information to create the decoder padding mask @@ -341,9 +335,7 @@ def decoder_nll( ys_in_pad = ys_in_pad.to(device, dtype=torch.int64) ys_out_pad = ys_out_pad.to(device, dtype=torch.int64) - tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to( - device - ) + tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device) tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) # TODO: Use length information to create the decoder padding mask @@ -616,9 +608,7 @@ def _get_activation_fn(activation: str): elif activation == "gelu": return nn.functional.gelu - raise RuntimeError( - "activation should be relu/gelu, not {}".format(activation) - ) + raise RuntimeError("activation should be relu/gelu, not {}".format(activation)) class PositionalEncoding(nn.Module): @@ -887,9 +877,7 @@ def encoder_padding_mask( 1, ).to(torch.int32) - lengths = [ - 0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1) - ] + lengths = [0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1)] for idx in range(supervision_segments.size(0)): # Note: TorchScript doesn't allow to unpack tensors as tuples sequence_idx = supervision_segments[idx, 0].item() @@ -910,9 +898,7 @@ def encoder_padding_mask( return mask -def decoder_padding_mask( - ys_pad: torch.Tensor, ignore_id: int = -1 -) -> torch.Tensor: +def decoder_padding_mask(ys_pad: torch.Tensor, ignore_id: int = -1) -> torch.Tensor: """Generate a length mask for input. The masked position are filled with True, diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless/decode.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless/decode.py index 620d69a190..6854c82d8f 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless/decode.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless/decode.py @@ -215,8 +215,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( @@ -284,9 +283,7 @@ def decode_one_batch( value=LOG_EPS, ) - encoder_out, encoder_out_lens = model.encoder( - x=feature, x_lens=feature_lens - ) + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) hyps = [] if params.decoding_method == "fast_beam_search": @@ -301,10 +298,7 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif ( - params.decoding_method == "greedy_search" - and params.max_sym_per_frame == 1 - ): + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -427,9 +421,7 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -462,8 +454,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -506,9 +497,7 @@ def main(): params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-states-{params.max_states}" elif "beam_search" in params.decoding_method: - params.suffix += ( - f"-{params.decoding_method}-beam-size-{params.beam_size}" - ) + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -540,9 +529,9 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -569,9 +558,9 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless/emformer.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless/emformer.py index 8ca7d55683..1aaa3b9cb5 100644 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless/emformer.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless/emformer.py @@ -35,7 +35,6 @@ from icefall.utils import make_pad_mask - LOG_EPSILON = math.log(1e-10) @@ -127,9 +126,7 @@ def stack_states( for si, s in enumerate(layer): attn_caches[li][si].append(s) if b == batch_size - 1: - attn_caches[li][si] = torch.stack( - attn_caches[li][si], dim=1 - ) + attn_caches[li][si] = torch.stack(attn_caches[li][si], dim=1) conv_caches = [] for layer in state_list[0][1]: @@ -268,9 +265,7 @@ def _split_right_context( intervals = torch.arange( 0, self.chunk_length * (num_chunks - 1), self.chunk_length ) - first = torch.arange( - self.chunk_length, self.chunk_length + self.cache_size - ) + first = torch.arange(self.chunk_length, self.chunk_length + self.cache_size) indexes = intervals.unsqueeze(1) + first.unsqueeze(0) indexes = torch.cat( [indexes, torch.arange(U_ - self.cache_size, U_).unsqueeze(0)] @@ -284,9 +279,7 @@ def _split_right_context( # (num_chunks * B, cache_size + right_context_length, D) return pad_right_context.permute(0, 2, 1) - def _merge_right_context( - self, right_context: torch.Tensor, B: int - ) -> torch.Tensor: + def _merge_right_context(self, right_context: torch.Tensor, B: int) -> torch.Tensor: """ Args: right_context: @@ -337,12 +330,8 @@ def forward( right_context = x[:, :, :R] # (B, D, R) # make causal convolution - cache = torch.zeros( - B, D, self.cache_size, device=x.device, dtype=x.dtype - ) - pad_utterance = torch.cat( - [cache, utterance], dim=2 - ) # (B, D, cache + U) + cache = torch.zeros(B, D, self.cache_size, device=x.device, dtype=x.dtype) + pad_utterance = torch.cat([cache, utterance], dim=2) # (B, D, cache + U) # depth-wise conv on utterance utterance = self.depthwise_conv(pad_utterance) # (B, D, U) @@ -355,9 +344,7 @@ def forward( right_context = self.depthwise_conv( pad_right_context ) # (num_segs * B, D, right_context_length) - right_context = self._merge_right_context( - right_context, B - ) # (B, D, R) + right_context = self._merge_right_context(right_context, B) # (B, D, R) x = torch.cat([right_context, utterance], dim=2) # (B, D, R + U) x = self.deriv_balancer2(x) @@ -458,8 +445,7 @@ def __init__( if embed_dim % nhead != 0: raise ValueError( - f"embed_dim ({embed_dim}) is not a multiple of" - f"nhead ({nhead})." + f"embed_dim ({embed_dim}) is not a multiple of" f"nhead ({nhead})." ) self.embed_dim = embed_dim @@ -469,9 +455,7 @@ def __init__( self.head_dim = embed_dim // nhead self.dropout = dropout - self.emb_to_key_value = ScaledLinear( - embed_dim, 2 * embed_dim, bias=True - ) + self.emb_to_key_value = ScaledLinear(embed_dim, 2 * embed_dim, bias=True) self.emb_to_query = ScaledLinear(embed_dim, embed_dim, bias=True) self.out_proj = ScaledLinear( embed_dim, embed_dim, bias=True, initial_scale=0.25 @@ -513,9 +497,7 @@ def _gen_attention_probs( if padding_mask is not None: Q = attention_weights.size(1) B = attention_weights.size(0) // self.nhead - attention_weights_float = attention_weights_float.view( - B, self.nhead, Q, -1 - ) + attention_weights_float = attention_weights_float.view(B, self.nhead, Q, -1) attention_weights_float = attention_weights_float.masked_fill( padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), self.negative_inf, @@ -551,9 +533,7 @@ def _forward_impl( scaling = float(self.head_dim) ** -0.5 # compute query with [right_context, utterance, summary]. - query = self.emb_to_query( - torch.cat([right_context, utterance, summary]) - ) + query = self.emb_to_query(torch.cat([right_context, utterance, summary])) # compute key and value with [memory, right_context, utterance]. key, value = self.emb_to_key_value( torch.cat([memory, right_context, utterance]) @@ -564,16 +544,12 @@ def _forward_impl( # [memory, right context, left context, uttrance] # this is used in inference mode key = torch.cat([key[: M + R], left_context_key, key[M + R :]]) - value = torch.cat( - [value[: M + R], left_context_val, value[M + R :]] - ) + value = torch.cat([value[: M + R], left_context_val, value[M + R :]]) Q = query.size(0) # KV = key.size(0) reshaped_query, reshaped_key, reshaped_value = [ - tensor.contiguous() - .view(-1, B * self.nhead, self.head_dim) - .transpose(0, 1) + tensor.contiguous().view(-1, B * self.nhead, self.head_dim).transpose(0, 1) for tensor in [query, key, value] ] # (B * nhead, Q or KV, head_dim) attention_weights = torch.bmm( @@ -588,9 +564,7 @@ def _forward_impl( # compute attention outputs attention = torch.bmm(attention_probs, reshaped_value) assert attention.shape == (B * self.nhead, Q, self.head_dim) - attention = ( - attention.transpose(0, 1).contiguous().view(Q, B, self.embed_dim) - ) + attention = attention.transpose(0, 1).contiguous().view(Q, B, self.embed_dim) # apply output projection outputs = self.out_proj(attention) @@ -672,12 +646,7 @@ def forward( - output of right context and utterance, with shape (R + U, B, D). - memory output, with shape (M, B, D), where M = S - 1 or M = 0. """ - ( - output_right_context_utterance, - output_memory, - _, - _, - ) = self._forward_impl( + (output_right_context_utterance, output_memory, _, _,) = self._forward_impl( utterance, right_context, summary, @@ -947,13 +916,9 @@ def _apply_attention_module_forward( right_context = right_context_utterance[:R] if self.use_memory: - summary = self.summary_op(utterance.permute(1, 2, 0)).permute( - 2, 0, 1 - ) + summary = self.summary_op(utterance.permute(1, 2, 0)).permute(2, 0, 1) else: - summary = torch.empty(0).to( - dtype=utterance.dtype, device=utterance.device - ) + summary = torch.empty(0).to(dtype=utterance.dtype, device=utterance.device) output_right_context_utterance, output_memory = self.attention( utterance=utterance, right_context=right_context, @@ -992,14 +957,10 @@ def _apply_attention_module_infer( left_context_val = attn_cache[2] if self.use_memory: - summary = self.summary_op(utterance.permute(1, 2, 0)).permute( - 2, 0, 1 - ) + summary = self.summary_op(utterance.permute(1, 2, 0)).permute(2, 0, 1) summary = summary[:1] else: - summary = torch.empty(0).to( - dtype=utterance.dtype, device=utterance.device - ) + summary = torch.empty(0).to(dtype=utterance.dtype, device=utterance.device) ( output_right_context_utterance, output_memory, @@ -1014,9 +975,7 @@ def _apply_attention_module_infer( left_context_val=left_context_val, padding_mask=padding_mask, ) - attn_cache = self._update_attn_cache( - next_key, next_val, memory, attn_cache - ) + attn_cache = self._update_attn_cache(next_key, next_val, memory, attn_cache) return output_right_context_utterance, output_memory, attn_cache def forward( @@ -1151,11 +1110,7 @@ def infer( src = src + self.dropout(self.feed_forward_macaron(src)) # emformer attention module - ( - src_att, - output_memory, - attn_cache, - ) = self._apply_attention_module_infer( + (src_att, output_memory, attn_cache,) = self._apply_attention_module_infer( src, R, memory, attn_cache, padding_mask=padding_mask ) src = src + self.dropout(src_att) @@ -1295,9 +1250,7 @@ def __init__( def _gen_right_context(self, x: torch.Tensor) -> torch.Tensor: """Hard copy each chunk's right context and concat them.""" T = x.shape[0] - num_chunks = math.ceil( - (T - self.right_context_length) / self.chunk_length - ) + num_chunks = math.ceil((T - self.right_context_length) / self.chunk_length) # first (num_chunks - 1) right context block intervals = torch.arange( 0, self.chunk_length * (num_chunks - 1), self.chunk_length @@ -1316,9 +1269,7 @@ def _gen_right_context(self, x: torch.Tensor) -> torch.Tensor: right_context_blocks = x[indexes.reshape(-1)] return right_context_blocks - def _gen_attention_mask_col_widths( - self, chunk_idx: int, U: int - ) -> List[int]: + def _gen_attention_mask_col_widths(self, chunk_idx: int, U: int) -> List[int]: """Calculate column widths (key, value) in attention mask for the chunk_idx chunk.""" num_chunks = math.ceil(U / self.chunk_length) @@ -1479,9 +1430,7 @@ def forward( output_lengths = torch.clamp(lengths - self.right_context_length, min=0) attention_mask = self._gen_attention_mask(utterance) memory = ( - self.init_memory_op(utterance.permute(1, 2, 0)).permute(2, 0, 1)[ - :-1 - ] + self.init_memory_op(utterance.permute(1, 2, 0)).permute(2, 0, 1)[:-1] if self.use_memory else torch.empty(0).to(dtype=x.dtype, device=x.device) ) @@ -1643,12 +1592,8 @@ def init_states(self, device: torch.device = torch.device("cpu")): attn_caches = [ [ torch.zeros(self.memory_size, self.d_model, device=device), - torch.zeros( - self.left_context_length, self.d_model, device=device - ), - torch.zeros( - self.left_context_length, self.d_model, device=device - ), + torch.zeros(self.left_context_length, self.d_model, device=device), + torch.zeros(self.left_context_length, self.d_model, device=device), ] for _ in range(self.num_encoder_layers) ] @@ -1693,17 +1638,11 @@ def __init__( raise NotImplementedError( "chunk_length must be a mutiple of subsampling_factor." ) - if ( - left_context_length != 0 - and left_context_length % subsampling_factor != 0 - ): + if left_context_length != 0 and left_context_length % subsampling_factor != 0: raise NotImplementedError( "left_context_length must be 0 or a mutiple of subsampling_factor." # noqa ) - if ( - right_context_length != 0 - and right_context_length % subsampling_factor != 0 - ): + if right_context_length != 0 and right_context_length % subsampling_factor != 0: raise NotImplementedError( "right_context_length must be 0 or a mutiple of subsampling_factor." # noqa ) @@ -1766,9 +1705,7 @@ def forward( x_lens = (((x_lens - 1) >> 1) - 1) >> 1 assert x.size(0) == x_lens.max().item() - output, output_lengths = self.encoder( - x, x_lens, warmup=warmup - ) # (T, N, C) + output, output_lengths = self.encoder(x, x_lens, warmup=warmup) # (T, N, C) output = output.permute(1, 0, 2) # (T, N, C) -> (N, T, C) diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless/export.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless/export.py index 4930881ea3..334682ad62 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless/export.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless/export.py @@ -136,8 +136,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( @@ -181,9 +180,9 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -210,9 +209,9 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -279,9 +278,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless/stream.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless/stream.py index 9494e1fc15..c211b215ea 100644 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless/stream.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless/stream.py @@ -68,14 +68,12 @@ def __init__( elif params.decoding_method == "fast_beam_search": # feature_len is needed to get partial results. # The rnnt_decoding_stream for fast_beam_search. - self.rnnt_decoding_stream: k2.RnntDecodingStream = ( - k2.RnntDecodingStream(decoding_graph) + self.rnnt_decoding_stream: k2.RnntDecodingStream = k2.RnntDecodingStream( + decoding_graph ) self.hyp: Optional[List[int]] = None else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") self.ground_truth: str = "" diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless/streaming_decode.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless/streaming_decode.py index 61dbe86583..621eeb952d 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless/streaming_decode.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless/streaming_decode.py @@ -211,8 +211,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -371,9 +370,7 @@ def modified_beam_search( index=hyps_shape.row_ids(1).to(torch.int64), ) # (num_hyps, encoder_out_dim) - logits = model.joiner( - current_encoder_out, decoder_out, project_input=False - ) + logits = model.joiner(current_encoder_out, decoder_out, project_input=False) # logits is of shape (num_hyps, 1, 1, vocab_size) logits = logits.squeeze(1).squeeze(1) @@ -390,9 +387,7 @@ def modified_beam_search( log_probs_shape = k2.ragged.create_ragged_shape2( row_splits=row_splits, cached_tot_size=log_probs.numel() ) - ragged_log_probs = k2.RaggedTensor( - shape=log_probs_shape, value=log_probs - ) + ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) for i in range(batch_size): topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) @@ -551,14 +546,10 @@ def decode_one_chunk( feature_list, batch_first=True, padding_value=LOG_EPSILON ).to(device) feature_lens = torch.tensor(feature_len_list, device=device) - num_processed_frames = torch.tensor( - num_processed_frames_list, device=device - ) + num_processed_frames = torch.tensor(num_processed_frames_list, device=device) # Make sure it has at least 1 frame after subsampling, first-and-last-frame cutting, and right context cutting # noqa - tail_length = ( - 3 * params.subsampling_factor + params.right_context_length + 3 - ) + tail_length = 3 * params.subsampling_factor + params.right_context_length + 3 if features.size(1) < tail_length: pad_length = tail_length - features.size(1) feature_lens += pad_length @@ -605,9 +596,7 @@ def decode_one_chunk( max_states=params.max_states, ) else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") # Update cached states of each stream state_list = unstack_states(states) @@ -782,8 +771,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -831,9 +819,7 @@ def main(): params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-states-{params.max_states}" elif "beam_search" in params.decoding_method: - params.suffix += ( - f"-{params.decoding_method}-beam-size-{params.beam_size}" - ) + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -867,9 +853,9 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -896,9 +882,9 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless/train.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless/train.py index c07d8f76b0..3d8d4a18a1 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless/train.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless/train.py @@ -95,9 +95,7 @@ from icefall.env import get_env_info from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool -LRSchedulerType = Union[ - torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler -] +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] def add_model_arguments(parser: argparse.ArgumentParser): @@ -265,8 +263,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( @@ -289,8 +286,7 @@ def get_parser(): "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)" - "part.", + help="The scale to smooth the loss with am (output of encoder network)" "part.", ) parser.add_argument( @@ -636,11 +632,7 @@ def compute_loss( warmup: a floating point value which increases throughout training; values >= 1.0 are fully warmed up and have all modules present. """ - device = ( - model.device - if isinstance(model, DDP) - else next(model.parameters()).device - ) + device = model.device if isinstance(model, DDP) else next(model.parameters()).device feature = batch["inputs"] # at entry, feature is (N, T, C) assert feature.ndim == 3 @@ -668,23 +660,16 @@ def compute_loss( # overwhelming the simple_loss and causing it to diverge, # in case it had not fully learned the alignment yet. pruned_loss_scale = ( - 0.0 - if warmup < 1.0 - else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) - ) - loss = ( - params.simple_loss_scale * simple_loss - + pruned_loss_scale * pruned_loss + 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) ) + loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss assert loss.requires_grad == is_training info = MetricsTracker() with warnings.catch_warnings(): warnings.simplefilter("ignore") - info["frames"] = ( - (feature_lens // params.subsampling_factor).sum().item() - ) + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa info["utterances"] = feature.size(0) @@ -871,9 +856,7 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -981,7 +964,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2 ** 22 + 2**22 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/decode.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/decode.py index 98b8290b5c..d3c001942d 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/decode.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/decode.py @@ -215,8 +215,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( @@ -284,9 +283,7 @@ def decode_one_batch( value=LOG_EPS, ) - encoder_out, encoder_out_lens = model.encoder( - x=feature, x_lens=feature_lens - ) + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) hyps = [] if params.decoding_method == "fast_beam_search": @@ -301,10 +298,7 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif ( - params.decoding_method == "greedy_search" - and params.max_sym_per_frame == 1 - ): + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -427,9 +421,7 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -462,8 +454,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -506,9 +497,7 @@ def main(): params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-states-{params.max_states}" elif "beam_search" in params.decoding_method: - params.suffix += ( - f"-{params.decoding_method}-beam-size-{params.beam_size}" - ) + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -540,9 +529,9 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -569,9 +558,9 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/emformer.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/emformer.py index f16f5acc71..c3739566fe 100644 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/emformer.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/emformer.py @@ -35,7 +35,6 @@ from icefall.utils import make_pad_mask - LOG_EPSILON = math.log(1e-10) @@ -127,9 +126,7 @@ def stack_states( for si, s in enumerate(layer): attn_caches[li][si].append(s) if b == batch_size - 1: - attn_caches[li][si] = torch.stack( - attn_caches[li][si], dim=1 - ) + attn_caches[li][si] = torch.stack(attn_caches[li][si], dim=1) conv_caches = [] for layer in state_list[0][1]: @@ -268,9 +265,7 @@ def _split_right_context( intervals = torch.arange( 0, self.chunk_length * (num_chunks - 1), self.chunk_length ) - first = torch.arange( - self.chunk_length, self.chunk_length + self.cache_size - ) + first = torch.arange(self.chunk_length, self.chunk_length + self.cache_size) indexes = intervals.unsqueeze(1) + first.unsqueeze(0) indexes = torch.cat( [indexes, torch.arange(U_ - self.cache_size, U_).unsqueeze(0)] @@ -284,9 +279,7 @@ def _split_right_context( # (num_chunks * B, cache_size + right_context_length, D) return pad_right_context.permute(0, 2, 1) - def _merge_right_context( - self, right_context: torch.Tensor, B: int - ) -> torch.Tensor: + def _merge_right_context(self, right_context: torch.Tensor, B: int) -> torch.Tensor: """ Args: right_context: @@ -337,12 +330,8 @@ def forward( right_context = x[:, :, :R] # (B, D, R) # make causal convolution - cache = torch.zeros( - B, D, self.cache_size, device=x.device, dtype=x.dtype - ) - pad_utterance = torch.cat( - [cache, utterance], dim=2 - ) # (B, D, cache + U) + cache = torch.zeros(B, D, self.cache_size, device=x.device, dtype=x.dtype) + pad_utterance = torch.cat([cache, utterance], dim=2) # (B, D, cache + U) # depth-wise conv on utterance utterance = self.depthwise_conv(pad_utterance) # (B, D, U) @@ -355,9 +344,7 @@ def forward( right_context = self.depthwise_conv( pad_right_context ) # (num_segs * B, D, right_context_length) - right_context = self._merge_right_context( - right_context, B - ) # (B, D, R) + right_context = self._merge_right_context(right_context, B) # (B, D, R) x = torch.cat([right_context, utterance], dim=2) # (B, D, R + U) x = self.deriv_balancer2(x) @@ -458,8 +445,7 @@ def __init__( if embed_dim % nhead != 0: raise ValueError( - f"embed_dim ({embed_dim}) is not a multiple of" - f"nhead ({nhead})." + f"embed_dim ({embed_dim}) is not a multiple of" f"nhead ({nhead})." ) self.embed_dim = embed_dim @@ -469,9 +455,7 @@ def __init__( self.head_dim = embed_dim // nhead self.dropout = dropout - self.emb_to_key_value = ScaledLinear( - embed_dim, 2 * embed_dim, bias=True - ) + self.emb_to_key_value = ScaledLinear(embed_dim, 2 * embed_dim, bias=True) self.emb_to_query = ScaledLinear(embed_dim, embed_dim, bias=True) self.out_proj = ScaledLinear( embed_dim, embed_dim, bias=True, initial_scale=0.25 @@ -513,9 +497,7 @@ def _gen_attention_probs( if padding_mask is not None: Q = attention_weights.size(1) B = attention_weights.size(0) // self.nhead - attention_weights_float = attention_weights_float.view( - B, self.nhead, Q, -1 - ) + attention_weights_float = attention_weights_float.view(B, self.nhead, Q, -1) attention_weights_float = attention_weights_float.masked_fill( padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), self.negative_inf, @@ -561,16 +543,12 @@ def _forward_impl( # [memory, right context, left context, uttrance] # this is used in inference mode key = torch.cat([key[: M + R], left_context_key, key[M + R :]]) - value = torch.cat( - [value[: M + R], left_context_val, value[M + R :]] - ) + value = torch.cat([value[: M + R], left_context_val, value[M + R :]]) Q = query.size(0) # KV = key.size(0) reshaped_query, reshaped_key, reshaped_value = [ - tensor.contiguous() - .view(-1, B * self.nhead, self.head_dim) - .transpose(0, 1) + tensor.contiguous().view(-1, B * self.nhead, self.head_dim).transpose(0, 1) for tensor in [query, key, value] ] # (B * nhead, Q or KV, head_dim) attention_weights = torch.bmm( @@ -585,9 +563,7 @@ def _forward_impl( # compute attention outputs attention = torch.bmm(attention_probs, reshaped_value) assert attention.shape == (B * self.nhead, Q, self.head_dim) - attention = ( - attention.transpose(0, 1).contiguous().view(Q, B, self.embed_dim) - ) + attention = attention.transpose(0, 1).contiguous().view(Q, B, self.embed_dim) # apply output projection output_right_context_utterance = self.out_proj(attention) @@ -905,13 +881,11 @@ def _apply_attention_module_forward( right_context = right_context_utterance[:R] if self.use_memory: - memory = self.summary_op(utterance.permute(1, 2, 0)).permute( - 2, 0, 1 - )[:-1, :, :] + memory = self.summary_op(utterance.permute(1, 2, 0)).permute(2, 0, 1)[ + :-1, :, : + ] else: - memory = torch.empty(0).to( - dtype=utterance.dtype, device=utterance.device - ) + memory = torch.empty(0).to(dtype=utterance.dtype, device=utterance.device) output_right_context_utterance = self.attention( utterance=utterance, right_context=right_context, @@ -948,18 +922,12 @@ def _apply_attention_module_infer( left_context_val = attn_cache[2] if self.use_memory: - memory = self.summary_op(utterance.permute(1, 2, 0)).permute( - 2, 0, 1 - )[:1, :, :] + memory = self.summary_op(utterance.permute(1, 2, 0)).permute(2, 0, 1)[ + :1, :, : + ] else: - memory = torch.empty(0).to( - dtype=utterance.dtype, device=utterance.device - ) - ( - output_right_context_utterance, - next_key, - next_val, - ) = self.attention.infer( + memory = torch.empty(0).to(dtype=utterance.dtype, device=utterance.device) + (output_right_context_utterance, next_key, next_val,) = self.attention.infer( utterance=utterance, right_context=right_context, memory=pre_memory, @@ -967,9 +935,7 @@ def _apply_attention_module_infer( left_context_val=left_context_val, padding_mask=padding_mask, ) - attn_cache = self._update_attn_cache( - next_key, next_val, memory, attn_cache - ) + attn_cache = self._update_attn_cache(next_key, next_val, memory, attn_cache) return output_right_context_utterance, attn_cache def forward( @@ -1226,9 +1192,7 @@ def __init__( def _gen_right_context(self, x: torch.Tensor) -> torch.Tensor: """Hard copy each chunk's right context and concat them.""" T = x.shape[0] - num_chunks = math.ceil( - (T - self.right_context_length) / self.chunk_length - ) + num_chunks = math.ceil((T - self.right_context_length) / self.chunk_length) # first (num_chunks - 1) right context block intervals = torch.arange( 0, self.chunk_length * (num_chunks - 1), self.chunk_length @@ -1247,9 +1211,7 @@ def _gen_right_context(self, x: torch.Tensor) -> torch.Tensor: right_context_blocks = x[indexes.reshape(-1)] return right_context_blocks - def _gen_attention_mask_col_widths( - self, chunk_idx: int, U: int - ) -> List[int]: + def _gen_attention_mask_col_widths(self, chunk_idx: int, U: int) -> List[int]: """Calculate column widths (key, value) in attention mask for the chunk_idx chunk.""" num_chunks = math.ceil(U / self.chunk_length) @@ -1549,12 +1511,8 @@ def init_states(self, device: torch.device = torch.device("cpu")): attn_caches = [ [ torch.zeros(self.memory_size, self.d_model, device=device), - torch.zeros( - self.left_context_length, self.d_model, device=device - ), - torch.zeros( - self.left_context_length, self.d_model, device=device - ), + torch.zeros(self.left_context_length, self.d_model, device=device), + torch.zeros(self.left_context_length, self.d_model, device=device), ] for _ in range(self.num_encoder_layers) ] @@ -1599,17 +1557,11 @@ def __init__( raise NotImplementedError( "chunk_length must be a mutiple of subsampling_factor." ) - if ( - left_context_length != 0 - and left_context_length % subsampling_factor != 0 - ): + if left_context_length != 0 and left_context_length % subsampling_factor != 0: raise NotImplementedError( "left_context_length must be 0 or a mutiple of subsampling_factor." # noqa ) - if ( - right_context_length != 0 - and right_context_length % subsampling_factor != 0 - ): + if right_context_length != 0 and right_context_length % subsampling_factor != 0: raise NotImplementedError( "right_context_length must be 0 or a mutiple of subsampling_factor." # noqa ) @@ -1672,9 +1624,7 @@ def forward( x_lens = (((x_lens - 1) >> 1) - 1) >> 1 assert x.size(0) == x_lens.max().item() - output, output_lengths = self.encoder( - x, x_lens, warmup=warmup - ) # (T, N, C) + output, output_lengths = self.encoder(x, x_lens, warmup=warmup) # (T, N, C) output = output.permute(1, 0, 2) # (T, N, C) -> (N, T, C) diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export.py index ab15e02415..998fb6e813 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export.py @@ -136,8 +136,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( @@ -181,9 +180,9 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -210,9 +209,9 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -279,9 +278,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/streaming_decode.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/streaming_decode.py index 71150392d0..618d8bb638 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/streaming_decode.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/streaming_decode.py @@ -211,8 +211,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -371,9 +370,7 @@ def modified_beam_search( index=hyps_shape.row_ids(1).to(torch.int64), ) # (num_hyps, encoder_out_dim) - logits = model.joiner( - current_encoder_out, decoder_out, project_input=False - ) + logits = model.joiner(current_encoder_out, decoder_out, project_input=False) # logits is of shape (num_hyps, 1, 1, vocab_size) logits = logits.squeeze(1).squeeze(1) @@ -390,9 +387,7 @@ def modified_beam_search( log_probs_shape = k2.ragged.create_ragged_shape2( row_splits=row_splits, cached_tot_size=log_probs.numel() ) - ragged_log_probs = k2.RaggedTensor( - shape=log_probs_shape, value=log_probs - ) + ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) for i in range(batch_size): topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) @@ -551,14 +546,10 @@ def decode_one_chunk( feature_list, batch_first=True, padding_value=LOG_EPSILON ).to(device) feature_lens = torch.tensor(feature_len_list, device=device) - num_processed_frames = torch.tensor( - num_processed_frames_list, device=device - ) + num_processed_frames = torch.tensor(num_processed_frames_list, device=device) # Make sure it has at least 1 frame after subsampling, first-and-last-frame cutting, and right context cutting # noqa - tail_length = ( - 3 * params.subsampling_factor + params.right_context_length + 3 - ) + tail_length = 3 * params.subsampling_factor + params.right_context_length + 3 if features.size(1) < tail_length: pad_length = tail_length - features.size(1) feature_lens += pad_length @@ -605,9 +596,7 @@ def decode_one_chunk( max_states=params.max_states, ) else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") # Update cached states of each stream state_list = unstack_states(states) @@ -782,8 +771,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -831,9 +819,7 @@ def main(): params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-states-{params.max_states}" elif "beam_search" in params.decoding_method: - params.suffix += ( - f"-{params.decoding_method}-beam-size-{params.beam_size}" - ) + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -867,9 +853,9 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -896,9 +882,9 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/train.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/train.py index 2bbc45d780..542f524a9a 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/train.py @@ -95,9 +95,7 @@ from icefall.env import get_env_info from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool -LRSchedulerType = Union[ - torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler -] +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] def add_model_arguments(parser: argparse.ArgumentParser): @@ -265,8 +263,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( @@ -289,8 +286,7 @@ def get_parser(): "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)" - "part.", + help="The scale to smooth the loss with am (output of encoder network)" "part.", ) parser.add_argument( @@ -636,11 +632,7 @@ def compute_loss( warmup: a floating point value which increases throughout training; values >= 1.0 are fully warmed up and have all modules present. """ - device = ( - model.device - if isinstance(model, DDP) - else next(model.parameters()).device - ) + device = model.device if isinstance(model, DDP) else next(model.parameters()).device feature = batch["inputs"] # at entry, feature is (N, T, C) assert feature.ndim == 3 @@ -668,23 +660,16 @@ def compute_loss( # overwhelming the simple_loss and causing it to diverge, # in case it had not fully learned the alignment yet. pruned_loss_scale = ( - 0.0 - if warmup < 1.0 - else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) - ) - loss = ( - params.simple_loss_scale * simple_loss - + pruned_loss_scale * pruned_loss + 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) ) + loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss assert loss.requires_grad == is_training info = MetricsTracker() with warnings.catch_warnings(): warnings.simplefilter("ignore") - info["frames"] = ( - (feature_lens // params.subsampling_factor).sum().item() - ) + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa info["utterances"] = feature.size(0) @@ -871,9 +856,7 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -981,7 +964,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2 ** 22 + 2**22 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/librispeech/ASR/local/add_alignment_librispeech.py b/egs/librispeech/ASR/local/add_alignment_librispeech.py index fe6a26c51b..cc34a72d81 100755 --- a/egs/librispeech/ASR/local/add_alignment_librispeech.py +++ b/egs/librispeech/ASR/local/add_alignment_librispeech.py @@ -157,9 +157,7 @@ def add_alignment( for ali_path in part_ali_dir.rglob("*.alignment.txt"): ali = parse_alignments(ali_path) alignments.update(ali) - logging.info( - f"{part} has {len(alignments.keys())} cuts with alignments." - ) + logging.info(f"{part} has {len(alignments.keys())} cuts with alignments.") # add alignment attribute and write out cuts_in = load_manifest_lazy(cuts_in_path) @@ -170,18 +168,14 @@ def add_alignment( if origin_id in alignments: ali = alignments[origin_id] else: - logging.info( - f"Warning: {origin_id} does not have alignment." - ) + logging.info(f"Warning: {origin_id} does not have alignment.") ali = [] subcut.alignment = {"word": ali} writer.write(cut, flush=True) def main(): - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) parser = get_parser() diff --git a/egs/librispeech/ASR/local/compile_hlg.py b/egs/librispeech/ASR/local/compile_hlg.py index c628dfd532..df6c609bbd 100755 --- a/egs/librispeech/ASR/local/compile_hlg.py +++ b/egs/librispeech/ASR/local/compile_hlg.py @@ -57,7 +57,7 @@ def get_args(): return parser.parse_args() -def compile_HLG(lang_dir: str, lm: str="G_3_gram") -> k2.Fsa: +def compile_HLG(lang_dir: str, lm: str = "G_3_gram") -> k2.Fsa: """ Args: lang_dir: @@ -159,9 +159,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) diff --git a/egs/librispeech/ASR/local/compile_lg.py b/egs/librispeech/ASR/local/compile_lg.py index 45c4b7f5fd..19bf3bff44 100755 --- a/egs/librispeech/ASR/local/compile_lg.py +++ b/egs/librispeech/ASR/local/compile_lg.py @@ -132,9 +132,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) diff --git a/egs/librispeech/ASR/local/compute_fbank_gigaspeech_dev_test.py b/egs/librispeech/ASR/local/compute_fbank_gigaspeech_dev_test.py index c0c7ef8c55..97750f3ea4 100644 --- a/egs/librispeech/ASR/local/compute_fbank_gigaspeech_dev_test.py +++ b/egs/librispeech/ASR/local/compute_fbank_gigaspeech_dev_test.py @@ -80,9 +80,7 @@ def compute_fbank_gigaspeech_dev_test(): def main(): - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) compute_fbank_gigaspeech_dev_test() diff --git a/egs/librispeech/ASR/local/compute_fbank_gigaspeech_splits.py b/egs/librispeech/ASR/local/compute_fbank_gigaspeech_splits.py index 5587106e53..ce0ef24e7c 100644 --- a/egs/librispeech/ASR/local/compute_fbank_gigaspeech_splits.py +++ b/egs/librispeech/ASR/local/compute_fbank_gigaspeech_splits.py @@ -144,9 +144,7 @@ def main(): date_time = now.strftime("%Y-%m-%d-%H-%M-%S") log_filename = "log-compute_fbank_gigaspeech_splits" - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" log_filename = f"{log_filename}-{date_time}" logging.basicConfig( diff --git a/egs/librispeech/ASR/local/compute_fbank_librispeech.py b/egs/librispeech/ASR/local/compute_fbank_librispeech.py index ce7d087f0d..9f85038143 100755 --- a/egs/librispeech/ASR/local/compute_fbank_librispeech.py +++ b/egs/librispeech/ASR/local/compute_fbank_librispeech.py @@ -112,9 +112,7 @@ def compute_fbank_librispeech(bpe_model: Optional[str] = None): if "train" in partition: cut_set = ( - cut_set - + cut_set.perturb_speed(0.9) - + cut_set.perturb_speed(1.1) + cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) ) cut_set = cut_set.compute_and_store_features( extractor=extractor, @@ -128,9 +126,7 @@ def compute_fbank_librispeech(bpe_model: Optional[str] = None): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) args = get_args() diff --git a/egs/librispeech/ASR/local/compute_fbank_musan.py b/egs/librispeech/ASR/local/compute_fbank_musan.py index 056da29e5c..4a4093ae49 100755 --- a/egs/librispeech/ASR/local/compute_fbank_musan.py +++ b/egs/librispeech/ASR/local/compute_fbank_musan.py @@ -83,9 +83,7 @@ def compute_fbank_musan(): # create chunks of Musan with duration 5 - 10 seconds musan_cuts = ( CutSet.from_manifests( - recordings=combine( - part["recordings"] for part in manifests.values() - ) + recordings=combine(part["recordings"] for part in manifests.values()) ) .cut_into_windows(10.0) .filter(lambda c: c.duration > 5) @@ -101,9 +99,7 @@ def compute_fbank_musan(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) compute_fbank_musan() diff --git a/egs/librispeech/ASR/local/convert_transcript_words_to_tokens.py b/egs/librispeech/ASR/local/convert_transcript_words_to_tokens.py index 133499c8bc..a8d5117c97 100755 --- a/egs/librispeech/ASR/local/convert_transcript_words_to_tokens.py +++ b/egs/librispeech/ASR/local/convert_transcript_words_to_tokens.py @@ -51,16 +51,12 @@ def get_args(): "lines. Each line consists of space separated words.", ) parser.add_argument("--lexicon", type=str, help="The input lexicon file.") - parser.add_argument( - "--oov", type=str, default="
", help="The OOV word." - ) + parser.add_argument("--oov", type=str, default=" ", help="The OOV word.") return parser.parse_args() -def process_line( - lexicon: Dict[str, List[str]], line: str, oov_token: str -) -> None: +def process_line(lexicon: Dict[str, List[str]], line: str, oov_token: str) -> None: """ Args: lexicon: diff --git a/egs/librispeech/ASR/local/download_lm.py b/egs/librispeech/ASR/local/download_lm.py index 030122aa71..3518db5244 100755 --- a/egs/librispeech/ASR/local/download_lm.py +++ b/egs/librispeech/ASR/local/download_lm.py @@ -87,9 +87,7 @@ def main(out_dir: str): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) diff --git a/egs/librispeech/ASR/local/filter_cuts.py b/egs/librispeech/ASR/local/filter_cuts.py index dff98a954d..b3f0956c3e 100644 --- a/egs/librispeech/ASR/local/filter_cuts.py +++ b/egs/librispeech/ASR/local/filter_cuts.py @@ -79,8 +79,7 @@ def remove_short_and_long_utterances(c: Cut): total += 1 if c.duration < 1.0 or c.duration > 20.0: logging.warning( - f"Exclude cut with ID {c.id} from training. " - f"Duration: {c.duration}" + f"Exclude cut with ID {c.id} from training. " f"Duration: {c.duration}" ) removed += 1 return False @@ -125,8 +124,7 @@ def remove_short_and_long_utterances(c: Cut): ans = cut_set.filter(remove_short_and_long_utterances).to_eager() ratio = removed / total * 100 logging.info( - f"Removed {removed} cuts from {total} cuts. " - f"{ratio:.3f}% data is removed." + f"Removed {removed} cuts from {total} cuts. " f"{ratio:.3f}% data is removed." ) return ans @@ -155,9 +153,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) diff --git a/egs/librispeech/ASR/local/generate_unique_lexicon.py b/egs/librispeech/ASR/local/generate_unique_lexicon.py index 566c0743db..3459c2f5ad 100755 --- a/egs/librispeech/ASR/local/generate_unique_lexicon.py +++ b/egs/librispeech/ASR/local/generate_unique_lexicon.py @@ -91,9 +91,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) diff --git a/egs/librispeech/ASR/local/prepare_lang_bpe.py b/egs/librispeech/ASR/local/prepare_lang_bpe.py index dec8a74426..e121aefa93 100755 --- a/egs/librispeech/ASR/local/prepare_lang_bpe.py +++ b/egs/librispeech/ASR/local/prepare_lang_bpe.py @@ -150,9 +150,7 @@ def generate_lexicon( words_pieces_ids: List[List[int]] = sp.encode(words, out_type=int) # Now convert word piece IDs back to word piece strings. - words_pieces: List[List[str]] = [ - sp.id_to_piece(ids) for ids in words_pieces_ids - ] + words_pieces: List[List[str]] = [sp.id_to_piece(ids) for ids in words_pieces_ids] lexicon = [] for word, pieces in zip(words, words_pieces): diff --git a/egs/librispeech/ASR/local/prepare_lm_training_data.py b/egs/librispeech/ASR/local/prepare_lm_training_data.py index 5070341f12..32ae8c5809 100755 --- a/egs/librispeech/ASR/local/prepare_lm_training_data.py +++ b/egs/librispeech/ASR/local/prepare_lm_training_data.py @@ -137,8 +137,7 @@ def main(): for i in range(num_sentences): if step and i % step == 0: logging.info( - f"Processed number of lines: {i} " - f"({i/num_sentences*100: .3f}%)" + f"Processed number of lines: {i} " f"({i/num_sentences*100: .3f}%)" ) word_ids = sentences[i] @@ -154,18 +153,14 @@ def main(): sentence_lengths[i] = token_ids.numel() - output["sentence_lengths"] = torch.tensor( - sentence_lengths, dtype=torch.int32 - ) + output["sentence_lengths"] = torch.tensor(sentence_lengths, dtype=torch.int32) torch.save(output, args.lm_archive) logging.info(f"Saved to {args.lm_archive}") if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) diff --git a/egs/librispeech/ASR/local/preprocess_gigaspeech.py b/egs/librispeech/ASR/local/preprocess_gigaspeech.py index 077f23039b..8aa5e461d4 100644 --- a/egs/librispeech/ASR/local/preprocess_gigaspeech.py +++ b/egs/librispeech/ASR/local/preprocess_gigaspeech.py @@ -119,9 +119,7 @@ def preprocess_giga_speech(): def main(): - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) preprocess_giga_speech() diff --git a/egs/librispeech/ASR/local/test_prepare_lang.py b/egs/librispeech/ASR/local/test_prepare_lang.py index d4cf62bba5..74e025ad78 100755 --- a/egs/librispeech/ASR/local/test_prepare_lang.py +++ b/egs/librispeech/ASR/local/test_prepare_lang.py @@ -88,9 +88,7 @@ def test_read_lexicon(filename: str): fsa.aux_labels_sym = k2.SymbolTable.from_file("words.txt") fsa.draw("L.pdf", title="L") - fsa_disambig = lexicon_to_fst( - lexicon_disambig, phone2id=phone2id, word2id=word2id - ) + fsa_disambig = lexicon_to_fst(lexicon_disambig, phone2id=phone2id, word2id=word2id) fsa_disambig.labels_sym = k2.SymbolTable.from_file("phones.txt") fsa_disambig.aux_labels_sym = k2.SymbolTable.from_file("words.txt") fsa_disambig.draw("L_disambig.pdf", title="L_disambig") diff --git a/egs/librispeech/ASR/local/validate_manifest.py b/egs/librispeech/ASR/local/validate_manifest.py index 7c57d629a8..f620b91ead 100755 --- a/egs/librispeech/ASR/local/validate_manifest.py +++ b/egs/librispeech/ASR/local/validate_manifest.py @@ -85,9 +85,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/decode.py b/egs/librispeech/ASR/lstm_transducer_stateless/decode.py index 27414d717e..79b21fab19 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless/decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless/decode.py @@ -272,8 +272,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( @@ -366,9 +365,7 @@ def decode_one_batch( ) feature_lens += num_tail_padded_frames - encoder_out, encoder_out_lens, _ = model.encoder( - x=feature, x_lens=feature_lens - ) + encoder_out, encoder_out_lens, _ = model.encoder(x=feature, x_lens=feature_lens) hyps = [] @@ -427,10 +424,7 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif ( - params.decoding_method == "greedy_search" - and params.max_sym_per_frame == 1 - ): + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -561,9 +555,7 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -596,8 +588,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -648,9 +639,7 @@ def main(): if "LG" in params.decoding_method: params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" elif "beam_search" in params.decoding_method: - params.suffix += ( - f"-{params.decoding_method}-beam-size-{params.beam_size}" - ) + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -682,9 +671,9 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -711,9 +700,9 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -772,9 +761,7 @@ def main(): decoding_graph.scores *= params.ngram_lm_scale else: word_table = None - decoding_graph = k2.trivial_graph( - params.vocab_size - 1, device=device - ) + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) else: decoding_graph = None word_table = None diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/export.py b/egs/librispeech/ASR/lstm_transducer_stateless/export.py index 13dac60092..45fa6d6620 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless/export.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless/export.py @@ -172,8 +172,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) add_model_arguments(parser) @@ -281,9 +280,9 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -310,9 +309,9 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -380,9 +379,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/jit_pretrained.py b/egs/librispeech/ASR/lstm_transducer_stateless/jit_pretrained.py index 594c33e4f5..51f4a2e8aa 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless/jit_pretrained.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless/jit_pretrained.py @@ -124,8 +124,7 @@ def read_sound_files( for f in filenames: wave, sample_rate = torchaudio.load(f) assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" + f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}" ) # We use only the first channel ans.append(wave[0]) @@ -314,9 +313,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/lstm.py b/egs/librispeech/ASR/lstm_transducer_stateless/lstm.py index c54a4c4780..bbab16af7b 100644 --- a/egs/librispeech/ASR/lstm_transducer_stateless/lstm.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless/lstm.py @@ -672,9 +672,7 @@ def __init__( self.stddev = stddev self.final_log_weight = ( - torch.tensor( - (final_weight / (1 - final_weight)) * (self.num_inputs - 1) - ) + torch.tensor((final_weight / (1 - final_weight)) * (self.num_inputs - 1)) .log() .item() ) @@ -771,16 +769,14 @@ def _get_random_pure_weights( # final contains self.num_inputs - 1 in all elements final = torch.full((num_frames,), self.num_inputs - 1, device=device) # nonfinal contains random integers in [0..num_inputs - 2], these are for non-final weights. # noqa - nonfinal = torch.randint( - self.num_inputs - 1, (num_frames,), device=device - ) + nonfinal = torch.randint(self.num_inputs - 1, (num_frames,), device=device) indexes = torch.where( torch.rand(num_frames, device=device) < final_prob, final, nonfinal ) - ans = torch.nn.functional.one_hot( - indexes, num_classes=self.num_inputs - ).to(dtype=dtype) + ans = torch.nn.functional.one_hot(indexes, num_classes=self.num_inputs).to( + dtype=dtype + ) return ans def _get_random_mixed_weights( diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/model.py b/egs/librispeech/ASR/lstm_transducer_stateless/model.py index d71132b4ac..e7bad7ed83 100644 --- a/egs/librispeech/ASR/lstm_transducer_stateless/model.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless/model.py @@ -66,9 +66,7 @@ def __init__( self.decoder = decoder self.joiner = joiner - self.simple_am_proj = ScaledLinear( - encoder_dim, vocab_size, initial_speed=0.5 - ) + self.simple_am_proj = ScaledLinear(encoder_dim, vocab_size, initial_speed=0.5) self.simple_lm_proj = ScaledLinear(decoder_dim, vocab_size) def forward( @@ -151,9 +149,7 @@ def forward( y_padded = y.pad(mode="constant", padding_value=0) y_padded = y_padded.to(torch.int64) - boundary = torch.zeros( - (x.size(0), 4), dtype=torch.int64, device=x.device - ) + boundary = torch.zeros((x.size(0), 4), dtype=torch.int64, device=x.device) boundary[:, 2] = y_lens boundary[:, 3] = x_lens diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/pretrained.py b/egs/librispeech/ASR/lstm_transducer_stateless/pretrained.py index 2a6e2adc67..9263b41b2a 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless/pretrained.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless/pretrained.py @@ -166,8 +166,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -199,8 +198,7 @@ def read_sound_files( for f in filenames: wave, sample_rate = torchaudio.load(f) assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" + f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}" ) # We use only the first channel ans.append(wave[0]) @@ -264,15 +262,11 @@ def main(): features = fbank(waves) feature_lengths = [f.size(0) for f in features] - features = pad_sequence( - features, batch_first=True, padding_value=math.log(1e-10) - ) + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) feature_lengths = torch.tensor(feature_lengths, device=device) - encoder_out, encoder_out_lens, _ = model.encoder( - x=features, x_lens=feature_lengths - ) + encoder_out, encoder_out_lens, _ = model.encoder(x=features, x_lens=feature_lengths) num_waves = encoder_out.size(0) hyps = [] @@ -344,9 +338,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/stream.py b/egs/librispeech/ASR/lstm_transducer_stateless/stream.py index 97d890c820..d8f7fd960c 100644 --- a/egs/librispeech/ASR/lstm_transducer_stateless/stream.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless/stream.py @@ -70,14 +70,12 @@ def __init__( elif params.decoding_method == "fast_beam_search": # feature_len is needed to get partial results. # The rnnt_decoding_stream for fast_beam_search. - self.rnnt_decoding_stream: k2.RnntDecodingStream = ( - k2.RnntDecodingStream(decoding_graph) + self.rnnt_decoding_stream: k2.RnntDecodingStream = k2.RnntDecodingStream( + decoding_graph ) self.hyp: Optional[List[int]] = None else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") self.ground_truth: str = "" diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/streaming_decode.py b/egs/librispeech/ASR/lstm_transducer_stateless/streaming_decode.py index d6376bdc06..4cc2aabb2f 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless/streaming_decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless/streaming_decode.py @@ -199,8 +199,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -359,9 +358,7 @@ def modified_beam_search( index=hyps_shape.row_ids(1).to(torch.int64), ) # (num_hyps, encoder_out_dim) - logits = model.joiner( - current_encoder_out, decoder_out, project_input=False - ) + logits = model.joiner(current_encoder_out, decoder_out, project_input=False) # logits is of shape (num_hyps, 1, 1, vocab_size) logits = logits.squeeze(1).squeeze(1) @@ -378,9 +375,7 @@ def modified_beam_search( log_probs_shape = k2.ragged.create_ragged_shape2( row_splits=row_splits, cached_tot_size=log_probs.numel() ) - ragged_log_probs = k2.RaggedTensor( - shape=log_probs_shape, value=log_probs - ) + ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) for i in range(batch_size): topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) @@ -539,9 +534,7 @@ def decode_one_chunk( feature_list, batch_first=True, padding_value=LOG_EPSILON ).to(device) feature_lens = torch.tensor(feature_len_list, device=device) - num_processed_frames = torch.tensor( - num_processed_frames_list, device=device - ) + num_processed_frames = torch.tensor(num_processed_frames_list, device=device) # Make sure it has at least 1 frame after subsampling tail_length = params.subsampling_factor + 5 @@ -583,8 +576,7 @@ def decode_one_chunk( with warnings.catch_warnings(): warnings.simplefilter("ignore") processed_lens = ( - num_processed_frames // params.subsampling_factor - + encoder_out_lens + num_processed_frames // params.subsampling_factor + encoder_out_lens ) fast_beam_search_one_best( model=model, @@ -596,9 +588,7 @@ def decode_one_chunk( max_states=params.max_states, ) else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") # Update cached states of each stream state_list = unstack_states(states) @@ -773,8 +763,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -816,9 +805,7 @@ def main(): params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-states-{params.max_states}" elif "beam_search" in params.decoding_method: - params.suffix += ( - f"-{params.decoding_method}-beam-size-{params.beam_size}" - ) + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -852,9 +839,9 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -881,9 +868,9 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/train.py b/egs/librispeech/ASR/lstm_transducer_stateless/train.py index d30fc260a9..b9a68753e4 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless/train.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless/train.py @@ -87,9 +87,7 @@ str2bool, ) -LRSchedulerType = Union[ - torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler -] +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] def add_model_arguments(parser: argparse.ArgumentParser): @@ -222,8 +220,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( @@ -246,8 +243,7 @@ def get_parser(): "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)" - "part.", + help="The scale to smooth the loss with am (output of encoder network)" "part.", ) parser.add_argument( @@ -594,11 +590,7 @@ def compute_loss( warmup: a floating point value which increases throughout training; values >= 1.0 are fully warmed up and have all modules present. """ - device = ( - model.device - if isinstance(model, DDP) - else next(model.parameters()).device - ) + device = model.device if isinstance(model, DDP) else next(model.parameters()).device feature = batch["inputs"] # at entry, feature is (N, T, C) assert feature.ndim == 3 @@ -638,9 +630,7 @@ def compute_loss( # If either all simple_loss or pruned_loss is inf or nan, # we stop the training process by raising an exception - if torch.all(~simple_loss_is_finite) or torch.all( - ~pruned_loss_is_finite - ): + if torch.all(~simple_loss_is_finite) or torch.all(~pruned_loss_is_finite): raise ValueError( "There are too many utterances in this batch " "leading to inf or nan losses." @@ -653,14 +643,9 @@ def compute_loss( # overwhelming the simple_loss and causing it to diverge, # in case it had not fully learned the alignment yet. pruned_loss_scale = ( - 0.0 - if warmup < 1.0 - else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) - ) - loss = ( - params.simple_loss_scale * simple_loss - + pruned_loss_scale * pruned_loss + 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) ) + loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss assert loss.requires_grad == is_training @@ -671,9 +656,7 @@ def compute_loss( # (1) The acutal subsampling factor is ((lens - 1) // 2 - 1) // 2 # (2) If some utterances in the batch lead to inf/nan loss, they # are filtered out. - info["frames"] = ( - (feature_lens // params.subsampling_factor).sum().item() - ) + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa info["utterances"] = feature.size(0) @@ -856,9 +839,7 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -989,8 +970,7 @@ def remove_short_and_long_utt(c: Cut): # the threshold if c.duration < 1.0 or c.duration > 20.0: logging.warning( - f"Exclude cut with ID {c.id} from training. " - f"Duration: {c.duration}" + f"Exclude cut with ID {c.id} from training. " f"Duration: {c.duration}" ) return False diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py b/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py index bad4e243e5..41602d2070 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py @@ -295,8 +295,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( @@ -474,9 +473,7 @@ def decode_one_batch( ) feature_lens += num_tail_padded_frames - encoder_out, encoder_out_lens, _ = model.encoder( - x=feature, x_lens=feature_lens - ) + encoder_out, encoder_out_lens, _ = model.encoder(x=feature, x_lens=feature_lens) hyps = [] @@ -535,10 +532,7 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif ( - params.decoding_method == "greedy_search" - and params.max_sym_per_frame == 1 - ): + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -700,9 +694,7 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -735,8 +727,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -789,9 +780,7 @@ def main(): if "LG" in params.decoding_method: params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" elif "beam_search" in params.decoding_method: - params.suffix += ( - f"-{params.decoding_method}-beam-size-{params.beam_size}" - ) + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -826,9 +815,9 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -860,9 +849,9 @@ def main(): ) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -961,9 +950,7 @@ def main(): decoding_graph.scores *= params.ngram_lm_scale else: word_table = None - decoding_graph = k2.trivial_graph( - params.vocab_size - 1, device=device - ) + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) else: decoding_graph = None word_table = None diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/export.py b/egs/librispeech/ASR/lstm_transducer_stateless2/export.py index 1906736382..2a25cb46a5 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/export.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/export.py @@ -225,8 +225,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) add_model_arguments(parser) @@ -342,9 +341,7 @@ def export_encoder_model_onnx( x = torch.zeros(N, 9, 80, dtype=torch.float32) x_lens = torch.tensor([9], dtype=torch.int64) h = torch.rand(encoder_model.num_encoder_layers, N, encoder_model.d_model) - c = torch.rand( - encoder_model.num_encoder_layers, N, encoder_model.rnn_hidden_size - ) + c = torch.rand(encoder_model.num_encoder_layers, N, encoder_model.rnn_hidden_size) warmup = 1.0 torch.onnx.export( @@ -445,13 +442,9 @@ def export_joiner_model_onnx( - projected_decoder_out: a tensor of shape (N, joiner_dim) """ - encoder_proj_filename = str(joiner_filename).replace( - ".onnx", "_encoder_proj.onnx" - ) + encoder_proj_filename = str(joiner_filename).replace(".onnx", "_encoder_proj.onnx") - decoder_proj_filename = str(joiner_filename).replace( - ".onnx", "_decoder_proj.onnx" - ) + decoder_proj_filename = str(joiner_filename).replace(".onnx", "_decoder_proj.onnx") encoder_out_dim = joiner_model.encoder_proj.weight.shape[1] decoder_out_dim = joiner_model.decoder_proj.weight.shape[1] @@ -550,9 +543,9 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -585,9 +578,9 @@ def main(): ) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -694,9 +687,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/jit_pretrained.py b/egs/librispeech/ASR/lstm_transducer_stateless2/jit_pretrained.py index da184b76fd..40f11018fe 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/jit_pretrained.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/jit_pretrained.py @@ -125,8 +125,7 @@ def read_sound_files( for f in filenames: wave, sample_rate = torchaudio.load(f) assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" + f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}" ) # We use only the first channel ans.append(wave[0]) @@ -315,9 +314,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/model.py b/egs/librispeech/ASR/lstm_transducer_stateless2/model.py index fadeb4ac2f..4957d14b11 100644 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/model.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/model.py @@ -84,9 +84,7 @@ def __init__( self.decoder_giga = decoder_giga self.joiner_giga = joiner_giga - self.simple_am_proj = ScaledLinear( - encoder_dim, vocab_size, initial_speed=0.5 - ) + self.simple_am_proj = ScaledLinear(encoder_dim, vocab_size, initial_speed=0.5) self.simple_lm_proj = ScaledLinear(decoder_dim, vocab_size) if decoder_giga is not None: @@ -190,9 +188,7 @@ def forward( y_padded = y.pad(mode="constant", padding_value=0) y_padded = y_padded.to(torch.int64) - boundary = torch.zeros( - (x.size(0), 4), dtype=torch.int64, device=x.device - ) + boundary = torch.zeros((x.size(0), 4), dtype=torch.int64, device=x.device) boundary[:, 2] = y_lens boundary[:, 3] = x_lens diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/ncnn-decode.py b/egs/librispeech/ASR/lstm_transducer_stateless2/ncnn-decode.py index 410de8d3d3..ab2f174806 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/ncnn-decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/ncnn-decode.py @@ -156,9 +156,7 @@ def run_encoder(self, x, states): assert ret == 0, ret encoder_out = torch.from_numpy(ncnn_out0.numpy()).clone() - encoder_out_lens = torch.from_numpy(ncnn_out1.numpy()).to( - torch.int32 - ) + encoder_out_lens = torch.from_numpy(ncnn_out1.numpy()).to(torch.int32) hx = torch.from_numpy(ncnn_out2.numpy()).clone() cx = torch.from_numpy(ncnn_out3.numpy()).clone() return encoder_out, encoder_out_lens, hx, cx @@ -201,8 +199,7 @@ def read_sound_files( for f in filenames: wave, sample_rate = torchaudio.load(f) assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" + f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}" ) # We use only the first channel ans.append(wave[0]) @@ -286,9 +283,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/pretrained.py b/egs/librispeech/ASR/lstm_transducer_stateless2/pretrained.py index bef0ad760f..2983328bf5 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/pretrained.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/pretrained.py @@ -169,8 +169,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -202,8 +201,7 @@ def read_sound_files( for f in filenames: wave, sample_rate = torchaudio.load(f) assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" + f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}" ) # We use only the first channel ans.append(wave[0]) @@ -267,15 +265,11 @@ def main(): features = fbank(waves) feature_lengths = [f.size(0) for f in features] - features = pad_sequence( - features, batch_first=True, padding_value=math.log(1e-10) - ) + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) feature_lengths = torch.tensor(feature_lengths, device=device) - encoder_out, encoder_out_lens, _ = model.encoder( - x=features, x_lens=feature_lengths - ) + encoder_out, encoder_out_lens, _ = model.encoder(x=features, x_lens=feature_lengths) num_waves = encoder_out.size(0) hyps = [] @@ -347,9 +341,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/streaming-ncnn-decode.py b/egs/librispeech/ASR/lstm_transducer_stateless2/streaming-ncnn-decode.py index e47a05a9e4..a787a00e60 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/streaming-ncnn-decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/streaming-ncnn-decode.py @@ -144,9 +144,7 @@ def run_encoder(self, x, states): assert ret == 0, ret encoder_out = torch.from_numpy(ncnn_out0.numpy()).clone() - encoder_out_lens = torch.from_numpy(ncnn_out1.numpy()).to( - torch.int32 - ) + encoder_out_lens = torch.from_numpy(ncnn_out1.numpy()).to(torch.int32) hx = torch.from_numpy(ncnn_out2.numpy()).clone() cx = torch.from_numpy(ncnn_out3.numpy()).clone() return encoder_out, encoder_out_lens, hx, cx @@ -189,8 +187,7 @@ def read_sound_files( for f in filenames: wave, sample_rate = torchaudio.load(f) assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" + f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}" ) # We use only the first channel ans.append(wave[0]) @@ -229,9 +226,7 @@ def greedy_search( if decoder_out is None: assert hyp is None, hyp hyp = [blank_id] * context_size - decoder_input = torch.tensor( - hyp, dtype=torch.int32 - ) # (1, context_size) + decoder_input = torch.tensor(hyp, dtype=torch.int32) # (1, context_size) decoder_out = model.run_decoder(decoder_input).squeeze(0) else: assert decoder_out.ndim == 1 @@ -310,9 +305,7 @@ def main(): frames.append(online_fbank.get_frame(num_processed_frames + i)) num_processed_frames += offset frames = torch.cat(frames, dim=0) - encoder_out, encoder_out_lens, hx, cx = model.run_encoder( - frames, states - ) + encoder_out, encoder_out_lens, hx, cx = model.run_encoder(frames, states) states = (hx, cx) hyp, decoder_out = greedy_search( model, encoder_out.squeeze(0), decoder_out, hyp @@ -328,9 +321,7 @@ def main(): frames.append(online_fbank.get_frame(num_processed_frames + i)) num_processed_frames += offset frames = torch.cat(frames, dim=0) - encoder_out, encoder_out_lens, hx, cx = model.run_encoder( - frames, states - ) + encoder_out, encoder_out_lens, hx, cx = model.run_encoder(frames, states) states = (hx, cx) hyp, decoder_out = greedy_search( model, encoder_out.squeeze(0), decoder_out, hyp @@ -343,9 +334,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/streaming-onnx-decode.py b/egs/librispeech/ASR/lstm_transducer_stateless2/streaming-onnx-decode.py index 232d3dd18f..e896fd510c 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/streaming-onnx-decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/streaming-onnx-decode.py @@ -148,8 +148,7 @@ def read_sound_files( for f in filenames: wave, sample_rate = torchaudio.load(f) assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" + f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}" ) # We use only the first channel ans.append(wave[0]) @@ -199,9 +198,7 @@ def init_joiner_decoder_proj(self, args): sess_options=self.session_opts, ) - def run_encoder( - self, x, h0, c0 - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + def run_encoder(self, x, h0, c0) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Args: x: @@ -258,9 +255,7 @@ def run_decoder(self, decoder_input: torch.Tensor) -> torch.Tensor: }, )[0] - return self.run_joiner_decoder_proj( - torch.from_numpy(decoder_out).squeeze(1) - ) + return self.run_joiner_decoder_proj(torch.from_numpy(decoder_out).squeeze(1)) def run_joiner( self, @@ -303,11 +298,7 @@ def run_joiner_encoder_proj( projected_encoder_out = self.joiner_encoder_proj.run( [self.joiner_encoder_proj.get_outputs()[0].name], - { - self.joiner_encoder_proj.get_inputs()[ - 0 - ].name: encoder_out.numpy() - }, + {self.joiner_encoder_proj.get_inputs()[0].name: encoder_out.numpy()}, )[0] return torch.from_numpy(projected_encoder_out) @@ -326,11 +317,7 @@ def run_joiner_decoder_proj( projected_decoder_out = self.joiner_decoder_proj.run( [self.joiner_decoder_proj.get_outputs()[0].name], - { - self.joiner_decoder_proj.get_inputs()[ - 0 - ].name: decoder_out.numpy() - }, + {self.joiner_decoder_proj.get_inputs()[0].name: decoder_out.numpy()}, )[0] return torch.from_numpy(projected_decoder_out) @@ -369,9 +356,7 @@ def greedy_search( if decoder_out is None: assert hyp is None, hyp hyp = [blank_id] * context_size - decoder_input = torch.tensor( - [hyp], dtype=torch.int64 - ) # (1, context_size) + decoder_input = torch.tensor([hyp], dtype=torch.int64) # (1, context_size) decoder_out = model.run_decoder(decoder_input) else: assert decoder_out.shape[0] == 1 @@ -474,9 +459,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/train.py b/egs/librispeech/ASR/lstm_transducer_stateless2/train.py index 5eaaf321f9..056285c645 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/train.py @@ -95,9 +95,7 @@ str2bool, ) -LRSchedulerType = Union[ - torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler -] +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] def add_model_arguments(parser: argparse.ArgumentParser): @@ -163,8 +161,7 @@ def get_parser(): "--full-libri", type=str2bool, default=True, - help="When enabled, use 960h LibriSpeech. " - "Otherwise, use 100h subset.", + help="When enabled, use 960h LibriSpeech. " "Otherwise, use 100h subset.", ) parser.add_argument( @@ -238,8 +235,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( @@ -262,8 +258,7 @@ def get_parser(): "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)" - "part.", + help="The scale to smooth the loss with am (output of encoder network)" "part.", ) parser.add_argument( @@ -645,11 +640,7 @@ def compute_loss( warmup: a floating point value which increases throughout training; values >= 1.0 are fully warmed up and have all modules present. """ - device = ( - model.device - if isinstance(model, DDP) - else next(model.parameters()).device - ) + device = model.device if isinstance(model, DDP) else next(model.parameters()).device feature = batch["inputs"] # at entry, feature is (N, T, C) assert feature.ndim == 3 @@ -692,9 +683,7 @@ def compute_loss( # If either all simple_loss or pruned_loss is inf or nan, # we stop the training process by raising an exception - if torch.all(~simple_loss_is_finite) or torch.all( - ~pruned_loss_is_finite - ): + if torch.all(~simple_loss_is_finite) or torch.all(~pruned_loss_is_finite): raise ValueError( "There are too many utterances in this batch " "leading to inf or nan losses." @@ -707,14 +696,9 @@ def compute_loss( # overwhelming the simple_loss and causing it to diverge, # in case it had not fully learned the alignment yet. pruned_loss_scale = ( - 0.0 - if warmup < 1.0 - else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) - ) - loss = ( - params.simple_loss_scale * simple_loss - + pruned_loss_scale * pruned_loss + 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) ) + loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss assert loss.requires_grad == is_training @@ -725,9 +709,7 @@ def compute_loss( # (1) The acutal subsampling factor is ((lens - 1) // 2 - 1) // 2 # (2) If some utterances in the batch lead to inf/nan loss, they # are filtered out. - info["frames"] = ( - (feature_lens // params.subsampling_factor).sum().item() - ) + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa info["utterances"] = feature.size(0) @@ -958,9 +940,7 @@ def train_one_epoch( f"train/current_{prefix}_", params.batch_idx_train, ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) libri_tot_loss.write_summary( tb_writer, "train/libri_tot_", params.batch_idx_train ) @@ -1006,8 +986,7 @@ def remove_short_and_long_utt(c: Cut): # the threshold if c.duration < 1.0 or c.duration > 20.0: logging.warning( - f"Exclude cut with ID {c.id} from training. " - f"Duration: {c.duration}" + f"Exclude cut with ID {c.id} from training. " f"Duration: {c.duration}" ) return False @@ -1155,9 +1134,7 @@ def run(rank, world_size, args): train_giga_cuts = train_giga_cuts.repeat(times=None) if args.enable_musan: - cuts_musan = load_manifest( - Path(args.manifest_dir) / "musan_cuts.jsonl.gz" - ) + cuts_musan = load_manifest(Path(args.manifest_dir) / "musan_cuts.jsonl.gz") else: cuts_musan = None diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/decode.py b/egs/librispeech/ASR/lstm_transducer_stateless3/decode.py index 9eee19379c..cba1ac6890 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless3/decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless3/decode.py @@ -290,8 +290,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( @@ -386,9 +385,7 @@ def decode_one_batch( ) feature_lens += num_tail_padded_frames - encoder_out, encoder_out_lens, _ = model.encoder( - x=feature, x_lens=feature_lens - ) + encoder_out, encoder_out_lens, _ = model.encoder(x=feature, x_lens=feature_lens) if params.decoding_method == "fast_beam_search": res = fast_beam_search_one_best( @@ -441,10 +438,7 @@ def decode_one_batch( nbest_scale=params.nbest_scale, return_timestamps=True, ) - elif ( - params.decoding_method == "greedy_search" - and params.max_sym_per_frame == 1 - ): + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: res = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -522,9 +516,7 @@ def decode_dataset( sp: spm.SentencePieceProcessor, word_table: Optional[k2.SymbolTable] = None, decoding_graph: Optional[k2.Fsa] = None, -) -> Dict[ - str, List[Tuple[str, List[str], List[str], List[float], List[float]]] -]: +) -> Dict[str, List[Tuple[str, List[str], List[str], List[float], List[float]]]]: """Decode dataset. Args: @@ -599,9 +591,7 @@ def decode_dataset( cut_ids, hyps, texts, timestamps_hyp, timestamps_ref ): ref_words = ref_text.split() - this_batch.append( - (cut_id, ref_words, hyp_words, time_ref, time_hyp) - ) + this_batch.append((cut_id, ref_words, hyp_words, time_ref, time_hyp)) results[name].extend(this_batch) @@ -610,9 +600,7 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -650,8 +638,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -678,9 +665,7 @@ def save_results( note = "" logging.info(s) - s = "\nFor {}, symbol-delay of different settings are:\n".format( - test_set_name - ) + s = "\nFor {}, symbol-delay of different settings are:\n".format(test_set_name) note = "\tbest for {}".format(test_set_name) for key, val in test_set_delays: s += "{}\tmean: {}s, variance: {}{}\n".format(key, val[0], val[1], note) @@ -724,9 +709,7 @@ def main(): if "LG" in params.decoding_method: params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" elif "beam_search" in params.decoding_method: - params.suffix += ( - f"-{params.decoding_method}-beam-size-{params.beam_size}" - ) + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -758,9 +741,9 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -787,9 +770,9 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -848,9 +831,7 @@ def main(): decoding_graph.scores *= params.ngram_lm_scale else: word_table = None - decoding_graph = k2.trivial_graph( - params.vocab_size - 1, device=device - ) + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) else: decoding_graph = None word_table = None diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/export.py b/egs/librispeech/ASR/lstm_transducer_stateless3/export.py index 212c7bad69..457bd472f4 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless3/export.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless3/export.py @@ -172,8 +172,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) add_model_arguments(parser) @@ -281,9 +280,9 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -310,9 +309,9 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -380,9 +379,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/jit_pretrained.py b/egs/librispeech/ASR/lstm_transducer_stateless3/jit_pretrained.py index a3443cf0a6..71b37ac552 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless3/jit_pretrained.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless3/jit_pretrained.py @@ -124,8 +124,7 @@ def read_sound_files( for f in filenames: wave, sample_rate = torchaudio.load(f) assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" + f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}" ) # We use only the first channel ans.append(wave[0]) @@ -314,9 +313,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/lstm.py b/egs/librispeech/ASR/lstm_transducer_stateless3/lstm.py index 90bc351f48..6e51b85e41 100644 --- a/egs/librispeech/ASR/lstm_transducer_stateless3/lstm.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless3/lstm.py @@ -661,9 +661,7 @@ def __init__( self.stddev = stddev self.final_log_weight = ( - torch.tensor( - (final_weight / (1 - final_weight)) * (self.num_inputs - 1) - ) + torch.tensor((final_weight / (1 - final_weight)) * (self.num_inputs - 1)) .log() .item() ) @@ -760,16 +758,14 @@ def _get_random_pure_weights( # final contains self.num_inputs - 1 in all elements final = torch.full((num_frames,), self.num_inputs - 1, device=device) # nonfinal contains random integers in [0..num_inputs - 2], these are for non-final weights. # noqa - nonfinal = torch.randint( - self.num_inputs - 1, (num_frames,), device=device - ) + nonfinal = torch.randint(self.num_inputs - 1, (num_frames,), device=device) indexes = torch.where( torch.rand(num_frames, device=device) < final_prob, final, nonfinal ) - ans = torch.nn.functional.one_hot( - indexes, num_classes=self.num_inputs - ).to(dtype=dtype) + ans = torch.nn.functional.one_hot(indexes, num_classes=self.num_inputs).to( + dtype=dtype + ) return ans def _get_random_mixed_weights( diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/pretrained.py b/egs/librispeech/ASR/lstm_transducer_stateless3/pretrained.py index 0e48fef044..e72f4ee422 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless3/pretrained.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless3/pretrained.py @@ -166,8 +166,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -199,8 +198,7 @@ def read_sound_files( for f in filenames: wave, sample_rate = torchaudio.load(f) assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" + f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}" ) # We use only the first channel ans.append(wave[0]) @@ -264,15 +262,11 @@ def main(): features = fbank(waves) feature_lengths = [f.size(0) for f in features] - features = pad_sequence( - features, batch_first=True, padding_value=math.log(1e-10) - ) + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) feature_lengths = torch.tensor(feature_lengths, device=device) - encoder_out, encoder_out_lens, _ = model.encoder( - x=features, x_lens=feature_lengths - ) + encoder_out, encoder_out_lens, _ = model.encoder(x=features, x_lens=feature_lengths) num_waves = encoder_out.size(0) hyps = [] @@ -344,9 +338,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/streaming_decode.py b/egs/librispeech/ASR/lstm_transducer_stateless3/streaming_decode.py index cfa918ed53..dad6b905f6 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless3/streaming_decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless3/streaming_decode.py @@ -199,8 +199,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -359,9 +358,7 @@ def modified_beam_search( index=hyps_shape.row_ids(1).to(torch.int64), ) # (num_hyps, encoder_out_dim) - logits = model.joiner( - current_encoder_out, decoder_out, project_input=False - ) + logits = model.joiner(current_encoder_out, decoder_out, project_input=False) # logits is of shape (num_hyps, 1, 1, vocab_size) logits = logits.squeeze(1).squeeze(1) @@ -378,9 +375,7 @@ def modified_beam_search( log_probs_shape = k2.ragged.create_ragged_shape2( row_splits=row_splits, cached_tot_size=log_probs.numel() ) - ragged_log_probs = k2.RaggedTensor( - shape=log_probs_shape, value=log_probs - ) + ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) for i in range(batch_size): topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) @@ -539,9 +534,7 @@ def decode_one_chunk( feature_list, batch_first=True, padding_value=LOG_EPSILON ).to(device) feature_lens = torch.tensor(feature_len_list, device=device) - num_processed_frames = torch.tensor( - num_processed_frames_list, device=device - ) + num_processed_frames = torch.tensor(num_processed_frames_list, device=device) # Make sure it has at least 1 frame after subsampling tail_length = params.subsampling_factor + 5 @@ -583,8 +576,7 @@ def decode_one_chunk( with warnings.catch_warnings(): warnings.simplefilter("ignore") processed_lens = ( - num_processed_frames // params.subsampling_factor - + encoder_out_lens + num_processed_frames // params.subsampling_factor + encoder_out_lens ) fast_beam_search_one_best( model=model, @@ -596,9 +588,7 @@ def decode_one_chunk( max_states=params.max_states, ) else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") # Update cached states of each stream state_list = unstack_states(states) @@ -773,8 +763,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -816,9 +805,7 @@ def main(): params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-states-{params.max_states}" elif "beam_search" in params.decoding_method: - params.suffix += ( - f"-{params.decoding_method}-beam-size-{params.beam_size}" - ) + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -852,9 +839,9 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -881,9 +868,9 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/train.py b/egs/librispeech/ASR/lstm_transducer_stateless3/train.py index 60a5a2be74..97ca4b94c8 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless3/train.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless3/train.py @@ -87,9 +87,7 @@ str2bool, ) -LRSchedulerType = Union[ - torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler -] +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] def add_model_arguments(parser: argparse.ArgumentParser): @@ -232,8 +230,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( @@ -256,8 +253,7 @@ def get_parser(): "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)" - "part.", + help="The scale to smooth the loss with am (output of encoder network)" "part.", ) parser.add_argument( @@ -606,11 +602,7 @@ def compute_loss( warmup: a floating point value which increases throughout training; values >= 1.0 are fully warmed up and have all modules present. """ - device = ( - model.device - if isinstance(model, DDP) - else next(model.parameters()).device - ) + device = model.device if isinstance(model, DDP) else next(model.parameters()).device feature = batch["inputs"] # at entry, feature is (N, T, C) assert feature.ndim == 3 @@ -650,9 +642,7 @@ def compute_loss( # If either all simple_loss or pruned_loss is inf or nan, # we stop the training process by raising an exception - if torch.all(~simple_loss_is_finite) or torch.all( - ~pruned_loss_is_finite - ): + if torch.all(~simple_loss_is_finite) or torch.all(~pruned_loss_is_finite): raise ValueError( "There are too many utterances in this batch " "leading to inf or nan losses." @@ -665,14 +655,9 @@ def compute_loss( # overwhelming the simple_loss and causing it to diverge, # in case it had not fully learned the alignment yet. pruned_loss_scale = ( - 0.0 - if warmup < 1.0 - else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) - ) - loss = ( - params.simple_loss_scale * simple_loss - + pruned_loss_scale * pruned_loss + 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) ) + loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss assert loss.requires_grad == is_training @@ -683,9 +668,7 @@ def compute_loss( # (1) The acutal subsampling factor is ((lens - 1) // 2 - 1) // 2 # (2) If some utterances in the batch lead to inf/nan loss, they # are filtered out. - info["frames"] = ( - (feature_lens // params.subsampling_factor).sum().item() - ) + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa info["utterances"] = feature.size(0) @@ -852,10 +835,7 @@ def train_one_epoch( rank=rank, ) - if ( - batch_idx % params.log_interval == 0 - and not params.print_diagnostics - ): + if batch_idx % params.log_interval == 0 and not params.print_diagnostics: cur_lr = scheduler.get_last_lr()[0] logging.info( f"Epoch {params.cur_epoch}, " @@ -872,9 +852,7 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) if ( batch_idx > 0 @@ -1009,8 +987,7 @@ def remove_short_and_long_utt(c: Cut): # the threshold if c.duration < 1.0 or c.duration > 20.0: logging.warning( - f"Exclude cut with ID {c.id} from training. " - f"Duration: {c.duration}" + f"Exclude cut with ID {c.id} from training. " f"Duration: {c.duration}" ) return False diff --git a/egs/librispeech/ASR/pruned2_knowledge/asr_datamodule.py b/egs/librispeech/ASR/pruned2_knowledge/asr_datamodule.py index 8dd1459ca5..3dc9164f87 100644 --- a/egs/librispeech/ASR/pruned2_knowledge/asr_datamodule.py +++ b/egs/librispeech/ASR/pruned2_knowledge/asr_datamodule.py @@ -83,8 +83,7 @@ def add_arguments(cls, parser: argparse.ArgumentParser): "--full-libri", type=str2bool, default=True, - help="When enabled, use 960h LibriSpeech. " - "Otherwise, use 100h subset.", + help="When enabled, use 960h LibriSpeech. " "Otherwise, use 100h subset.", ) group.add_argument( "--manifest-dir", @@ -208,13 +207,9 @@ def train_dataloaders( if self.args.enable_musan: logging.info("Enable MUSAN") logging.info("About to get Musan cuts") - cuts_musan = load_manifest( - self.args.manifest_dir / "cuts_musan.json.gz" - ) + cuts_musan = load_manifest(self.args.manifest_dir / "cuts_musan.json.gz") transforms.append( - CutMix( - cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True - ) + CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) ) else: logging.info("Disable MUSAN") @@ -236,9 +231,7 @@ def train_dataloaders( input_transforms = [] if self.args.enable_spec_aug: logging.info("Enable SpecAugment") - logging.info( - f"Time warp factor: {self.args.spec_aug_time_warp_factor}" - ) + logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") # Set the value of num_frame_masks according to Lhotse's version. # In different Lhotse's versions, the default of num_frame_masks is # different. @@ -281,9 +274,7 @@ def train_dataloaders( # Drop feats to be on the safe side. train = K2SpeechRecognitionDataset( cut_transforms=transforms, - input_strategy=OnTheFlyFeatures( - Fbank(FbankConfig(num_mel_bins=80)) - ), + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), input_transforms=input_transforms, return_cuts=self.args.return_cuts, ) @@ -340,9 +331,7 @@ def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: if self.args.on_the_fly_feats: validate = K2SpeechRecognitionDataset( cut_transforms=transforms, - input_strategy=OnTheFlyFeatures( - Fbank(FbankConfig(num_mel_bins=80)) - ), + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), return_cuts=self.args.return_cuts, ) else: @@ -389,23 +378,17 @@ def test_dataloaders(self, cuts: CutSet) -> DataLoader: @lru_cache() def train_clean_100_cuts(self) -> CutSet: logging.info("About to get train-clean-100 cuts") - return load_manifest( - self.args.manifest_dir / "cuts_train-clean-100.json.gz" - ) + return load_manifest(self.args.manifest_dir / "cuts_train-clean-100.json.gz") @lru_cache() def train_clean_360_cuts(self) -> CutSet: logging.info("About to get train-clean-360 cuts") - return load_manifest( - self.args.manifest_dir / "cuts_train-clean-360.json.gz" - ) + return load_manifest(self.args.manifest_dir / "cuts_train-clean-360.json.gz") @lru_cache() def train_other_500_cuts(self) -> CutSet: logging.info("About to get train-other-500 cuts") - return load_manifest( - self.args.manifest_dir / "cuts_train-other-500.json.gz" - ) + return load_manifest(self.args.manifest_dir / "cuts_train-other-500.json.gz") @lru_cache() def dev_clean_cuts(self) -> CutSet: diff --git a/egs/librispeech/ASR/pruned2_knowledge/beam_search.py b/egs/librispeech/ASR/pruned2_knowledge/beam_search.py index 2e9bf3e0b7..785a8f0971 100644 --- a/egs/librispeech/ASR/pruned2_knowledge/beam_search.py +++ b/egs/librispeech/ASR/pruned2_knowledge/beam_search.py @@ -172,9 +172,9 @@ def greedy_search( y = logits.argmax().item() if y != blank_id: hyp.append(y) - decoder_input = torch.tensor( - [hyp[-context_size:]], device=device - ).reshape(1, context_size) + decoder_input = torch.tensor([hyp[-context_size:]], device=device).reshape( + 1, context_size + ) decoder_out = model.decoder(decoder_input, need_pad=False) decoder_out = model.joiner.decoder_proj(decoder_out) @@ -302,9 +302,7 @@ def add(self, hyp: Hypothesis) -> None: key = hyp.key if key in self: old_hyp = self._data[key] # shallow copy - torch.logaddexp( - old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob - ) + torch.logaddexp(old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob) else: self._data[key] = hyp @@ -320,9 +318,7 @@ def get_most_probable(self, length_norm: bool = False) -> Hypothesis: Return the hypothesis that has the largest `log_prob`. """ if length_norm: - return max( - self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys) - ) + return max(self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys)) else: return max(self._data.values(), key=lambda hyp: hyp.log_prob) @@ -496,9 +492,7 @@ def modified_beam_search( log_probs_shape = k2.ragged.create_ragged_shape2( row_splits=row_splits, cached_tot_size=log_probs.numel() ) - ragged_log_probs = k2.RaggedTensor( - shape=log_probs_shape, value=log_probs - ) + ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) for i in range(batch_size): topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) diff --git a/egs/librispeech/ASR/pruned2_knowledge/conformer.py b/egs/librispeech/ASR/pruned2_knowledge/conformer.py index 295a352049..de367c2343 100644 --- a/egs/librispeech/ASR/pruned2_knowledge/conformer.py +++ b/egs/librispeech/ASR/pruned2_knowledge/conformer.py @@ -18,10 +18,10 @@ import math import warnings from typing import Optional, Tuple -from sampling import create_knowledge_base, KnowledgeBaseLookup import torch from encoder_interface import EncoderInterface +from sampling import KnowledgeBaseLookup, create_knowledge_base from scaling import ( ActivationBalancer, BasicNorm, @@ -73,9 +73,9 @@ def __init__( if subsampling_factor != 4: raise NotImplementedError("Support only 'subsampling_factor=4'.") - - self.knowledge_base = create_knowledge_base(knowledge_M, knowledge_N, - knowledge_D) + self.knowledge_base = create_knowledge_base( + knowledge_M, knowledge_N, knowledge_D + ) # self.encoder_embed converts the input of shape (N, T, num_features) # to the shape (N, T//subsampling_factor, d_model). @@ -89,7 +89,7 @@ def __init__( # Pass in a lambda that creates a new ConformerEncoderLayer with these # args. Don't use deepcopy because we need the knowledge_base # to be shared. - encoder_layer_fn = lambda: ConformerEncoderLayer( + encoder_layer_fn = lambda: ConformerEncoderLayer( # noqa: E731 self.knowledge_base, d_model, nhead, @@ -100,7 +100,7 @@ def __init__( knowledge_M, knowledge_N, knowledge_D, - knowledge_K + knowledge_K, ) self.encoder = ConformerEncoder(encoder_layer_fn, num_encoder_layers) @@ -187,9 +187,7 @@ def __init__( self.d_model = d_model - self.self_attn = RelPositionMultiheadAttention( - d_model, nhead, dropout=0.0 - ) + self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0) self.feed_forward = nn.Sequential( ScaledLinear(d_model, dim_feedforward), @@ -209,10 +207,9 @@ def __init__( self.conv_module = ConvolutionModule(d_model, cnn_module_kernel) - self.lookup = KnowledgeBaseLookup(knowledge_M, knowledge_N, - knowledge_D, knowledge_K, - d_model, - knowledge_base) + self.lookup = KnowledgeBaseLookup( + knowledge_M, knowledge_N, knowledge_D, knowledge_K, d_model, knowledge_base + ) self.norm_final = BasicNorm(d_model) @@ -311,9 +308,7 @@ class ConformerEncoder(nn.Module): def __init__(self, encoder_layer_fn, num_layers: int) -> None: super().__init__() - self.layers = nn.ModuleList( - [encoder_layer_fn() for i in range(num_layers)] - ) + self.layers = nn.ModuleList([encoder_layer_fn() for i in range(num_layers)]) self.num_layers = num_layers def forward( @@ -367,9 +362,7 @@ class RelPositionalEncoding(torch.nn.Module): """ - def __init__( - self, d_model: int, dropout_rate: float, max_len: int = 5000 - ) -> None: + def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None: """Construct an PositionalEncoding object.""" super(RelPositionalEncoding, self).__init__() self.d_model = d_model @@ -384,9 +377,7 @@ def extend_pe(self, x: Tensor) -> None: # the length of self.pe is 2 * input_len - 1 if self.pe.size(1) >= x.size(1) * 2 - 1: # Note: TorchScript doesn't implement operator== for torch.Device - if self.pe.dtype != x.dtype or str(self.pe.device) != str( - x.device - ): + if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device): self.pe = self.pe.to(dtype=x.dtype, device=x.device) return # Suppose `i` means to the position of query vecotr and `j` means the @@ -661,9 +652,9 @@ def multi_head_attention_forward( if torch.equal(query, key) and torch.equal(key, value): # self-attention - q, k, v = nn.functional.linear( - query, in_proj_weight, in_proj_bias - ).chunk(3, dim=-1) + q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk( + 3, dim=-1 + ) elif torch.equal(key, value): # encoder-decoder attention @@ -732,31 +723,22 @@ def multi_head_attention_forward( if attn_mask.dim() == 2: attn_mask = attn_mask.unsqueeze(0) if list(attn_mask.size()) != [1, query.size(0), key.size(0)]: - raise RuntimeError( - "The size of the 2D attn_mask is not correct." - ) + raise RuntimeError("The size of the 2D attn_mask is not correct.") elif attn_mask.dim() == 3: if list(attn_mask.size()) != [ bsz * num_heads, query.size(0), key.size(0), ]: - raise RuntimeError( - "The size of the 3D attn_mask is not correct." - ) + raise RuntimeError("The size of the 3D attn_mask is not correct.") else: raise RuntimeError( - "attn_mask's dimension {} is not supported".format( - attn_mask.dim() - ) + "attn_mask's dimension {} is not supported".format(attn_mask.dim()) ) # attn_mask's dim is 3 now. # convert ByteTensor key_padding_mask to bool - if ( - key_padding_mask is not None - and key_padding_mask.dtype == torch.uint8 - ): + if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: warnings.warn( "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead." ) @@ -795,9 +777,7 @@ def multi_head_attention_forward( # first compute matrix a and matrix c # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) - matrix_ac = torch.matmul( - q_with_bias_u, k - ) # (batch, head, time1, time2) + matrix_ac = torch.matmul(q_with_bias_u, k) # (batch, head, time1, time2) # compute matrix b and matrix d matrix_bd = torch.matmul( @@ -805,13 +785,9 @@ def multi_head_attention_forward( ) # (batch, head, time1, 2*time1-1) matrix_bd = self.rel_shift(matrix_bd) - attn_output_weights = ( - matrix_ac + matrix_bd - ) # (batch, head, time1, time2) + attn_output_weights = matrix_ac + matrix_bd # (batch, head, time1, time2) - attn_output_weights = attn_output_weights.view( - bsz * num_heads, tgt_len, -1 - ) + attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1) assert list(attn_output_weights.size()) == [ bsz * num_heads, @@ -845,13 +821,9 @@ def multi_head_attention_forward( attn_output = torch.bmm(attn_output_weights, v) assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim] attn_output = ( - attn_output.transpose(0, 1) - .contiguous() - .view(tgt_len, bsz, embed_dim) - ) - attn_output = nn.functional.linear( - attn_output, out_proj_weight, out_proj_bias + attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) ) + attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias) if need_weights: # average attention weights over heads @@ -874,9 +846,7 @@ class ConvolutionModule(nn.Module): """ - def __init__( - self, channels: int, kernel_size: int, bias: bool = True - ) -> None: + def __init__(self, channels: int, kernel_size: int, bias: bool = True) -> None: """Construct an ConvolutionModule object.""" super(ConvolutionModule, self).__init__() # kernerl_size should be a odd number for 'SAME' padding diff --git a/egs/librispeech/ASR/pruned2_knowledge/decode.py b/egs/librispeech/ASR/pruned2_knowledge/decode.py index b4a9af55a9..c3e7b01ab2 100755 --- a/egs/librispeech/ASR/pruned2_knowledge/decode.py +++ b/egs/librispeech/ASR/pruned2_knowledge/decode.py @@ -76,11 +76,7 @@ ) from train import get_params, get_transducer_model -from icefall.checkpoint import ( - average_checkpoints, - find_checkpoints, - load_checkpoint, -) +from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint from icefall.utils import ( AttributeDict, setup_logger, @@ -186,8 +182,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -245,9 +240,7 @@ def decode_one_batch( supervisions = batch["supervisions"] feature_lens = supervisions["num_frames"].to(device) - encoder_out, encoder_out_lens = model.encoder( - x=feature, x_lens=feature_lens - ) + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) hyps = [] if params.decoding_method == "fast_beam_search": @@ -262,10 +255,7 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif ( - params.decoding_method == "greedy_search" - and params.max_sym_per_frame == 1 - ): + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -385,9 +375,7 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -419,8 +407,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) diff --git a/egs/librispeech/ASR/pruned2_knowledge/decoder.py b/egs/librispeech/ASR/pruned2_knowledge/decoder.py index b6d94aaf15..0b9c886c7f 100644 --- a/egs/librispeech/ASR/pruned2_knowledge/decoder.py +++ b/egs/librispeech/ASR/pruned2_knowledge/decoder.py @@ -90,9 +90,7 @@ def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor: if self.context_size > 1: embedding_out = embedding_out.permute(0, 2, 1) if need_pad is True: - embedding_out = F.pad( - embedding_out, pad=(self.context_size - 1, 0) - ) + embedding_out = F.pad(embedding_out, pad=(self.context_size - 1, 0)) else: # During inference time, there is no need to do extra padding # as we only need one output diff --git a/egs/librispeech/ASR/pruned2_knowledge/decoder2.py b/egs/librispeech/ASR/pruned2_knowledge/decoder2.py index db51fb1cd7..0c9cee4314 100644 --- a/egs/librispeech/ASR/pruned2_knowledge/decoder2.py +++ b/egs/librispeech/ASR/pruned2_knowledge/decoder2.py @@ -14,12 +14,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Optional + import torch import torch.nn as nn import torch.nn.functional as F -from torch import Tensor -from typing import Optional from subsampling import ScaledConv1d +from torch import Tensor class Decoder(nn.Module): @@ -90,9 +91,7 @@ def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor: if self.context_size > 1: embedding_out = embedding_out.permute(0, 2, 1) if need_pad is True: - embedding_out = F.pad( - embedding_out, pad=(self.context_size - 1, 0) - ) + embedding_out = F.pad(embedding_out, pad=(self.context_size - 1, 0)) else: # During inference time, there is no need to do extra padding # as we only need one output @@ -102,7 +101,6 @@ def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor: return embedding_out - class ScaledEmbedding(nn.Module): r"""A simple lookup table that stores embeddings of a fixed dictionary and size. @@ -171,8 +169,13 @@ class ScaledEmbedding(nn.Module): [ 0.0000, 0.0000, 0.0000], [-0.1655, 0.9897, 0.0635]]]) """ - __constants__ = ['num_embeddings', 'embedding_dim', 'padding_idx', - 'scale_grad_by_freq', 'sparse'] + __constants__ = [ + "num_embeddings", + "embedding_dim", + "padding_idx", + "scale_grad_by_freq", + "sparse", + ] num_embeddings: int embedding_dim: int @@ -181,34 +184,41 @@ class ScaledEmbedding(nn.Module): weight: Tensor sparse: bool - def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None, - scale_grad_by_freq: bool = False, - sparse: bool = False, - scale_speed: float = 5.0) -> None: + def __init__( + self, + num_embeddings: int, + embedding_dim: int, + padding_idx: Optional[int] = None, + scale_grad_by_freq: bool = False, + sparse: bool = False, + scale_speed: float = 5.0, + ) -> None: super(ScaledEmbedding, self).__init__() self.num_embeddings = num_embeddings self.embedding_dim = embedding_dim if padding_idx is not None: if padding_idx > 0: - assert padding_idx < self.num_embeddings, 'Padding_idx must be within num_embeddings' + assert ( + padding_idx < self.num_embeddings + ), "Padding_idx must be within num_embeddings" elif padding_idx < 0: - assert padding_idx >= -self.num_embeddings, 'Padding_idx must be within num_embeddings' + assert ( + padding_idx >= -self.num_embeddings + ), "Padding_idx must be within num_embeddings" padding_idx = self.num_embeddings + padding_idx self.padding_idx = padding_idx self.scale_grad_by_freq = scale_grad_by_freq self.scale_speed = scale_speed - self.scale = nn.Parameter(torch.zeros(())) # see reset_parameters() + self.scale = nn.Parameter(torch.zeros(())) # see reset_parameters() self.sparse = sparse self.weight = nn.Parameter(torch.Tensor(num_embeddings, embedding_dim)) self.reset_parameters() - - def reset_parameters(self) -> None: nn.init.normal_(self.weight, std=0.05) - nn.init.constant_(self.scale, torch.tensor(1.0/0.05).log() / self.scale_speed) + nn.init.constant_(self.scale, torch.tensor(1.0 / 0.05).log() / self.scale_speed) if self.padding_idx is not None: with torch.no_grad(): @@ -217,22 +227,35 @@ def reset_parameters(self) -> None: def forward(self, input: Tensor) -> Tensor: scale = (self.scale * self.scale_speed).exp() if input.numel() < self.num_embeddings: - return F.embedding( - input, self.weight, self.padding_idx, - None, 2.0, # None, 2.0 relate to normalization - self.scale_grad_by_freq, self.sparse) * scale + return ( + F.embedding( + input, + self.weight, + self.padding_idx, + None, + 2.0, # None, 2.0 relate to normalization + self.scale_grad_by_freq, + self.sparse, + ) + * scale + ) else: return F.embedding( - input, self.weight * scale, self.padding_idx, - None, 2.0, # None, 2.0 relates to normalization - self.scale_grad_by_freq, self.sparse) + input, + self.weight * scale, + self.padding_idx, + None, + 2.0, # None, 2.0 relates to normalization + self.scale_grad_by_freq, + self.sparse, + ) def extra_repr(self) -> str: - s = '{num_embeddings}, {embedding_dim}, scale_speed={scale_speed}, scale={scale}' + s = "{num_embeddings}, {embedding_dim}, scale_speed={scale_speed}, scale={scale}" if self.padding_idx is not None: - s += ', padding_idx={padding_idx}' + s += ", padding_idx={padding_idx}" if self.scale_grad_by_freq is not False: - s += ', scale_grad_by_freq={scale_grad_by_freq}' + s += ", scale_grad_by_freq={scale_grad_by_freq}" if self.sparse is not False: - s += ', sparse=True' + s += ", sparse=True" return s.format(**self.__dict__) diff --git a/egs/librispeech/ASR/pruned2_knowledge/export.py b/egs/librispeech/ASR/pruned2_knowledge/export.py index 96d1a30fbb..ce5f162bf3 100755 --- a/egs/librispeech/ASR/pruned2_knowledge/export.py +++ b/egs/librispeech/ASR/pruned2_knowledge/export.py @@ -105,8 +105,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) return parser @@ -174,9 +173,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned2_knowledge/joiner.py b/egs/librispeech/ASR/pruned2_knowledge/joiner.py index 35f75ed2ad..68c663b66f 100644 --- a/egs/librispeech/ASR/pruned2_knowledge/joiner.py +++ b/egs/librispeech/ASR/pruned2_knowledge/joiner.py @@ -56,9 +56,7 @@ def forward( assert encoder_out.shape[:-1] == decoder_out.shape[:-1] if project_input: - logit = self.encoder_proj(encoder_out) + self.decoder_proj( - decoder_out - ) + logit = self.encoder_proj(encoder_out) + self.decoder_proj(decoder_out) else: logit = encoder_out + decoder_out diff --git a/egs/librispeech/ASR/pruned2_knowledge/model.py b/egs/librispeech/ASR/pruned2_knowledge/model.py index 599bf25067..ca8c28af17 100644 --- a/egs/librispeech/ASR/pruned2_knowledge/model.py +++ b/egs/librispeech/ASR/pruned2_knowledge/model.py @@ -63,9 +63,7 @@ def __init__( self.decoder = decoder self.joiner = joiner - self.simple_am_proj = ScaledLinear( - encoder_dim, vocab_size, initial_speed=0.5 - ) + self.simple_am_proj = ScaledLinear(encoder_dim, vocab_size, initial_speed=0.5) self.simple_lm_proj = ScaledLinear(decoder_dim, vocab_size) def forward( @@ -136,9 +134,7 @@ def forward( y_padded = y.pad(mode="constant", padding_value=0) y_padded = y_padded.to(torch.int64) - boundary = torch.zeros( - (x.size(0), 4), dtype=torch.int64, device=x.device - ) + boundary = torch.zeros((x.size(0), 4), dtype=torch.int64, device=x.device) boundary[:, 2] = y_lens boundary[:, 3] = x_lens diff --git a/egs/librispeech/ASR/pruned2_knowledge/optim.py b/egs/librispeech/ASR/pruned2_knowledge/optim.py index 432bf82207..76cd4e11ee 100644 --- a/egs/librispeech/ASR/pruned2_knowledge/optim.py +++ b/egs/librispeech/ASR/pruned2_knowledge/optim.py @@ -72,17 +72,11 @@ def __init__( if not 0.0 <= eps: raise ValueError("Invalid epsilon value: {}".format(eps)) if not 0.0 <= betas[0] < 1.0: - raise ValueError( - "Invalid beta parameter at index 0: {}".format(betas[0]) - ) + raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) if not 0.0 <= betas[1] < 1.0: - raise ValueError( - "Invalid beta parameter at index 1: {}".format(betas[1]) - ) + raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) if not 0 <= weight_decay <= 0.1: - raise ValueError( - "Invalid weight_decay value: {}".format(weight_decay) - ) + raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) if not 0 < target_rms <= 10.0: raise ValueError("Invalid target_rms value: {}".format(target_rms)) defaults = dict( @@ -118,9 +112,7 @@ def step(self, closure=None): # Perform optimization step grad = p.grad if grad.is_sparse: - raise RuntimeError( - "AdamW does not support sparse gradients" - ) + raise RuntimeError("AdamW does not support sparse gradients") state = self.state[p] @@ -147,7 +139,7 @@ def step(self, closure=None): # Decay the first and second moment running average coefficient exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) - denom = (exp_avg_sq.sqrt() * (bias_correction2 ** -0.5)).add_( + denom = (exp_avg_sq.sqrt() * (bias_correction2**-0.5)).add_( group["eps"] ) @@ -158,9 +150,7 @@ def step(self, closure=None): if p.numel() > 1: # avoid applying this weight-decay on "scaling factors" # (which are scalar). - is_above_target_rms = p.norm() > ( - target_rms * (p.numel() ** 0.5) - ) + is_above_target_rms = p.norm() > (target_rms * (p.numel() ** 0.5)) p.mul_(1 - (weight_decay * is_above_target_rms)) p.addcdiv_(exp_avg, denom, value=-step_size) @@ -176,18 +166,14 @@ class LRScheduler(object): def __init__(self, optimizer: Optimizer, verbose: bool = False): # Attach optimizer if not isinstance(optimizer, Optimizer): - raise TypeError( - "{} is not an Optimizer".format(type(optimizer).__name__) - ) + raise TypeError("{} is not an Optimizer".format(type(optimizer).__name__)) self.optimizer = optimizer self.verbose = verbose for group in optimizer.param_groups: group.setdefault("initial_lr", group["lr"]) - self.base_lrs = [ - group["initial_lr"] for group in optimizer.param_groups - ] + self.base_lrs = [group["initial_lr"] for group in optimizer.param_groups] self.epoch = 0 self.batch = 0 @@ -295,10 +281,9 @@ def __init__( def get_lr(self): factor = ( - (self.batch ** 2 + self.lr_batches ** 2) / self.lr_batches ** 2 + (self.batch**2 + self.lr_batches**2) / self.lr_batches**2 ) ** -0.25 * ( - ((self.epoch ** 2 + self.lr_epochs ** 2) / self.lr_epochs ** 2) - ** -0.25 + ((self.epoch**2 + self.lr_epochs**2) / self.lr_epochs**2) ** -0.25 ) return [x * factor for x in self.base_lrs] diff --git a/egs/librispeech/ASR/pruned2_knowledge/sampling.py b/egs/librispeech/ASR/pruned2_knowledge/sampling.py index 7b05e2f00b..5b595c76c9 100644 --- a/egs/librispeech/ASR/pruned2_knowledge/sampling.py +++ b/egs/librispeech/ASR/pruned2_knowledge/sampling.py @@ -3,32 +3,29 @@ # This was copied from /ceph-dan/torch-sampling/torch_sampling/sampling_ref.py, # its git history is there. +import random import timeit +from typing import Optional, Tuple + import torch -from torch import Tensor -from torch import nn -from torch.cuda.amp import GradScaler, custom_fwd, custom_bwd -from typing import Tuple, Optional from scaling import ScaledLinear -import random +from torch import Tensor, nn +from torch.cuda.amp import GradScaler, custom_bwd, custom_fwd from torch_scheduled_sampling import sample_combined # The main exports of this file are the module KnowledgeBaseLookup and the # function create_knowledge_base. - - - - def create_knowledge_base(M: int, N: int, D: int) -> nn.Parameter: std = 0.1 - a = (3 ** 0.5) * std # this sqrt(3) thing is intended to get variance of - # 0.1 from uniform distribution - ans = nn.Parameter(torch.ones(M ** N, D)) + a = (3**0.5) * std # this sqrt(3) thing is intended to get variance of + # 0.1 from uniform distribution + ans = nn.Parameter(torch.ones(M**N, D)) nn.init.uniform_(ans, -a, a) return ans + def join_indexes(indexes: Tensor, M: int) -> Tensor: """ Combines N-tuples of indexes into single indexes that can be used for @@ -47,9 +44,9 @@ def join_indexes(indexes: Tensor, M: int) -> Tensor: # Note, we don't use this, we -def weighted_matrix_lookup(weights: Tensor, - indexes: Tensor, - knowledge_base: Tensor) -> Tensor: +def weighted_matrix_lookup( + weights: Tensor, indexes: Tensor, knowledge_base: Tensor +) -> Tensor: """ Weighted combination of specified rows of a matrix. weights: Tensor of shape (*, K), can contain any value but probably in [0..1]. @@ -65,9 +62,9 @@ def weighted_matrix_lookup(weights: Tensor, # simpler but less memory-efficient implementation lookup = torch.index_select(knowledge_base, dim=0, index=indexes.flatten()) D = knowledge_base.shape[-1] - weights = weights.unsqueeze(-2) # (*, 1, K) - lookup = lookup.reshape(*indexes.shape, D) # (*, K, D) - ans = torch.matmul(weights, lookup) # ans: (*, 1, D) + weights = weights.unsqueeze(-2) # (*, 1, K) + lookup = lookup.reshape(*indexes.shape, D) # (*, K, D) + ans = torch.matmul(weights, lookup) # ans: (*, 1, D) ans = ans.squeeze(-2) assert list(ans.shape) == list(weights.shape[:-2]) + [D] return ans @@ -76,7 +73,9 @@ def weighted_matrix_lookup(weights: Tensor, class WeightedMatrixLookupFunction(torch.autograd.Function): @staticmethod @custom_fwd - def forward(ctx, weights: Tensor, indexes: Tensor, knowledge_base: Tensor) -> Tensor: + def forward( + ctx, weights: Tensor, indexes: Tensor, knowledge_base: Tensor + ) -> Tensor: """ Weighted combination of specified rows of a matrix. weights: Tensor of shape (*, K), can contain any value but probably in [0..1]. @@ -88,15 +87,16 @@ def forward(ctx, weights: Tensor, indexes: Tensor, knowledge_base: Tensor) -> Te """ if random.random() < 0.001: print("dtype[1] = ", weights.dtype) - ctx.save_for_backward(weights.detach(), indexes.detach(), - knowledge_base.detach()) + ctx.save_for_backward( + weights.detach(), indexes.detach(), knowledge_base.detach() + ) with torch.no_grad(): lookup = torch.index_select(knowledge_base, dim=0, index=indexes.flatten()) D = knowledge_base.shape[-1] - weights = weights.unsqueeze(-2) # (*, 1, K) - lookup = lookup.reshape(*indexes.shape, D) # (*, K, D) - ans = torch.matmul(weights, lookup) # ans: (*, 1, D) - ans = ans.squeeze(-2) #(*, D) + weights = weights.unsqueeze(-2) # (*, 1, K) + lookup = lookup.reshape(*indexes.shape, D) # (*, K, D) + ans = torch.matmul(weights, lookup) # ans: (*, 1, D) + ans = ans.squeeze(-2) # (*, D) return ans @staticmethod @@ -107,7 +107,7 @@ def backward(ctx, ans_grad: Tensor) -> Tuple[Tensor, None, Tensor]: knowledge_base.requires_grad = True dtype = ans_grad.dtype ans_grad = ans_grad.to(weights.dtype) - assert weights.requires_grad == False + assert weights.requires_grad is False D = knowledge_base.shape[-1] with torch.enable_grad(): # we'll use torch's autograd to differentiate this operation, which @@ -115,16 +115,19 @@ def backward(ctx, ans_grad: Tensor) -> Tuple[Tensor, None, Tensor]: # We don't save `lookup` because it's large, that is the reason # we override Torch autograd. lookup = torch.index_select(knowledge_base, dim=0, index=indexes.flatten()) - lookup = lookup.reshape(*indexes.shape, D) # (*, K, D) - weights = weights.unsqueeze(-1) # (*, K, 1) + lookup = lookup.reshape(*indexes.shape, D) # (*, K, D) + weights = weights.unsqueeze(-1) # (*, K, 1) # forward pass: was: ## ans = torch.matmul(weights, lookup) ## ans: (*, 1, D) ## ans = ans.squeeze(-2) # ans, ans_grad: (*, D) - weights_grad = torch.matmul(lookup, # (*, K, D) - ans_grad.unsqueeze(-1)) # (*, D, 1) - weights_grad = weights_grad.squeeze(-1) # (*, K, 1) -> (*, K) - lookup_grad = weights * ans_grad.unsqueeze(-2) # (*, K, 1) * (*, 1, D) = (*, K, D) + weights_grad = torch.matmul( + lookup, ans_grad.unsqueeze(-1) # (*, K, D) + ) # (*, D, 1) + weights_grad = weights_grad.squeeze(-1) # (*, K, 1) -> (*, K) + lookup_grad = weights * ans_grad.unsqueeze( + -2 + ) # (*, K, 1) * (*, 1, D) = (*, K, D) lookup.backward(gradient=lookup_grad) return weights_grad.to(dtype), None, knowledge_base.grad.to(dtype) @@ -146,6 +149,7 @@ class PenalizeNegentropyFunction(torch.autograd.Function): Returns: logprobs """ + @staticmethod def forward(ctx, logprobs: Tensor, alpha: float): ctx.save_for_backward(logprobs.detach()) @@ -154,18 +158,23 @@ def forward(ctx, logprobs: Tensor, alpha: float): @staticmethod def backward(ctx, logprobs_grad: Tensor) -> Tuple[Tensor, None]: - logprobs, = ctx.saved_tensors + (logprobs,) = ctx.saved_tensors with torch.enable_grad(): logprobs.requires_grad = True # `negentropy` is the negative entropy of the average distribution. # distributions. It will be <= 0. - l = logprobs.reshape(-1, logprobs.shape[-1]) + l = logprobs.reshape(-1, logprobs.shape[-1]) # noqa: E741 scale = ctx.alpha * l.shape[0] avg_dist = l.exp().mean(dim=0) negentropy = (avg_dist * (avg_dist + 1.0e-20).log()).sum() if random.random() < 0.0005: negentropy_individual = (l * l.exp()).sum(dim=-1).mean() - print("Negentropy[individual,combined] = ", negentropy_individual.item(), ", ", negentropy.item()) + print( + "Negentropy[individual,combined] = ", + negentropy_individual.item(), + ", ", + negentropy.item(), + ) loss = negentropy * scale loss.backward() return logprobs_grad + logprobs.grad, None @@ -183,18 +192,23 @@ class KnowledgeBaseLookup(nn.Module): embedding_dim: the dimension to project from and to, e.g. the d_model of the conformer. """ - def __init__(self, M: int, N: int, D: int, - K: int, embedding_dim: int, - knowledge_base: nn.Parameter, - negentropy_penalty: float = 0.001): + + def __init__( + self, + M: int, + N: int, + D: int, + K: int, + embedding_dim: int, + knowledge_base: nn.Parameter, + negentropy_penalty: float = 0.001, + ): super(KnowledgeBaseLookup, self).__init__() self.knowledge_base = knowledge_base # shared! - self.in_proj = ScaledLinear(embedding_dim, M * N, - initial_scale=1.0) + self.in_proj = ScaledLinear(embedding_dim, M * N, initial_scale=1.0) # initial_scale = 4.0 because the knowlege_base activations are # quite small -- if we use our optimizer they'll have stddev <= 0.1. - self.out_proj = ScaledLinear(D, embedding_dim, - initial_scale = 4.0) + self.out_proj = ScaledLinear(D, embedding_dim, initial_scale=4.0) self.M = M self.N = N self.K = K @@ -210,14 +224,14 @@ def forward(self, x: Tensor) -> Tensor: # TODO: later we can try multiplying by a projection of x or something like that. """ - x = self.in_proj(x) # now (*, M*N) - x = x.reshape(*x.shape[:-1], self.N, self.M) # now (*, N, M) - x = x.log_softmax(dim=-1) # now normalized logprobs, dim= (*, N, M) + x = self.in_proj(x) # now (*, M*N) + x = x.reshape(*x.shape[:-1], self.N, self.M) # now (*, N, M) + x = x.log_softmax(dim=-1) # now normalized logprobs, dim= (*, N, M) x = PenalizeNegentropyFunction.apply(x, self.negentropy_penalty) _, indexes, weights = sample_combined(x, self.K, input_is_log=True) - x = weighted_matrix_lookup(weights, indexes, self.knowledge_base) # now (*, D) - x = self.out_proj(x) # now (*, self.embedding_dim) + x = weighted_matrix_lookup(weights, indexes, self.knowledge_base) # now (*, D) + x = self.out_proj(x) # now (*, self.embedding_dim) return x @@ -237,38 +251,44 @@ def _test_knowledge_base_lookup(): x.requires_grad = True y = m(x) assert y.shape == x.shape - y.sum().backward() # make sure backward doesn't crash.. + y.sum().backward() # make sure backward doesn't crash.. print("y = ", y) print("x.grad = ", x.grad) print("knowlege_base.grad norm = ", knowledge_base.grad.norm()) dtype = torch.float32 - device = torch.device('cuda') - train_pairs = [ (torch.randn(B, T, E, device=device, dtype=dtype), torch.randn(B, T, E, device=device, dtype=dtype)) for _ in range(10) ] + device = torch.device("cuda") + train_pairs = [ + ( + torch.randn(B, T, E, device=device, dtype=dtype), + torch.randn(B, T, E, device=device, dtype=dtype), + ) + for _ in range(10) + ] from optim import Eve + optimizer = Eve(m.parameters(), lr=0.005, eps=1.0e-04) m = m.to(device).to(dtype) - start = timeit.default_timer() -# Epoch 0, batch 0, loss 1.0109944343566895 -# Epoch 10, batch 0, loss 1.0146660804748535 -# Epoch 20, batch 0, loss 1.0119813680648804 -# Epoch 30, batch 0, loss 1.0105408430099487 -# Epoch 40, batch 0, loss 1.0077732801437378 -# Epoch 50, batch 0, loss 1.0050103664398193 -# Epoch 60, batch 0, loss 1.0033129453659058 -# Epoch 70, batch 0, loss 1.0014232397079468 -# Epoch 80, batch 0, loss 0.9977912306785583 -# Epoch 90, batch 0, loss 0.8274348974227905 -# Epoch 100, batch 0, loss 0.3368612825870514 -# Epoch 110, batch 0, loss 0.11323091387748718 -# Time taken: 17.591704960912466 + # Epoch 0, batch 0, loss 1.0109944343566895 + # Epoch 10, batch 0, loss 1.0146660804748535 + # Epoch 20, batch 0, loss 1.0119813680648804 + # Epoch 30, batch 0, loss 1.0105408430099487 + # Epoch 40, batch 0, loss 1.0077732801437378 + # Epoch 50, batch 0, loss 1.0050103664398193 + # Epoch 60, batch 0, loss 1.0033129453659058 + # Epoch 70, batch 0, loss 1.0014232397079468 + # Epoch 80, batch 0, loss 0.9977912306785583 + # Epoch 90, batch 0, loss 0.8274348974227905 + # Epoch 100, batch 0, loss 0.3368612825870514 + # Epoch 110, batch 0, loss 0.11323091387748718 + # Time taken: 17.591704960912466 for epoch in range(150): - for n, (x,y) in enumerate(train_pairs): + for n, (x, y) in enumerate(train_pairs): y_out = m(x) - loss = ((y_out - y)**2).mean() * 100.0 + loss = ((y_out - y) ** 2).mean() * 100.0 if n % 10 == 0 and epoch % 10 == 0: print(f"Epoch {epoch}, batch {n}, loss {loss.item()}") loss.backward() @@ -276,7 +296,8 @@ def _test_knowledge_base_lookup(): optimizer.zero_grad() stop = timeit.default_timer() - print('Time taken: ', stop - start) + print("Time taken: ", stop - start) + def _test_knowledge_base_lookup_autocast(): K = 16 @@ -294,14 +315,18 @@ def _test_knowledge_base_lookup_autocast(): x.requires_grad = True y = m(x) assert y.shape == x.shape - y.sum().backward() # make sure backward doesn't crash.. + y.sum().backward() # make sure backward doesn't crash.. print("y = ", y) print("x.grad = ", x.grad) print("knowlege_base.grad norm = ", knowledge_base.grad.norm()) - device = torch.device('cuda') - train_pairs = [ (torch.randn(B, T, E, device=device), torch.randn(B, T, E, device=device)) for _ in range(10) ] + device = torch.device("cuda") + train_pairs = [ + (torch.randn(B, T, E, device=device), torch.randn(B, T, E, device=device)) + for _ in range(10) + ] from optim import Eve + optimizer = Eve(m.parameters(), lr=0.005, eps=1.0e-04) m = m.to(device) @@ -309,12 +334,11 @@ def _test_knowledge_base_lookup_autocast(): start = timeit.default_timer() - for epoch in range(150): - for n, (x,y) in enumerate(train_pairs): + for n, (x, y) in enumerate(train_pairs): y_out = m(x) with torch.cuda.amp.autocast(enabled=True): - loss = ((y_out - y)**2).mean() * 100.0 + loss = ((y_out - y) ** 2).mean() * 100.0 if n % 10 == 0 and epoch % 10 == 0: print(f"Epoch {epoch}, batch {n}, loss {loss.item()}") scaler.scale(loss).backward() @@ -323,10 +347,9 @@ def _test_knowledge_base_lookup_autocast(): optimizer.zero_grad() stop = timeit.default_timer() - print('Time taken: ', stop - start) - + print("Time taken: ", stop - start) -if __name__ == '__main__': +if __name__ == "__main__": _test_knowledge_base_lookup() _test_knowledge_base_lookup_autocast() diff --git a/egs/librispeech/ASR/pruned2_knowledge/scaling.py b/egs/librispeech/ASR/pruned2_knowledge/scaling.py index f726c25832..527c735ebb 100644 --- a/egs/librispeech/ASR/pruned2_knowledge/scaling.py +++ b/egs/librispeech/ASR/pruned2_knowledge/scaling.py @@ -18,11 +18,11 @@ import collections from itertools import repeat from typing import Optional, Tuple -from torch.cuda.amp import custom_fwd, custom_bwd import torch import torch.nn as nn from torch import Tensor +from torch.cuda.amp import custom_bwd, custom_fwd def _ntuple(n): @@ -79,9 +79,7 @@ def forward( below_threshold = mean_abs < min_abs above_threshold = mean_abs > max_abs - ctx.save_for_backward( - factor, xgt0, below_threshold, above_threshold - ) + ctx.save_for_backward(factor, xgt0, below_threshold, above_threshold) ctx.max_factor = max_factor ctx.sum_dims = sum_dims return x @@ -149,8 +147,7 @@ def __init__( def forward(self, x: Tensor) -> Tensor: assert x.shape[self.channel_dim] == self.num_channels scales = ( - torch.mean(x ** 2, dim=self.channel_dim, keepdim=True) - + self.eps.exp() + torch.mean(x**2, dim=self.channel_dim, keepdim=True) + self.eps.exp() ) ** -0.5 return x * scales @@ -182,11 +179,7 @@ class ScaledLinear(nn.Linear): """ def __init__( - self, - *args, - initial_scale: float = 1.0, - initial_speed: float = 1.0, - **kwargs + self, *args, initial_scale: float = 1.0, initial_speed: float = 1.0, **kwargs ): super(ScaledLinear, self).__init__(*args, **kwargs) initial_scale = torch.tensor(initial_scale).log() @@ -202,12 +195,12 @@ def __init__( def _reset_parameters(self, initial_speed: float): std = 0.1 / initial_speed - a = (3 ** 0.5) * std + a = (3**0.5) * std nn.init.uniform_(self.weight, -a, a) if self.bias is not None: nn.init.constant_(self.bias, 0.0) fan_in = self.weight.shape[1] * self.weight[0][0].numel() - scale = fan_in ** -0.5 # 1/sqrt(fan_in) + scale = fan_in**-0.5 # 1/sqrt(fan_in) with torch.no_grad(): self.weight_scale += torch.tensor(scale / std).log() @@ -218,19 +211,13 @@ def get_bias(self): return None if self.bias is None else self.bias * self.bias_scale.exp() def forward(self, input: Tensor) -> Tensor: - return torch.nn.functional.linear( - input, self.get_weight(), self.get_bias() - ) + return torch.nn.functional.linear(input, self.get_weight(), self.get_bias()) class ScaledConv1d(nn.Conv1d): # See docs for ScaledLinear def __init__( - self, - *args, - initial_scale: float = 1.0, - initial_speed: float = 1.0, - **kwargs + self, *args, initial_scale: float = 1.0, initial_speed: float = 1.0, **kwargs ): super(ScaledConv1d, self).__init__(*args, **kwargs) initial_scale = torch.tensor(initial_scale).log() @@ -245,12 +232,12 @@ def __init__( def _reset_parameters(self, initial_speed: float): std = 0.1 / initial_speed - a = (3 ** 0.5) * std + a = (3**0.5) * std nn.init.uniform_(self.weight, -a, a) if self.bias is not None: nn.init.constant_(self.bias, 0.0) fan_in = self.weight.shape[1] * self.weight[0][0].numel() - scale = fan_in ** -0.5 # 1/sqrt(fan_in) + scale = fan_in**-0.5 # 1/sqrt(fan_in) with torch.no_grad(): self.weight_scale += torch.tensor(scale / std).log() @@ -290,11 +277,7 @@ def forward(self, input: Tensor) -> Tensor: class ScaledConv2d(nn.Conv2d): # See docs for ScaledLinear def __init__( - self, - *args, - initial_scale: float = 1.0, - initial_speed: float = 1.0, - **kwargs + self, *args, initial_scale: float = 1.0, initial_speed: float = 1.0, **kwargs ): super(ScaledConv2d, self).__init__(*args, **kwargs) initial_scale = torch.tensor(initial_scale).log() @@ -309,12 +292,12 @@ def __init__( def _reset_parameters(self, initial_speed: float): std = 0.1 / initial_speed - a = (3 ** 0.5) * std + a = (3**0.5) * std nn.init.uniform_(self.weight, -a, a) if self.bias is not None: nn.init.constant_(self.bias, 0.0) fan_in = self.weight.shape[1] * self.weight[0][0].numel() - scale = fan_in ** -0.5 # 1/sqrt(fan_in) + scale = fan_in**-0.5 # 1/sqrt(fan_in) with torch.no_grad(): self.weight_scale += torch.tensor(scale / std).log() @@ -653,9 +636,7 @@ def _test_activation_balancer_sign(): def _test_activation_balancer_magnitude(): magnitudes = torch.arange(0, 1, 0.01) N = 1000 - x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze( - -1 - ) + x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze(-1) x = x.detach() x.requires_grad = True m = ActivationBalancer( @@ -685,8 +666,8 @@ def _test_basic_norm(): y = m(x) assert y.shape == x.shape - x_rms = (x ** 2).mean().sqrt() - y_rms = (y ** 2).mean().sqrt() + x_rms = (x**2).mean().sqrt() + y_rms = (y**2).mean().sqrt() print("x rms = ", x_rms) print("y rms = ", y_rms) assert y_rms < x_rms diff --git a/egs/librispeech/ASR/pruned2_knowledge/scaling_tmp.py b/egs/librispeech/ASR/pruned2_knowledge/scaling_tmp.py index 6293e081a3..3f21133a08 100644 --- a/egs/librispeech/ASR/pruned2_knowledge/scaling_tmp.py +++ b/egs/librispeech/ASR/pruned2_knowledge/scaling_tmp.py @@ -15,21 +15,23 @@ # limitations under the License. +from typing import Optional, Tuple + import torch import torch.nn as nn from torch import Tensor -from typing import Tuple, Optional - -def _activation_balancer_loss(mean_pos: Tensor, - mean_neg: Tensor, - min_positive: float, # e.g. 0.05 - max_positive: float, # e.g. 0.95 - max_factor: float, # e.g. 0.01 - min_abs: float, # e.g. 0.2 - max_abs: float, # e.g. 100.0 - eps: float = 1.0e-10): +def _activation_balancer_loss( + mean_pos: Tensor, + mean_neg: Tensor, + min_positive: float, # e.g. 0.05 + max_positive: float, # e.g. 0.95 + max_factor: float, # e.g. 0.01 + min_abs: float, # e.g. 0.2 + max_abs: float, # e.g. 100.0 + eps: float = 1.0e-10, +): """ Returns a loss-function for the ActivationBalancer module. This loss function is not exposed to the user but is used internally, and eventually @@ -50,28 +52,32 @@ def _activation_balancer_loss(mean_pos: Tensor, """ loss_parts = [] - x_mean = mean_positive - mean_negative - x_mean_abs = (mean_positive + mean_negative + eps).detach() - x_rel_mean= x_mean / x_mean_abs + x_mean = mean_pos - mean_neg + x_mean_abs = (mean_pos + mean_neg + eps).detach() + x_rel_mean = x_mean / x_mean_abs if min_positive != 0.0: # e.g. x_mean_floor = -0.95 + 0.05 = -0.9 - x_rel_mean_floor = (-(1-min_positive) + min_positive) - min_positive_loss = (x_rel_mean_floor - x_rel_mean).relu().sum() * (1.0 / (2*min_positive)) + x_rel_mean_floor = -(1 - min_positive) + min_positive + min_positive_loss = (x_rel_mean_floor - x_rel_mean).relu().sum() * ( + 1.0 / (2 * min_positive) + ) # this part of the loss would be 1.0 * num_channels if all these constraints were # 100% violated. loss_parts.append(min_positive_loss) if max_positive != 1.0: # e.g. x_mean_floor = -0.05 + 0.95 = 0.8 - x_rel_mean_ceil = - (1.0-max_positive) + max_positive - max_positive_loss = (x_rel_mean - x_rel_mean_ceil).relu().sum() * (1.0 / (1 - x_rel_mean_ceil)) + x_rel_mean_ceil = -(1.0 - max_positive) + max_positive + max_positive_loss = (x_rel_mean - x_rel_mean_ceil).relu().sum() * ( + 1.0 / (1 - x_rel_mean_ceil) + ) # this part of the loss would be 1.0 * num_channels if all these constraints were # 100% violated. loss_parts.append(max_positive_loss) if min_abs != 0.0: - min_abs_loss = min_abs - x_mean_abs).relu().sum() / min_abs + min_abs_loss = (min_abs - x_mean_abs).relu().sum() / min_abs # this part of the loss would be 1.0 * num_channels if all these constraints were # 100% violated. loss_parts.append(min_abs_loss) @@ -82,43 +88,53 @@ def _activation_balancer_loss(mean_pos: Tensor, # 100% violated. loss_parts.append(max_abs_loss) - # the min_positive and 1 - max_positive are "ballast" added to the denom = mean_pos + mean_neg + (min_positive + (1 - max_positive)) - num + # num if min_positive != 0.0: - - + pass class ActivationBalancerFunction(torch.autograd.Function): @staticmethod - def forward(ctx, x: Tensor, - channel_dim: int, - min_positive: float, # e.g. 0.05 - max_positive: float, # e.g. 0.95 - max_factor: float, # e.g. 0.01 - min_abs: float, # e.g. 0.2 - max_abs: float, # e.g. 100.0 + def forward( + ctx, + x: Tensor, + channel_dim: int, + min_positive: float, # e.g. 0.05 + max_positive: float, # e.g. 0.95 + max_factor: float, # e.g. 0.01 + min_abs: float, # e.g. 0.2 + max_abs: float, # e.g. 100.0 ) -> Tensor: if x.requires_grad: if channel_dim < 0: channel_dim += x.ndim sum_dims = [d for d in range(x.ndim) if d != channel_dim] xgt0 = x > 0 - proportion_positive = torch.mean(xgt0.to(x.dtype), dim=sum_dims, keepdim=True) - factor1 = ((min_positive - proportion_positive).relu() * (max_factor / min_positive) - if min_positive != 0.0 else 0.0) - factor2 = ((proportion_positive - max_positive).relu() * (max_factor / (max_positive - 1.0)) - if max_positive != 1.0 else 0.0) + proportion_positive = torch.mean( + xgt0.to(x.dtype), dim=sum_dims, keepdim=True + ) + factor1 = ( + (min_positive - proportion_positive).relu() + * (max_factor / min_positive) + if min_positive != 0.0 + else 0.0 + ) + factor2 = ( + (proportion_positive - max_positive).relu() + * (max_factor / (max_positive - 1.0)) + if max_positive != 1.0 + else 0.0 + ) factor = factor1 + factor2 if isinstance(factor, float): factor = torch.zeros_like(proportion_positive) mean_abs = torch.mean(x.abs(), dim=sum_dims, keepdim=True) - below_threshold = (mean_abs < min_abs) - above_threshold = (mean_abs > max_abs) + below_threshold = mean_abs < min_abs + above_threshold = mean_abs > max_abs ctx.save_for_backward(factor, xgt0, below_threshold, above_threshold) ctx.max_factor = max_factor @@ -126,11 +142,16 @@ def forward(ctx, x: Tensor, return x @staticmethod - def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None, None, None, None]: + def backward( + ctx, x_grad: Tensor + ) -> Tuple[Tensor, None, None, None, None, None, None]: factor, xgt0, below_threshold, above_threshold = ctx.saved_tensors dtype = x_grad.dtype - scale_factor = ((below_threshold.to(dtype) - above_threshold.to(dtype)) * - (xgt0.to(dtype) - 0.5) * (ctx.max_factor * 2.0)) + scale_factor = ( + (below_threshold.to(dtype) - above_threshold.to(dtype)) + * (xgt0.to(dtype) - 0.5) + * (ctx.max_factor * 2.0) + ) neg_delta_grad = x_grad.abs() * (factor + scale_factor) return x_grad - neg_delta_grad, None, None, None, None, None, None @@ -163,29 +184,30 @@ class BasicNorm(torch.nn.Module): learn_eps: if true, we learn epsilon; if false, we keep it at the initial value. """ - def __init__(self, - num_channels: int, - channel_dim: int = -1, # CAUTION: see documentation. - eps: float = 0.25, - learn_eps: bool = True) -> None: + + def __init__( + self, + num_channels: int, + channel_dim: int = -1, # CAUTION: see documentation. + eps: float = 0.25, + learn_eps: bool = True, + ) -> None: super(BasicNorm, self).__init__() self.num_channels = num_channels self.channel_dim = channel_dim if learn_eps: self.eps = nn.Parameter(torch.tensor(eps).log().detach()) else: - self.register_buffer('eps', torch.tensor(eps).log().detach()) - + self.register_buffer("eps", torch.tensor(eps).log().detach()) def forward(self, x: Tensor) -> Tensor: assert x.shape[self.channel_dim] == self.num_channels - scales = (torch.mean(x**2, dim=self.channel_dim, keepdim=True) + - self.eps.exp()) ** -0.5 + scales = ( + torch.mean(x**2, dim=self.channel_dim, keepdim=True) + self.eps.exp() + ) ** -0.5 return x * scales - - class ScaledLinear(nn.Linear): """ A modified version of nn.Linear where the parameters are scaled before @@ -207,27 +229,26 @@ class ScaledLinear(nn.Linear): inherited from nn.Linear. For modules with small fan-in, this may be larger than optimal. """ - def __init__(self, *args, - initial_scale: float = 1.0, - **kwargs): + + def __init__(self, *args, initial_scale: float = 1.0, **kwargs): super(ScaledLinear, self).__init__(*args, **kwargs) initial_scale = torch.tensor(initial_scale).log() self.weight_scale = nn.Parameter(initial_scale.clone().detach()) if self.bias is not None: self.bias_scale = nn.Parameter(initial_scale.clone().detach()) else: - self.register_parameter('bias_scale', None) + self.register_parameter("bias_scale", None) self._reset_parameters() # Overrides the reset_parameters in nn.Linear def _reset_parameters(self): std = 0.01 - a = (3 ** 0.5) * std + a = (3**0.5) * std nn.init.uniform_(self.weight, -a, a) if self.bias is not None: nn.init.constant_(self.bias, 0.0) fan_in = self.weight.shape[1] * self.weight[0][0].numel() - scale = fan_in ** -0.5 # 1/sqrt(fan_in) + scale = fan_in**-0.5 # 1/sqrt(fan_in) with torch.no_grad(): self.weight_scale += torch.tensor(scale / std).log() if self.bias is not None: @@ -237,56 +258,67 @@ def get_weight(self): return self.weight * self.weight_scale.exp() def get_bias(self): - return (None if self.bias is None else - self.bias * self.bias_scale.exp()) + return None if self.bias is None else self.bias * self.bias_scale.exp() def forward(self, input: Tensor) -> Tensor: - return torch.nn.functional.linear(input, self.get_weight(), - self.get_bias()) + return torch.nn.functional.linear(input, self.get_weight(), self.get_bias()) class ScaledConv1d(nn.Conv1d): - def __init__(self, *args, - initial_scale=1.0, **kwargs): + def __init__(self, *args, initial_scale=1.0, **kwargs): super(ScaledConv1d, self).__init__(*args, **kwargs) initial_scale = torch.tensor(initial_scale).log() self.weight_scale = nn.Parameter(initial_scale.clone().detach()) if self.bias is not None: self.bias_scale = nn.Parameter(initial_scale.clone().detach()) else: - self.register_parameter('bias_scale', None) + self.register_parameter("bias_scale", None) self._reset_parameters() # Overrides the reset_parameters in base class def _reset_parameters(self): std = 0.01 - a = (3 ** 0.5) * std + a = (3**0.5) * std nn.init.uniform_(self.weight, -a, a) if self.bias is not None: nn.init.constant_(self.bias, 0.0) fan_in = self.weight.shape[1] * self.weight[0][0].numel() - scale = fan_in ** -0.5 # 1/sqrt(fan_in) + scale = fan_in**-0.5 # 1/sqrt(fan_in) with torch.no_grad(): self.weight_scale += torch.tensor(scale / std).log() if self.bias is not None: self.bias_scale += torch.tensor(scale / std).log() - def get_weight(self): return self.weight * self.weight_scale.exp() def get_bias(self): - return (None if self.bias is None else - self.bias * self.bias_scale.exp()) + return None if self.bias is None else self.bias * self.bias_scale.exp() def forward(self, input: Tensor) -> Tensor: F = torch.nn.functional - if self.padding_mode != 'zeros': - return F.conv1d(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode), - self.get_weight(), self.get_bias(), self.stride, - _single(0), self.dilation, self.groups) - return F.conv1d(input, self.get_weight(), self.get_bias(), self.stride, - self.padding, self.dilation, self.groups) - + if self.padding_mode != "zeros": + return F.conv1d( + F.pad( + input, + self._reversed_padding_repeated_twice, + mode=self.padding_mode, + ), + self.get_weight(), + self.get_bias(), + self.stride, + _single(0), # noqa: F821 + self.dilation, + self.groups, + ) + return F.conv1d( + input, + self.get_weight(), + self.get_bias(), + self.stride, + self.padding, + self.dilation, + self.groups, + ) class ScaledConv2d(nn.Conv2d): @@ -297,45 +329,58 @@ def __init__(self, *args, initial_scale=1.0, **kwargs): if self.bias is not None: self.bias_scale = nn.Parameter(initial_scale.clone().detach()) else: - self.register_parameter('bias_scale', None) + self.register_parameter("bias_scale", None) self._reset_parameters() # Overrides the reset_parameters in base class def _reset_parameters(self): std = 0.01 - a = (3 ** 0.5) * std + a = (3**0.5) * std nn.init.uniform_(self.weight, -a, a) if self.bias is not None: nn.init.constant_(self.bias, 0.0) fan_in = self.weight.shape[1] * self.weight[0][0].numel() - scale = fan_in ** -0.5 # 1/sqrt(fan_in) + scale = fan_in**-0.5 # 1/sqrt(fan_in) with torch.no_grad(): self.weight_scale += torch.tensor(scale / std).log() if self.bias is not None: self.bias_scale += torch.tensor(scale / std).log() - def get_weight(self): return self.weight * self.weight_scale.exp() def get_bias(self): - return (None if self.bias is None else - self.bias * self.bias_scale.exp()) + return None if self.bias is None else self.bias * self.bias_scale.exp() def _conv_forward(self, input, weight): F = torch.nn.functional - if self.padding_mode != 'zeros': - return F.conv2d(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode), - weight, self.get_bias(), self.stride, - _pair(0), self.dilation, self.groups) - return F.conv2d(input, weight, self.get_bias(), self.stride, - self.padding, self.dilation, self.groups) + if self.padding_mode != "zeros": + return F.conv2d( + F.pad( + input, + self._reversed_padding_repeated_twice, + mode=self.padding_mode, + ), + weight, + self.get_bias(), + self.stride, + _pair(0), # noqa: F821 + self.dilation, + self.groups, + ) + return F.conv2d( + input, + weight, + self.get_bias(), + self.stride, + self.padding, + self.dilation, + self.groups, + ) def forward(self, input: Tensor) -> Tensor: return self._conv_forward(input, self.get_weight()) - - class ActivationBalancer(torch.nn.Module): """ Modifies the backpropped derivatives of a function to try to encourage, for @@ -364,12 +409,16 @@ class ActivationBalancer(torch.nn.Module): we allow, before we start to modify the derivatives to prevent this. """ - def __init__(self, channel_dim: int, - min_positive: float = 0.05, - max_positive: float = 0.95, - max_factor: float = 0.01, - min_abs: float = 0.2, - max_abs: float = 100.0): + + def __init__( + self, + channel_dim: int, + min_positive: float = 0.05, + max_positive: float = 0.95, + max_factor: float = 0.01, + min_abs: float = 0.2, + max_abs: float = 100.0, + ): super(ActivationBalancer, self).__init__() self.channel_dim = channel_dim self.min_positive = min_positive @@ -379,10 +428,15 @@ def __init__(self, channel_dim: int, self.max_abs = max_abs def forward(self, x: Tensor) -> Tensor: - return ActivationBalancerFunction.apply(x, self.channel_dim, - self.min_positive, self.max_positive, - self.max_factor, self.min_abs, - self.max_abs) + return ActivationBalancerFunction.apply( + x, + self.channel_dim, + self.min_positive, + self.max_positive, + self.max_factor, + self.min_abs, + self.max_abs, + ) class DoubleSwishFunction(torch.autograd.Function): @@ -400,6 +454,7 @@ class DoubleSwishFunction(torch.autograd.Function): = double_swish(x) * (1-s(x)) + s(x) ... so we just need to remember s(x) but not x itself. """ + @staticmethod def forward(ctx, x: Tensor) -> Tensor: x = x.detach() @@ -411,18 +466,17 @@ def forward(ctx, x: Tensor) -> Tensor: @staticmethod def backward(ctx, y_grad: Tensor) -> Tensor: s, y = ctx.saved_tensors - return (y * (1-s) + s) * y_grad + return (y * (1 - s) + s) * y_grad + class DoubleSwish(torch.nn.Module): def forward(self, x: Tensor) -> Tensor: """Return double-swish activation function which is an approximation to Swish(Swish(x)), - that we approximate closely with x * sigmoid(x-1). + that we approximate closely with x * sigmoid(x-1). """ return DoubleSwishFunction.apply(x) - - class ScaledEmbedding(nn.Module): r"""A simple lookup table that stores embeddings of a fixed dictionary and size. @@ -491,8 +545,13 @@ class ScaledEmbedding(nn.Module): [ 0.0000, 0.0000, 0.0000], [-0.1655, 0.9897, 0.0635]]]) """ - __constants__ = ['num_embeddings', 'embedding_dim', 'padding_idx', - 'scale_grad_by_freq', 'sparse'] + __constants__ = [ + "num_embeddings", + "embedding_dim", + "padding_idx", + "scale_grad_by_freq", + "sparse", + ] num_embeddings: int embedding_dim: int @@ -501,33 +560,40 @@ class ScaledEmbedding(nn.Module): weight: Tensor sparse: bool - def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None, - scale_grad_by_freq: bool = False, - sparse: bool = False) -> None: + def __init__( + self, + num_embeddings: int, + embedding_dim: int, + padding_idx: Optional[int] = None, + scale_grad_by_freq: bool = False, + sparse: bool = False, + ) -> None: super(ScaledEmbedding, self).__init__() self.num_embeddings = num_embeddings self.embedding_dim = embedding_dim if padding_idx is not None: if padding_idx > 0: - assert padding_idx < self.num_embeddings, 'Padding_idx must be within num_embeddings' + assert ( + padding_idx < self.num_embeddings + ), "Padding_idx must be within num_embeddings" elif padding_idx < 0: - assert padding_idx >= -self.num_embeddings, 'Padding_idx must be within num_embeddings' + assert ( + padding_idx >= -self.num_embeddings + ), "Padding_idx must be within num_embeddings" padding_idx = self.num_embeddings + padding_idx self.padding_idx = padding_idx self.scale_grad_by_freq = scale_grad_by_freq - self.scale = nn.Parameter(torch.zeros(())) # see reset_parameters() + self.scale = nn.Parameter(torch.zeros(())) # see reset_parameters() self.sparse = sparse self.weight = nn.Parameter(torch.Tensor(num_embeddings, embedding_dim)) self.reset_parameters() - - def reset_parameters(self) -> None: std = 0.01 nn.init.normal_(self.weight, std=std) - nn.init.constant_(self.scale, torch.tensor(1.0/std).log()) + nn.init.constant_(self.scale, torch.tensor(1.0 / std).log()) if self.padding_idx is not None: with torch.no_grad(): @@ -537,24 +603,37 @@ def forward(self, input: Tensor) -> Tensor: F = torch.nn.functional scale = self.scale.exp() if input.numel() < self.num_embeddings: - return F.embedding( - input, self.weight, self.padding_idx, - None, 2.0, # None, 2.0 relate to normalization - self.scale_grad_by_freq, self.sparse) * scale + return ( + F.embedding( + input, + self.weight, + self.padding_idx, + None, + 2.0, # None, 2.0 relate to normalization + self.scale_grad_by_freq, + self.sparse, + ) + * scale + ) else: return F.embedding( - input, self.weight * scale, self.padding_idx, - None, 2.0, # None, 2.0 relates to normalization - self.scale_grad_by_freq, self.sparse) + input, + self.weight * scale, + self.padding_idx, + None, + 2.0, # None, 2.0 relates to normalization + self.scale_grad_by_freq, + self.sparse, + ) def extra_repr(self) -> str: - s = '{num_embeddings}, {embedding_dim}, scale={scale}' + s = "{num_embeddings}, {embedding_dim}, scale={scale}" if self.padding_idx is not None: - s += ', padding_idx={padding_idx}' + s += ", padding_idx={padding_idx}" if self.scale_grad_by_freq is not False: - s += ', scale_grad_by_freq={scale_grad_by_freq}' + s += ", scale_grad_by_freq={scale_grad_by_freq}" if self.sparse is not False: - s += ', sparse=True' + s += ", sparse=True" return s.format(**self.__dict__) @@ -565,8 +644,13 @@ def _test_activation_balancer_sign(): x = 1.0 * (torch.rand(probs.numel(), N) < probs.unsqueeze(-1)) x = x.detach() x.requires_grad = True - m = ActivationBalancer(channel_dim=0, min_positive=0.05, max_positive=0.95, - max_factor=0.2, min_abs=0.0) + m = ActivationBalancer( + channel_dim=0, + min_positive=0.05, + max_positive=0.95, + max_factor=0.2, + min_abs=0.0, + ) y_grad = torch.sign(torch.randn(probs.numel(), N)) @@ -576,17 +660,22 @@ def _test_activation_balancer_sign(): print("_test_activation_balancer_sign: y grad = ", y_grad) print("_test_activation_balancer_sign: x grad = ", x.grad) + def _test_activation_balancer_magnitude(): channel_dim = 0 magnitudes = torch.arange(0, 1, 0.01) N = 1000 - x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze(-1) + x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze(-1) x = x.detach() x.requires_grad = True - m = ActivationBalancer(channel_dim=0, - min_positive=0.0, max_positive=1.0, - max_factor=0.2, - min_abs=0.2, max_abs=0.8) + m = ActivationBalancer( + channel_dim=0, + min_positive=0.0, + max_positive=1.0, + max_factor=0.2, + min_abs=0.2, + max_abs=0.8, + ) y_grad = torch.sign(torch.randn(magnitudes.numel(), N)) @@ -621,7 +710,7 @@ def _test_double_swish_deriv(): torch.autograd.gradcheck(m, x) -if __name__ == '__main__': +if __name__ == "__main__": _test_activation_balancer_sign() _test_activation_balancer_magnitude() _test_basic_norm() diff --git a/egs/librispeech/ASR/pruned2_knowledge/train.py b/egs/librispeech/ASR/pruned2_knowledge/train.py index 2f68401667..c322abaf89 100755 --- a/egs/librispeech/ASR/pruned2_knowledge/train.py +++ b/egs/librispeech/ASR/pruned2_knowledge/train.py @@ -78,9 +78,7 @@ from icefall.env import get_env_info from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool -LRSchedulerType = Union[ - torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler -] +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] def get_parser(): @@ -179,8 +177,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( @@ -203,8 +200,7 @@ def get_parser(): "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)" - "part.", + help="The scale to smooth the loss with am (output of encoder network)" "part.", ) parser.add_argument( @@ -554,23 +550,16 @@ def compute_loss( # overwhelming the simple_loss and causing it to diverge, # in case it had not fully learned the alignment yet. pruned_loss_scale = ( - 0.0 - if warmup < 1.0 - else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) - ) - loss = ( - params.simple_loss_scale * simple_loss - + pruned_loss_scale * pruned_loss + 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) ) + loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss assert loss.requires_grad == is_training info = MetricsTracker() with warnings.catch_warnings(): warnings.simplefilter("ignore") - info["frames"] = ( - (feature_lens // params.subsampling_factor).sum().item() - ) + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() # Note: We use reduction=sum while computing the loss. info["loss"] = loss.detach().cpu().item() @@ -733,9 +722,7 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -835,7 +822,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2 ** 22 + 2**22 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/decode.py b/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/decode.py index 2d5724d306..891719f3da 100755 --- a/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/decode.py +++ b/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/decode.py @@ -204,8 +204,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -272,9 +271,7 @@ def decode_one_batch( value=LOG_EPS, ) - encoder_out, encoder_out_lens = model.encoder( - x=feature, x_lens=feature_lens - ) + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) hyps = [] if params.decoding_method == "fast_beam_search": @@ -289,10 +286,7 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif ( - params.decoding_method == "greedy_search" - and params.max_sym_per_frame == 1 - ): + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -415,9 +409,7 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -450,8 +442,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -494,9 +485,7 @@ def main(): params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-states-{params.max_states}" elif "beam_search" in params.decoding_method: - params.suffix += ( - f"-{params.decoding_method}-beam-size-{params.beam_size}" - ) + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -528,9 +517,9 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -557,9 +546,9 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" diff --git a/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/emformer.py b/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/emformer.py index 318cd5094d..008f40fb1d 100644 --- a/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/emformer.py +++ b/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/emformer.py @@ -272,13 +272,9 @@ def streaming_forward( # Caution: We assume the subsampling factor is 4! x_lens = (((x_lens - 1) >> 1) - 1) >> 1 - emformer_out, emformer_out_lens, states = self.model.infer( - x, x_lens, states - ) + emformer_out, emformer_out_lens, states = self.model.infer(x, x_lens, states) - if x.size(1) != ( - self.model.segment_length + self.model.right_context_length - ): + if x.size(1) != (self.model.segment_length + self.model.right_context_length): raise ValueError( "Incorrect input shape." f"{x.size(1)} vs {self.model.segment_length} + " diff --git a/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/export.py b/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/export.py index 2375f50018..047a1d4763 100755 --- a/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/export.py +++ b/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/export.py @@ -133,8 +133,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) add_model_arguments(parser) @@ -170,9 +169,9 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -199,9 +198,9 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -273,9 +272,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/model.py b/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/model.py index 2f019bcdbd..ed6848879a 100644 --- a/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/model.py +++ b/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/model.py @@ -122,9 +122,7 @@ def forward( y_padded = y.pad(mode="constant", padding_value=0) y_padded = y_padded.to(torch.int64) - boundary = torch.zeros( - (x.size(0), 4), dtype=torch.int64, device=x.device - ) + boundary = torch.zeros((x.size(0), 4), dtype=torch.int64, device=x.device) boundary[:, 2] = y_lens boundary[:, 3] = x_lens diff --git a/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/train.py b/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/train.py index fed814f19f..69e74cc572 100755 --- a/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/train.py +++ b/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/train.py @@ -209,8 +209,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( @@ -233,8 +232,7 @@ def get_parser(): "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)" - "part.", + help="The scale to smooth the loss with am (output of encoder network)" "part.", ) parser.add_argument( @@ -566,11 +564,7 @@ def compute_loss( function enables autograd during computation; when it is False, it disables autograd. """ - device = ( - model.device - if isinstance(model, DDP) - else next(model.parameters()).device - ) + device = model.device if isinstance(model, DDP) else next(model.parameters()).device feature = batch["inputs"] # at entry, feature is (N, T, C) assert feature.ndim == 3 @@ -599,9 +593,7 @@ def compute_loss( info = MetricsTracker() with warnings.catch_warnings(): warnings.simplefilter("ignore") - info["frames"] = ( - (feature_lens // params.subsampling_factor).sum().item() - ) + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa info["utterances"] = feature.size(0) @@ -782,9 +774,7 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -908,8 +898,7 @@ def remove_short_and_long_utt(c: Cut): # the threshold if c.duration < 1.0 or c.duration > 20.0: logging.warning( - f"Exclude cut with ID {c.id} from training. " - f"Duration: {c.duration}" + f"Exclude cut with ID {c.id} from training. " f"Duration: {c.duration}" ) return False diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless/beam_search.py index 7af9cc3d74..830b37cfb3 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless/beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/beam_search.py @@ -509,9 +509,9 @@ def greedy_search( y = logits.argmax().item() if y not in (blank_id, unk_id): hyp.append(y) - decoder_input = torch.tensor( - [hyp[-context_size:]], device=device - ).reshape(1, context_size) + decoder_input = torch.tensor([hyp[-context_size:]], device=device).reshape( + 1, context_size + ) decoder_out = model.decoder(decoder_input, need_pad=False) @@ -670,9 +670,7 @@ def add(self, hyp: Hypothesis, use_max: bool = False) -> None: if use_max: old_hyp.log_prob = max(old_hyp.log_prob, hyp.log_prob) else: - torch.logaddexp( - old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob - ) + torch.logaddexp(old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob) else: self._data[key] = hyp @@ -688,9 +686,7 @@ def get_most_probable(self, length_norm: bool = False) -> Hypothesis: Return the hypothesis that has the largest `log_prob`. """ if length_norm: - return max( - self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys) - ) + return max(self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys)) else: return max(self._data.values(), key=lambda hyp: hyp.log_prob) @@ -892,9 +888,7 @@ def modified_beam_search( log_probs_shape = k2.ragged.create_ragged_shape2( row_splits=row_splits, cached_tot_size=log_probs.numel() ) - ragged_log_probs = k2.RaggedTensor( - shape=log_probs_shape, value=log_probs - ) + ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) for i in range(batch_size): topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) @@ -1088,9 +1082,7 @@ def beam_search( t = 0 B = HypothesisList() - B.add( - Hypothesis(ys=[blank_id] * context_size, log_prob=0.0), use_max=use_max - ) + B.add(Hypothesis(ys=[blank_id] * context_size, log_prob=0.0), use_max=use_max) max_sym_per_utt = 20000 @@ -1130,9 +1122,7 @@ def beam_search( cached_key += f"-t-{t}" if cached_key not in joint_cache: - logits = model.joiner( - current_encoder_out, decoder_out.unsqueeze(1) - ) + logits = model.joiner(current_encoder_out, decoder_out.unsqueeze(1)) # TODO(fangjun): Scale the blank posterior diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless/decode.py index 7b63389484..12bd7f9bb4 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/decode.py @@ -128,11 +128,7 @@ ) from train import add_model_arguments, get_params, get_transducer_model -from icefall.checkpoint import ( - average_checkpoints, - find_checkpoints, - load_checkpoint, -) +from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint from icefall.lexicon import Lexicon from icefall.utils import ( AttributeDict, @@ -269,8 +265,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -383,9 +378,7 @@ def decode_one_batch( simulate_streaming=True, ) else: - encoder_out, encoder_out_lens = model.encoder( - x=feature, x_lens=feature_lens - ) + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) hyps = [] if ( @@ -450,10 +443,7 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif ( - params.decoding_method == "greedy_search" - and params.max_sym_per_frame == 1 - ): + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -584,9 +574,7 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -619,8 +607,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -678,9 +665,7 @@ def main(): if "LG" in params.decoding_method: params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" elif "beam_search" in params.decoding_method: - params.suffix += ( - f"-{params.decoding_method}-beam-size-{params.beam_size}" - ) + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -718,8 +703,7 @@ def main(): ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -757,9 +741,7 @@ def main(): decoding_graph.scores *= params.ngram_lm_scale else: word_table = None - decoding_graph = k2.trivial_graph( - params.vocab_size - 1, device=device - ) + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) else: decoding_graph = None word_table = None diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/decode_stream.py b/egs/librispeech/ASR/pruned_transducer_stateless/decode_stream.py index 3862485546..e522943c07 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless/decode_stream.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/decode_stream.py @@ -75,9 +75,7 @@ def __init__( # encoder.streaming_forward self.done_frames: int = 0 - self.pad_length = ( - params.right_context + 2 - ) * params.subsampling_factor + 3 + self.pad_length = (params.right_context + 2) * params.subsampling_factor + 3 if params.decoding_method == "greedy_search": self.hyp = [params.blank_id] * params.context_size @@ -91,13 +89,11 @@ def __init__( ) elif params.decoding_method == "fast_beam_search": # The rnnt_decoding_stream for fast_beam_search. - self.rnnt_decoding_stream: k2.RnntDecodingStream = ( - k2.RnntDecodingStream(decoding_graph) + self.rnnt_decoding_stream: k2.RnntDecodingStream = k2.RnntDecodingStream( + decoding_graph ) else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") @property def done(self) -> bool: @@ -126,13 +122,10 @@ def get_feature_frames(self, chunk_size: int) -> Tuple[torch.Tensor, int]: """Consume chunk_size frames of features""" chunk_length = chunk_size + self.pad_length - ret_length = min( - self.num_frames - self.num_processed_frames, chunk_length - ) + ret_length = min(self.num_frames - self.num_processed_frames, chunk_length) ret_features = self.features[ - self.num_processed_frames : self.num_processed_frames # noqa - + ret_length + self.num_processed_frames : self.num_processed_frames + ret_length # noqa ] self.num_processed_frames += chunk_size diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/decoder.py b/egs/librispeech/ASR/pruned_transducer_stateless/decoder.py index f4355e8a04..72593173cd 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless/decoder.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/decoder.py @@ -92,9 +92,7 @@ def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor: if self.context_size > 1: embedding_out = embedding_out.permute(0, 2, 1) if need_pad is True: - embedding_out = F.pad( - embedding_out, pad=(self.context_size - 1, 0) - ) + embedding_out = F.pad(embedding_out, pad=(self.context_size - 1, 0)) else: # During inference time, there is no need to do extra padding # as we only need one output diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/export.py b/egs/librispeech/ASR/pruned_transducer_stateless/export.py index b5a151878b..be45536d8f 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/export.py @@ -105,8 +105,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( @@ -192,9 +191,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/model.py b/egs/librispeech/ASR/pruned_transducer_stateless/model.py index 73b651b3fa..2cca7fa278 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/model.py @@ -130,9 +130,7 @@ def forward( y_padded = y.pad(mode="constant", padding_value=0) y_padded = y_padded.to(torch.int64) - boundary = torch.zeros( - (x.size(0), 4), dtype=torch.int64, device=x.device - ) + boundary = torch.zeros((x.size(0), 4), dtype=torch.int64, device=x.device) boundary[:, 2] = y_lens boundary[:, 3] = x_lens diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless/pretrained.py index eb95827af0..6e91e05014 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/pretrained.py @@ -168,8 +168,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -222,8 +221,7 @@ def read_sound_files( for f in filenames: wave, sample_rate = torchaudio.load(f) assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" + f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}" ) # We use only the first channel ans.append(wave[0]) @@ -292,9 +290,7 @@ def main(): features = fbank(waves) feature_lengths = [f.size(0) for f in features] - features = pad_sequence( - features, batch_first=True, padding_value=math.log(1e-10) - ) + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) feature_lengths = torch.tensor(feature_lengths, device=device) @@ -381,9 +377,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/streaming_beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless/streaming_beam_search.py index dcf6dc42f0..9e09200a11 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless/streaming_beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/streaming_beam_search.py @@ -166,14 +166,10 @@ def modified_beam_search( log_probs_shape = k2.ragged.create_ragged_shape2( row_splits=row_splits, cached_tot_size=log_probs.numel() ) - ragged_log_probs = k2.RaggedTensor( - shape=log_probs_shape, value=log_probs - ) + ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) for i in range(batch_size): - topk_log_probs, topk_indexes = ragged_log_probs[i].topk( - num_active_paths - ) + topk_log_probs, topk_indexes = ragged_log_probs[i].topk(num_active_paths) with warnings.catch_warnings(): warnings.simplefilter("ignore") diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/streaming_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless/streaming_decode.py index d2cae4f9f3..ce8e2f3481 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/streaming_decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/streaming_decode.py @@ -51,11 +51,7 @@ from torch.nn.utils.rnn import pad_sequence from train import add_model_arguments, get_params, get_transducer_model -from icefall.checkpoint import ( - average_checkpoints, - find_checkpoints, - load_checkpoint, -) +from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint from icefall.utils import ( AttributeDict, setup_logger, @@ -162,8 +158,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( @@ -269,9 +264,7 @@ def decode_one_chunk( ) if params.decoding_method == "greedy_search": - greedy_search( - model=model, encoder_out=encoder_out, streams=decode_streams - ) + greedy_search(model=model, encoder_out=encoder_out, streams=decode_streams) elif params.decoding_method == "fast_beam_search": processed_lens = processed_lens + encoder_out_lens fast_beam_search_one_best( @@ -291,9 +284,7 @@ def decode_one_chunk( num_active_paths=params.num_active_paths, ) else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") states = [torch.unbind(states[0], dim=2), torch.unbind(states[1], dim=2)] @@ -349,9 +340,7 @@ def decode_dataset( decode_results = [] # Contain decode streams currently running. decode_streams = [] - initial_states = model.encoder.get_init_state( - params.left_context, device=device - ) + initial_states = model.encoder.get_init_state(params.left_context, device=device) for num, cut in enumerate(cuts): # each utterance has a DecodeStream. decode_stream = DecodeStream( @@ -422,9 +411,7 @@ def decode_dataset( elif params.decoding_method == "modified_beam_search": key = f"num_active_paths_{params.num_active_paths}" else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") return {key: decode_results} @@ -460,8 +447,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -533,8 +519,7 @@ def main(): ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/train.py b/egs/librispeech/ASR/pruned_transducer_stateless/train.py index 399b11a293..7861df874c 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/train.py @@ -203,8 +203,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( @@ -227,8 +226,7 @@ def get_parser(): "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)" - "part.", + help="The scale to smooth the loss with am (output of encoder network)" "part.", ) parser.add_argument( @@ -562,9 +560,7 @@ def compute_loss( # If either all simple_loss or pruned_loss is inf or nan, # we stop the training process by raising an exception - if torch.all(~simple_loss_is_finite) or torch.all( - ~pruned_loss_is_finite - ): + if torch.all(~simple_loss_is_finite) or torch.all(~pruned_loss_is_finite): raise ValueError( "There are too many utterances in this batch " "leading to inf or nan losses." @@ -584,9 +580,7 @@ def compute_loss( # (1) The acutal subsampling factor is ((lens - 1) // 2 - 1) // 2 # (2) If some utterances in the batch lead to inf/nan loss, they # are filtered out. - info["frames"] = ( - (feature_lens // params.subsampling_factor).sum().item() - ) + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa info["utterances"] = feature.size(0) @@ -777,9 +771,7 @@ def maybe_log_weights(tag: str): loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -897,8 +889,7 @@ def remove_short_and_long_utt(c: Cut): # the threshold if c.duration < 1.0 or c.duration > 20.0: logging.warning( - f"Exclude cut with ID {c.id} from training. " - f"Duration: {c.duration}" + f"Exclude cut with ID {c.id} from training. " f"Duration: {c.duration}" ) return False @@ -956,9 +947,7 @@ def remove_short_and_long_utt(c: Cut): cur_lr = optimizer._rate if tb_writer is not None: - tb_writer.add_scalar( - "train/learning_rate", cur_lr, params.batch_idx_train - ) + tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train) tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) if rank == 0: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py index b7c2010f7a..5e9428b607 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py @@ -580,9 +580,9 @@ def greedy_search( if y not in (blank_id, unk_id): hyp.append(y) timestamp.append(t) - decoder_input = torch.tensor( - [hyp[-context_size:]], device=device - ).reshape(1, context_size) + decoder_input = torch.tensor([hyp[-context_size:]], device=device).reshape( + 1, context_size + ) decoder_out = model.decoder(decoder_input, need_pad=False) decoder_out = model.joiner.decoder_proj(decoder_out) @@ -775,9 +775,7 @@ def add(self, hyp: Hypothesis) -> None: key = hyp.key if key in self: old_hyp = self._data[key] # shallow copy - torch.logaddexp( - old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob - ) + torch.logaddexp(old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob) else: self._data[key] = hyp @@ -793,9 +791,7 @@ def get_most_probable(self, length_norm: bool = False) -> Hypothesis: Return the hypothesis that has the largest `log_prob`. """ if length_norm: - return max( - self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys) - ) + return max(self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys)) else: return max(self._data.values(), key=lambda hyp: hyp.log_prob) @@ -990,9 +986,7 @@ def modified_beam_search( logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size) - log_probs = (logits / temperature).log_softmax( - dim=-1 - ) # (num_hyps, vocab_size) + log_probs = (logits / temperature).log_softmax(dim=-1) # (num_hyps, vocab_size) log_probs.add_(ys_log_probs) @@ -1004,9 +998,7 @@ def modified_beam_search( log_probs_shape = k2.ragged.create_ragged_shape2( row_splits=row_splits, cached_tot_size=log_probs.numel() ) - ragged_log_probs = k2.RaggedTensor( - shape=log_probs_shape, value=log_probs - ) + ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) for i in range(batch_size): topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) @@ -1676,9 +1668,7 @@ def fast_beam_search_with_nbest_rnn_rescoring( for rnn_scale in rnn_lm_scale_list: key = f"ngram_lm_scale_{n_scale}_rnn_lm_scale_{rnn_scale}" tot_scores = ( - am_scores.values - + n_scale * ngram_lm_scores - + rnn_scale * rnn_lm_scores + am_scores.values + n_scale * ngram_lm_scores + rnn_scale * rnn_lm_scores ) ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores) max_indexes = ragged_tot_scores.argmax() @@ -1804,9 +1794,7 @@ def modified_beam_search_ngram_rescoring( logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size) - log_probs = (logits / temperature).log_softmax( - dim=-1 - ) # (num_hyps, vocab_size) + log_probs = (logits / temperature).log_softmax(dim=-1) # (num_hyps, vocab_size) log_probs.add_(ys_log_probs) vocab_size = log_probs.size(-1) @@ -1816,9 +1804,7 @@ def modified_beam_search_ngram_rescoring( log_probs_shape = k2.ragged.create_ragged_shape2( row_splits=row_splits, cached_tot_size=log_probs.numel() ) - ragged_log_probs = k2.RaggedTensor( - shape=log_probs_shape, value=log_probs - ) + ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) for i in range(batch_size): topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) @@ -1841,9 +1827,7 @@ def modified_beam_search_ngram_rescoring( state_cost = hyp.state_cost # We only keep AM scores in new_hyp.log_prob - new_log_prob = ( - topk_log_probs[k] - hyp.state_cost.lm_score * lm_scale - ) + new_log_prob = topk_log_probs[k] - hyp.state_cost.lm_score * lm_scale new_hyp = Hypothesis( ys=new_ys, log_prob=new_log_prob, state_cost=state_cost @@ -1995,9 +1979,7 @@ def modified_beam_search_rnnlm_shallow_fusion( log_probs_shape = k2.ragged.create_ragged_shape2( row_splits=row_splits, cached_tot_size=log_probs.numel() ) - ragged_log_probs = k2.RaggedTensor( - shape=log_probs_shape, value=log_probs - ) + ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) """ for all hyps with a non-blank new token, score this token. It is a little confusing here because this for-loop @@ -2032,10 +2014,7 @@ def modified_beam_search_rnnlm_shallow_fusion( # forward RNNLM to get new states and scores if len(token_list) != 0: tokens_to_score = ( - torch.tensor(token_list) - .to(torch.int64) - .to(device) - .reshape(-1, 1) + torch.tensor(token_list).to(torch.int64).to(device).reshape(-1, 1) ) hs = torch.cat(hs, dim=1).to(device) @@ -2067,9 +2046,7 @@ def modified_beam_search_rnnlm_shallow_fusion( ys.append(new_token) new_timestamp.append(t) - hyp_log_prob += ( - lm_score[new_token] * lm_scale - ) # add the lm score + hyp_log_prob += lm_score[new_token] * lm_scale # add the lm score lm_score = scores[count] state = ( diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py index bc273d33bc..f94ffef59b 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py @@ -214,10 +214,7 @@ def get_init_state( NOTE: the returned tensors are on the given device. """ - if ( - len(self._init_state) == 2 - and self._init_state[0].size(1) == left_context - ): + if len(self._init_state) == 2 and self._init_state[0].size(1) == left_context: # Note: It is OK to share the init state as it is # not going to be modified by the model return self._init_state @@ -439,9 +436,7 @@ def __init__( self.d_model = d_model - self.self_attn = RelPositionMultiheadAttention( - d_model, nhead, dropout=0.0 - ) + self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0) self.feed_forward = nn.Sequential( ScaledLinear(d_model, dim_feedforward), @@ -459,9 +454,7 @@ def __init__( ScaledLinear(dim_feedforward, d_model, initial_scale=0.25), ) - self.conv_module = ConvolutionModule( - d_model, cnn_module_kernel, causal=causal - ) + self.conv_module = ConvolutionModule(d_model, cnn_module_kernel, causal=causal) self.norm_final = BasicNorm(d_model) @@ -527,9 +520,7 @@ def forward( src = src + self.dropout(src_att) # convolution module - conv, _ = self.conv_module( - src, src_key_padding_mask=src_key_padding_mask - ) + conv, _ = self.conv_module(src, src_key_padding_mask=src_key_padding_mask) src = src + self.dropout(conv) # feed forward module @@ -785,9 +776,7 @@ class RelPositionalEncoding(torch.nn.Module): """ - def __init__( - self, d_model: int, dropout_rate: float, max_len: int = 5000 - ) -> None: + def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None: """Construct an PositionalEncoding object.""" super(RelPositionalEncoding, self).__init__() if is_jit_tracing(): @@ -811,9 +800,7 @@ def extend_pe(self, x: Tensor, left_context: int = 0) -> None: # the length of self.pe is 2 * input_len - 1 if self.pe.size(1) >= x_size_1 * 2 - 1: # Note: TorchScript doesn't implement operator== for torch.Device - if self.pe.dtype != x.dtype or str(self.pe.device) != str( - x.device - ): + if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device): self.pe = self.pe.to(dtype=x.dtype, device=x.device) return # Suppose `i` means to the position of query vector and `j` means the @@ -1127,9 +1114,9 @@ def multi_head_attention_forward( if torch.equal(query, key) and torch.equal(key, value): # self-attention - q, k, v = nn.functional.linear( - query, in_proj_weight, in_proj_bias - ).chunk(3, dim=-1) + q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk( + 3, dim=-1 + ) elif torch.equal(key, value): # encoder-decoder attention @@ -1198,31 +1185,22 @@ def multi_head_attention_forward( if attn_mask.dim() == 2: attn_mask = attn_mask.unsqueeze(0) if list(attn_mask.size()) != [1, query.size(0), key.size(0)]: - raise RuntimeError( - "The size of the 2D attn_mask is not correct." - ) + raise RuntimeError("The size of the 2D attn_mask is not correct.") elif attn_mask.dim() == 3: if list(attn_mask.size()) != [ bsz * num_heads, query.size(0), key.size(0), ]: - raise RuntimeError( - "The size of the 3D attn_mask is not correct." - ) + raise RuntimeError("The size of the 3D attn_mask is not correct.") else: raise RuntimeError( - "attn_mask's dimension {} is not supported".format( - attn_mask.dim() - ) + "attn_mask's dimension {} is not supported".format(attn_mask.dim()) ) # attn_mask's dim is 3 now. # convert ByteTensor key_padding_mask to bool - if ( - key_padding_mask is not None - and key_padding_mask.dtype == torch.uint8 - ): + if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: warnings.warn( "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead." ) @@ -1264,23 +1242,15 @@ def multi_head_attention_forward( # first compute matrix a and matrix c # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) - matrix_ac = torch.matmul( - q_with_bias_u, k - ) # (batch, head, time1, time2) + matrix_ac = torch.matmul(q_with_bias_u, k) # (batch, head, time1, time2) # compute matrix b and matrix d - matrix_bd = torch.matmul( - q_with_bias_v, p - ) # (batch, head, time1, 2*time1-1) + matrix_bd = torch.matmul(q_with_bias_v, p) # (batch, head, time1, 2*time1-1) matrix_bd = self.rel_shift(matrix_bd, left_context) - attn_output_weights = ( - matrix_ac + matrix_bd - ) # (batch, head, time1, time2) + attn_output_weights = matrix_ac + matrix_bd # (batch, head, time1, time2) - attn_output_weights = attn_output_weights.view( - bsz * num_heads, tgt_len, -1 - ) + attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1) if not is_jit_tracing(): assert list(attn_output_weights.size()) == [ @@ -1322,21 +1292,17 @@ def multi_head_attention_forward( ): if attn_mask.size(0) != 1: attn_mask = attn_mask.view(bsz, num_heads, tgt_len, src_len) - combined_mask = attn_mask | key_padding_mask.unsqueeze( - 1 - ).unsqueeze(2) + combined_mask = attn_mask | key_padding_mask.unsqueeze(1).unsqueeze(2) else: # attn_mask.shape == (1, tgt_len, src_len) - combined_mask = attn_mask.unsqueeze( - 0 - ) | key_padding_mask.unsqueeze(1).unsqueeze(2) + combined_mask = attn_mask.unsqueeze(0) | key_padding_mask.unsqueeze( + 1 + ).unsqueeze(2) attn_output_weights = attn_output_weights.view( bsz, num_heads, tgt_len, src_len ) - attn_output_weights = attn_output_weights.masked_fill( - combined_mask, 0.0 - ) + attn_output_weights = attn_output_weights.masked_fill(combined_mask, 0.0) attn_output_weights = attn_output_weights.view( bsz * num_heads, tgt_len, src_len ) @@ -1355,13 +1321,9 @@ def multi_head_attention_forward( ] attn_output = ( - attn_output.transpose(0, 1) - .contiguous() - .view(tgt_len, bsz, embed_dim) - ) - attn_output = nn.functional.linear( - attn_output, out_proj_weight, out_proj_bias + attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) ) + attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias) if need_weights: # average attention weights over heads @@ -1498,16 +1460,12 @@ def forward( # manualy padding self.lorder zeros to the left x = nn.functional.pad(x, (self.lorder, 0), "constant", 0.0) else: - assert ( - not self.training - ), "Cache should be None in training time" + assert not self.training, "Cache should be None in training time" assert cache.size(0) == self.lorder x = torch.cat([cache.permute(1, 2, 0), x], dim=2) if right_context > 0: cache = x.permute(2, 0, 1)[ - -(self.lorder + right_context) : ( # noqa - -right_context - ), + -(self.lorder + right_context) : (-right_context), # noqa ..., ] else: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py index 979a0e02e0..92138a5ea1 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py @@ -132,11 +132,7 @@ ) from train import add_model_arguments, get_params, get_transducer_model -from icefall.checkpoint import ( - average_checkpoints, - find_checkpoints, - load_checkpoint, -) +from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint from icefall.lexicon import Lexicon from icefall.utils import ( AttributeDict, @@ -275,8 +271,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( @@ -397,9 +392,7 @@ def decode_one_batch( simulate_streaming=True, ) else: - encoder_out, encoder_out_lens = model.encoder( - x=feature, x_lens=feature_lens - ) + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) hyps = [] @@ -465,10 +458,7 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif ( - params.decoding_method == "greedy_search" - and params.max_sym_per_frame == 1 - ): + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -608,9 +598,7 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -643,8 +631,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -700,9 +687,7 @@ def main(): if "LG" in params.decoding_method: params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" elif "beam_search" in params.decoding_method: - params.suffix += ( - f"-{params.decoding_method}-beam-size-{params.beam_size}" - ) + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -740,8 +725,7 @@ def main(): ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -779,9 +763,7 @@ def main(): decoding_graph.scores *= params.ngram_lm_scale else: word_table = None - decoding_graph = k2.trivial_graph( - params.vocab_size - 1, device=device - ) + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) else: decoding_graph = None word_table = None diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py b/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py index ba91302ce9..b599281038 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py @@ -107,15 +107,11 @@ def forward( # This is for exporting to PNNX via ONNX embedding_out = self.embedding(y) else: - embedding_out = self.embedding(y.clamp(min=0)) * (y >= 0).unsqueeze( - -1 - ) + embedding_out = self.embedding(y.clamp(min=0)) * (y >= 0).unsqueeze(-1) if self.context_size > 1: embedding_out = embedding_out.permute(0, 2, 1) if need_pad: - embedding_out = F.pad( - embedding_out, pad=(self.context_size - 1, 0) - ) + embedding_out = F.pad(embedding_out, pad=(self.context_size - 1, 0)) else: # During inference time, there is no need to do extra padding # as we only need one output diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/export.py b/egs/librispeech/ASR/pruned_transducer_stateless2/export.py index f1a8ea589f..4f1170bbc8 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/export.py @@ -51,11 +51,7 @@ import torch from train import add_model_arguments, get_params, get_transducer_model -from icefall.checkpoint import ( - average_checkpoints, - find_checkpoints, - load_checkpoint, -) +from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint from icefall.utils import str2bool @@ -120,8 +116,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( @@ -173,8 +168,7 @@ def main(): ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -222,9 +216,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/joiner.py b/egs/librispeech/ASR/pruned_transducer_stateless2/joiner.py index 6a9d080331..1954f47246 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/joiner.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/joiner.py @@ -60,9 +60,7 @@ def forward( assert encoder_out.shape == decoder_out.shape if project_input: - logit = self.encoder_proj(encoder_out) + self.decoder_proj( - decoder_out - ) + logit = self.encoder_proj(encoder_out) + self.decoder_proj(decoder_out) else: logit = encoder_out + decoder_out diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/model.py b/egs/librispeech/ASR/pruned_transducer_stateless2/model.py index 417c391d9b..272d06c37d 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/model.py @@ -66,9 +66,7 @@ def __init__( self.decoder = decoder self.joiner = joiner - self.simple_am_proj = ScaledLinear( - encoder_dim, vocab_size, initial_speed=0.5 - ) + self.simple_am_proj = ScaledLinear(encoder_dim, vocab_size, initial_speed=0.5) self.simple_lm_proj = ScaledLinear(decoder_dim, vocab_size) def forward( @@ -152,9 +150,7 @@ def forward( y_padded = y.pad(mode="constant", padding_value=0) y_padded = y_padded.to(torch.int64) - boundary = torch.zeros( - (x.size(0), 4), dtype=torch.int64, device=x.device - ) + boundary = torch.zeros((x.size(0), 4), dtype=torch.int64, device=x.device) boundary[:, 2] = y_lens boundary[:, 3] = x_lens diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless2/optim.py index 041a81f45b..2d7f557ad9 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/optim.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/optim.py @@ -72,17 +72,11 @@ def __init__( if not 0.0 <= eps: raise ValueError("Invalid epsilon value: {}".format(eps)) if not 0.0 <= betas[0] < 1.0: - raise ValueError( - "Invalid beta parameter at index 0: {}".format(betas[0]) - ) + raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) if not 0.0 <= betas[1] < 1.0: - raise ValueError( - "Invalid beta parameter at index 1: {}".format(betas[1]) - ) + raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) if not 0 <= weight_decay <= 0.1: - raise ValueError( - "Invalid weight_decay value: {}".format(weight_decay) - ) + raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) if not 0 < target_rms <= 10.0: raise ValueError("Invalid target_rms value: {}".format(target_rms)) defaults = dict( @@ -118,9 +112,7 @@ def step(self, closure=None): # Perform optimization step grad = p.grad if grad.is_sparse: - raise RuntimeError( - "AdamW does not support sparse gradients" - ) + raise RuntimeError("AdamW does not support sparse gradients") state = self.state[p] @@ -147,7 +139,7 @@ def step(self, closure=None): # Decay the first and second moment running average coefficient exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) - denom = (exp_avg_sq.sqrt() * (bias_correction2 ** -0.5)).add_( + denom = (exp_avg_sq.sqrt() * (bias_correction2**-0.5)).add_( group["eps"] ) @@ -158,9 +150,7 @@ def step(self, closure=None): if p.numel() > 1: # avoid applying this weight-decay on "scaling factors" # (which are scalar). - is_above_target_rms = p.norm() > ( - target_rms * (p.numel() ** 0.5) - ) + is_above_target_rms = p.norm() > (target_rms * (p.numel() ** 0.5)) p.mul_(1 - (weight_decay * is_above_target_rms)) p.addcdiv_(exp_avg, denom, value=-step_size) @@ -180,18 +170,14 @@ class LRScheduler(object): def __init__(self, optimizer: Optimizer, verbose: bool = False): # Attach optimizer if not isinstance(optimizer, Optimizer): - raise TypeError( - "{} is not an Optimizer".format(type(optimizer).__name__) - ) + raise TypeError("{} is not an Optimizer".format(type(optimizer).__name__)) self.optimizer = optimizer self.verbose = verbose for group in optimizer.param_groups: group.setdefault("initial_lr", group["lr"]) - self.base_lrs = [ - group["initial_lr"] for group in optimizer.param_groups - ] + self.base_lrs = [group["initial_lr"] for group in optimizer.param_groups] self.epoch = 0 self.batch = 0 @@ -299,10 +285,9 @@ def __init__( def get_lr(self): factor = ( - (self.batch ** 2 + self.lr_batches ** 2) / self.lr_batches ** 2 + (self.batch**2 + self.lr_batches**2) / self.lr_batches**2 ) ** -0.25 * ( - ((self.epoch ** 2 + self.lr_epochs ** 2) / self.lr_epochs ** 2) - ** -0.25 + ((self.epoch**2 + self.lr_epochs**2) / self.lr_epochs**2) ** -0.25 ) return [x * factor for x in self.base_lrs] diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless2/pretrained.py index f52cb22ab5..e5b5aeba55 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/pretrained.py @@ -168,8 +168,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -223,8 +222,7 @@ def read_sound_files( for f in filenames: wave, sample_rate = torchaudio.load(f) assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" + f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}" ) # We use only the first channel ans.append(wave[0]) @@ -293,9 +291,7 @@ def main(): features = fbank(waves) feature_lengths = [f.size(0) for f in features] - features = pad_sequence( - features, batch_first=True, padding_value=math.log(1e-10) - ) + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) feature_lengths = torch.tensor(feature_lengths, device=device) @@ -382,9 +378,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py index 8c572a9ef1..c802ecf89c 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py @@ -89,9 +89,7 @@ def forward( below_threshold = mean_abs < min_abs above_threshold = mean_abs > max_abs - ctx.save_for_backward( - factor, xgt0, below_threshold, above_threshold - ) + ctx.save_for_backward(factor, xgt0, below_threshold, above_threshold) ctx.max_factor = max_factor ctx.sum_dims = sum_dims return x @@ -137,7 +135,7 @@ def backward( eps = 1.0e-20 dim = ctx.batch_dim norm_dims = [d for d in range(x_grad.ndim) if d != dim] - norm_of_batch = (x_grad ** 2).mean(dim=norm_dims, keepdim=True).sqrt() + norm_of_batch = (x_grad**2).mean(dim=norm_dims, keepdim=True).sqrt() median_norm = norm_of_batch.median() cutoff = median_norm * ctx.threshold @@ -229,8 +227,7 @@ def forward(self, x: Tensor) -> Tensor: if not is_jit_tracing(): assert x.shape[self.channel_dim] == self.num_channels scales = ( - torch.mean(x ** 2, dim=self.channel_dim, keepdim=True) - + self.eps.exp() + torch.mean(x**2, dim=self.channel_dim, keepdim=True) + self.eps.exp() ) ** -0.5 return x * scales @@ -282,12 +279,12 @@ def __init__( def _reset_parameters(self, initial_speed: float): std = 0.1 / initial_speed - a = (3 ** 0.5) * std + a = (3**0.5) * std nn.init.uniform_(self.weight, -a, a) if self.bias is not None: nn.init.constant_(self.bias, 0.0) fan_in = self.weight.shape[1] * self.weight[0][0].numel() - scale = fan_in ** -0.5 # 1/sqrt(fan_in) + scale = fan_in**-0.5 # 1/sqrt(fan_in) with torch.no_grad(): self.weight_scale += torch.tensor(scale / std).log() @@ -301,9 +298,7 @@ def get_bias(self): return self.bias * self.bias_scale.exp() def forward(self, input: Tensor) -> Tensor: - return torch.nn.functional.linear( - input, self.get_weight(), self.get_bias() - ) + return torch.nn.functional.linear(input, self.get_weight(), self.get_bias()) class ScaledConv1d(nn.Conv1d): @@ -331,12 +326,12 @@ def __init__( def _reset_parameters(self, initial_speed: float): std = 0.1 / initial_speed - a = (3 ** 0.5) * std + a = (3**0.5) * std nn.init.uniform_(self.weight, -a, a) if self.bias is not None: nn.init.constant_(self.bias, 0.0) fan_in = self.weight.shape[1] * self.weight[0][0].numel() - scale = fan_in ** -0.5 # 1/sqrt(fan_in) + scale = fan_in**-0.5 # 1/sqrt(fan_in) with torch.no_grad(): self.weight_scale += torch.tensor(scale / std).log() @@ -400,12 +395,12 @@ def __init__( def _reset_parameters(self, initial_speed: float): std = 0.1 / initial_speed - a = (3 ** 0.5) * std + a = (3**0.5) * std nn.init.uniform_(self.weight, -a, a) if self.bias is not None: nn.init.constant_(self.bias, 0.0) fan_in = self.weight.shape[1] * self.weight[0][0].numel() - scale = fan_in ** -0.5 # 1/sqrt(fan_in) + scale = fan_in**-0.5 # 1/sqrt(fan_in) with torch.no_grad(): self.weight_scale += torch.tensor(scale / std).log() @@ -476,9 +471,7 @@ def __init__( setattr(self, scale_name, param) self._scales.append(param) - self.grad_filter = GradientFilter( - batch_dim=1, threshold=grad_norm_threshold - ) + self.grad_filter = GradientFilter(batch_dim=1, threshold=grad_norm_threshold) self._reset_parameters( initial_speed @@ -486,8 +479,8 @@ def __init__( def _reset_parameters(self, initial_speed: float): std = 0.1 / initial_speed - a = (3 ** 0.5) * std - scale = self.hidden_size ** -0.5 + a = (3**0.5) * std + scale = self.hidden_size**-0.5 v = scale / std for idx, name in enumerate(self._flat_weights_names): if "weight" in name: @@ -559,15 +552,11 @@ def _get_flat_weights(self): """Get scaled weights, and resets their data pointer.""" flat_weights = [] for idx in range(len(self._flat_weights_names)): - flat_weights.append( - self._flat_weights[idx] * self._scales[idx].exp() - ) + flat_weights.append(self._flat_weights[idx] * self._scales[idx].exp()) self._flatten_parameters(flat_weights) return flat_weights - def forward( - self, input: Tensor, hx: Optional[Tuple[Tensor, Tensor]] = None - ): + def forward(self, input: Tensor, hx: Optional[Tuple[Tensor, Tensor]] = None): # This function is modified from https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/rnn.py # noqa # The change for calling `_VF.lstm()` is: # self._flat_weights -> self._get_flat_weights() @@ -915,9 +904,7 @@ def _test_activation_balancer_sign(): def _test_activation_balancer_magnitude(): magnitudes = torch.arange(0, 1, 0.01) N = 1000 - x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze( - -1 - ) + x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze(-1) x = x.detach() x.requires_grad = True m = ActivationBalancer( @@ -947,8 +934,8 @@ def _test_basic_norm(): y = m(x) assert y.shape == x.shape - x_rms = (x ** 2).mean().sqrt() - y_rms = (y ** 2).mean().sqrt() + x_rms = (x**2).mean().sqrt() + y_rms = (y**2).mean().sqrt() print("x rms = ", x_rms) print("y rms = ", y_rms) assert y_rms < x_rms @@ -1007,11 +994,11 @@ def _test_grad_filter(): print( "_test_grad_filter: x_out_grad norm = ", - (x_out_grad ** 2).mean(dim=(0, 2)).sqrt(), + (x_out_grad**2).mean(dim=(0, 2)).sqrt(), ) print( "_test_grad_filter: x.grad norm = ", - (x.grad ** 2).mean(dim=(0, 2)).sqrt(), + (x.grad**2).mean(dim=(0, 2)).sqrt(), ) print("_test_grad_filter: w_out_grad = ", w_out_grad) print("_test_grad_filter: w.grad = ", w.grad) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_beam_search.py index 9bcd2f9f91..e6e0fb1c84 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_beam_search.py @@ -153,9 +153,7 @@ def modified_beam_search( index=hyps_shape.row_ids(1).to(torch.int64), ) # (num_hyps, encoder_out_dim) - logits = model.joiner( - current_encoder_out, decoder_out, project_input=False - ) + logits = model.joiner(current_encoder_out, decoder_out, project_input=False) # logits is of shape (num_hyps, 1, 1, vocab_size) logits = logits.squeeze(1).squeeze(1) @@ -172,14 +170,10 @@ def modified_beam_search( log_probs_shape = k2.ragged.create_ragged_shape2( row_splits=row_splits, cached_tot_size=log_probs.numel() ) - ragged_log_probs = k2.RaggedTensor( - shape=log_probs_shape, value=log_probs - ) + ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) for i in range(batch_size): - topk_log_probs, topk_indexes = ragged_log_probs[i].topk( - num_active_paths - ) + topk_log_probs, topk_indexes = ragged_log_probs[i].topk(num_active_paths) with warnings.catch_warnings(): warnings.simplefilter("ignore") diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_decode.py index d76a039462..0eea3a7828 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_decode.py @@ -51,11 +51,7 @@ from torch.nn.utils.rnn import pad_sequence from train import add_model_arguments, get_params, get_transducer_model -from icefall.checkpoint import ( - average_checkpoints, - find_checkpoints, - load_checkpoint, -) +from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint from icefall.utils import ( AttributeDict, setup_logger, @@ -162,8 +158,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( @@ -271,9 +266,7 @@ def decode_one_chunk( encoder_out = model.joiner.encoder_proj(encoder_out) if params.decoding_method == "greedy_search": - greedy_search( - model=model, encoder_out=encoder_out, streams=decode_streams - ) + greedy_search(model=model, encoder_out=encoder_out, streams=decode_streams) elif params.decoding_method == "fast_beam_search": processed_lens = processed_lens + encoder_out_lens fast_beam_search_one_best( @@ -293,9 +286,7 @@ def decode_one_chunk( num_active_paths=params.num_active_paths, ) else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") states = [torch.unbind(states[0], dim=2), torch.unbind(states[1], dim=2)] @@ -351,9 +342,7 @@ def decode_dataset( decode_results = [] # Contain decode streams currently running. decode_streams = [] - initial_states = model.encoder.get_init_state( - params.left_context, device=device - ) + initial_states = model.encoder.get_init_state(params.left_context, device=device) for num, cut in enumerate(cuts): # each utterance has a DecodeStream. decode_stream = DecodeStream( @@ -425,9 +414,7 @@ def decode_dataset( elif params.decoding_method == "modified_beam_search": key = f"num_active_paths_{params.num_active_paths}" else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") return {key: decode_results} @@ -462,8 +449,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -536,8 +522,7 @@ def main(): ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py index 1947834bf2..f6702ef16b 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py @@ -96,9 +96,7 @@ str2bool, ) -LRSchedulerType = Union[ - torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler -] +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] def add_model_arguments(parser: argparse.ArgumentParser): @@ -210,8 +208,7 @@ def get_parser(): "--initial-lr", type=float, default=0.003, - help="The initial learning rate. This value should not need to " - "be changed.", + help="The initial learning rate. This value should not need to " "be changed.", ) parser.add_argument( @@ -234,8 +231,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( @@ -258,8 +254,7 @@ def get_parser(): "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)" - "part.", + help="The scale to smooth the loss with am (output of encoder network)" "part.", ) parser.add_argument( @@ -634,9 +629,7 @@ def compute_loss( # If either all simple_loss or pruned_loss is inf or nan, # we stop the training process by raising an exception - if torch.all(~simple_loss_is_finite) or torch.all( - ~pruned_loss_is_finite - ): + if torch.all(~simple_loss_is_finite) or torch.all(~pruned_loss_is_finite): raise ValueError( "There are too many utterances in this batch " "leading to inf or nan losses." @@ -649,14 +642,9 @@ def compute_loss( # overwhelming the simple_loss and causing it to diverge, # in case it had not fully learned the alignment yet. pruned_loss_scale = ( - 0.0 - if warmup < 1.0 - else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) - ) - loss = ( - params.simple_loss_scale * simple_loss - + pruned_loss_scale * pruned_loss + 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) ) + loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss assert loss.requires_grad == is_training @@ -667,9 +655,7 @@ def compute_loss( # (1) The acutal subsampling factor is ((lens - 1) // 2 - 1) // 2 # (2) If some utterances in the batch lead to inf/nan loss, they # are filtered out. - info["frames"] = ( - (feature_lens // params.subsampling_factor).sum().item() - ) + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa info["utterances"] = feature.size(0) @@ -837,9 +823,7 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -963,8 +947,7 @@ def remove_short_and_long_utt(c: Cut): # the threshold if c.duration < 1.0 or c.duration > 20.0: logging.warning( - f"Exclude cut with ID {c.id} from training. " - f"Duration: {c.duration}" + f"Exclude cut with ID {c.id} from training. " f"Duration: {c.duration}" ) return False diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/asr_datamodule.py b/egs/librispeech/ASR/pruned_transducer_stateless3/asr_datamodule.py index 1df7f9ee57..b7735be854 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/asr_datamodule.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/asr_datamodule.py @@ -27,10 +27,7 @@ K2SpeechRecognitionDataset, SpecAugment, ) -from lhotse.dataset.input_strategies import ( - OnTheFlyFeatures, - PrecomputedFeatures, -) +from lhotse.dataset.input_strategies import OnTheFlyFeatures, PrecomputedFeatures from torch.utils.data import DataLoader from icefall.utils import str2bool @@ -167,9 +164,7 @@ def train_dataloaders( if cuts_musan is not None: logging.info("Enable MUSAN") transforms.append( - CutMix( - cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True - ) + CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) ) else: logging.info("Disable MUSAN") @@ -178,9 +173,7 @@ def train_dataloaders( if self.args.enable_spec_aug: logging.info("Enable SpecAugment") - logging.info( - f"Time warp factor: {self.args.spec_aug_time_warp_factor}" - ) + logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") input_transforms.append( SpecAugment( time_warp_factor=self.args.spec_aug_time_warp_factor, @@ -250,9 +243,7 @@ def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: if self.args.on_the_fly_feats: validate = K2SpeechRecognitionDataset( cut_transforms=transforms, - input_strategy=OnTheFlyFeatures( - Fbank(FbankConfig(num_mel_bins=80)) - ), + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), return_cuts=self.args.return_cuts, ) else: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/decode-giga.py b/egs/librispeech/ASR/pruned_transducer_stateless3/decode-giga.py index 5784a78bab..df24d95854 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/decode-giga.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/decode-giga.py @@ -79,11 +79,7 @@ from gigaspeech_scoring import asr_text_post_processing from train import get_params, get_transducer_model -from icefall.checkpoint import ( - average_checkpoints, - find_checkpoints, - load_checkpoint, -) +from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint from icefall.utils import ( AttributeDict, setup_logger, @@ -192,8 +188,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -280,9 +275,7 @@ def decode_one_batch( supervisions = batch["supervisions"] feature_lens = supervisions["num_frames"].to(device) - encoder_out, encoder_out_lens = model.encoder( - x=feature, x_lens=feature_lens - ) + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) hyps = [] if params.decoding_method == "fast_beam_search": @@ -312,10 +305,7 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif ( - params.decoding_method == "greedy_search" - and params.max_sym_per_frame == 1 - ): + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -446,9 +436,7 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -481,8 +469,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -532,9 +519,7 @@ def main(): params.suffix += f"-num-paths-{params.num_paths}" params.suffix += f"-nbest-scale-{params.nbest_scale}" elif "beam_search" in params.decoding_method: - params.suffix += ( - f"-{params.decoding_method}-beam-size-{params.beam_size}" - ) + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -567,8 +552,7 @@ def main(): ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py index 8025d6be17..55585e08c9 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py @@ -120,11 +120,7 @@ from librispeech import LibriSpeech from train import add_model_arguments, get_params, get_transducer_model -from icefall.checkpoint import ( - average_checkpoints, - find_checkpoints, - load_checkpoint, -) +from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint from icefall.lexicon import Lexicon from icefall.rnn_lm.model import RnnLmModel from icefall.utils import ( @@ -265,8 +261,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -478,9 +473,7 @@ def decode_one_batch( simulate_streaming=True, ) else: - encoder_out, encoder_out_lens = model.encoder( - x=feature, x_lens=feature_lens - ) + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) hyps = [] @@ -550,10 +543,7 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif ( - params.decoding_method == "greedy_search" - and params.max_sym_per_frame == 1 - ): + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -691,10 +681,7 @@ def decode_one_batch( return {key: hyps} else: return { - ( - f"beam_size_{params.beam_size}_" - f"temperature_{params.temperature}" - ): hyps + (f"beam_size_{params.beam_size}_" f"temperature_{params.temperature}"): hyps } @@ -779,9 +766,7 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -814,8 +799,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -939,9 +923,7 @@ def main(): if "LG" in params.decoding_method: params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" elif "beam_search" in params.decoding_method: - params.suffix += ( - f"-{params.decoding_method}-beam-size-{params.beam_size}" - ) + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" params.suffix += f"-temperature-{params.temperature}" else: params.suffix += f"-context-{params.context_size}" @@ -981,8 +963,7 @@ def main(): ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -1032,15 +1013,10 @@ def main(): word_table=word_table, device=device, ) - decoding_graph = k2.trivial_graph( - params.vocab_size - 1, device=device - ) + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) logging.info(f"G properties_str: {G.properties_str}") rnn_lm_model = None - if ( - params.decoding_method - == "fast_beam_search_with_nbest_rnn_rescoring" - ): + if params.decoding_method == "fast_beam_search_with_nbest_rnn_rescoring": rnn_lm_model = RnnLmModel( vocab_size=params.vocab_size, embedding_dim=params.rnn_lm_embedding_dim, @@ -1065,9 +1041,7 @@ def main(): rnn_lm_model.eval() else: word_table = None - decoding_graph = k2.trivial_graph( - params.vocab_size - 1, device=device - ) + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) rnn_lm_model = None else: decoding_graph = None diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/export.py b/egs/librispeech/ASR/pruned_transducer_stateless3/export.py index 47217ba057..2e444353cf 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/export.py @@ -128,11 +128,7 @@ from scaling_converter import convert_scaled_to_non_scaled from train import add_model_arguments, get_params, get_transducer_model -from icefall.checkpoint import ( - average_checkpoints, - find_checkpoints, - load_checkpoint, -) +from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint from icefall.utils import str2bool @@ -235,8 +231,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( @@ -509,13 +504,9 @@ def export_joiner_model_onnx( - projected_decoder_out: a tensor of shape (N, joiner_dim) """ - encoder_proj_filename = str(joiner_filename).replace( - ".onnx", "_encoder_proj.onnx" - ) + encoder_proj_filename = str(joiner_filename).replace(".onnx", "_encoder_proj.onnx") - decoder_proj_filename = str(joiner_filename).replace( - ".onnx", "_decoder_proj.onnx" - ) + decoder_proj_filename = str(joiner_filename).replace(".onnx", "_decoder_proj.onnx") encoder_out_dim = joiner_model.encoder_proj.weight.shape[1] decoder_out_dim = joiner_model.decoder_proj.weight.shape[1] @@ -616,8 +607,7 @@ def main(): ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -715,9 +705,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/gigaspeech.py b/egs/librispeech/ASR/pruned_transducer_stateless3/gigaspeech.py index 36f32c6b38..598434f549 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/gigaspeech.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/gigaspeech.py @@ -52,18 +52,14 @@ def train_XL_cuts(self) -> CutSet: ) pattern = re.compile(r"gigaspeech_cuts_XL.([0-9]+).jsonl.gz") - idx_filenames = [ - (int(pattern.search(f).group(1)), f) for f in filenames - ] + idx_filenames = [(int(pattern.search(f).group(1)), f) for f in filenames] idx_filenames = sorted(idx_filenames, key=lambda x: x[0]) sorted_filenames = [f[1] for f in idx_filenames] logging.info(f"Loading {len(sorted_filenames)} splits") - return lhotse.combine( - lhotse.load_manifest_lazy(p) for p in sorted_filenames - ) + return lhotse.combine(lhotse.load_manifest_lazy(p) for p in sorted_filenames) def train_L_cuts(self) -> CutSet: f = self.manifest_dir / "gigaspeech_cuts_L_raw.jsonl.gz" diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/jit_pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless3/jit_pretrained.py index 162f8c7dbe..86cb45c09c 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/jit_pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/jit_pretrained.py @@ -143,8 +143,7 @@ def read_sound_files( for f in filenames: wave, sample_rate = torchaudio.load(f) assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" + f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}" ) # We use only the first channel ans.append(wave[0]) @@ -330,9 +329,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/model.py b/egs/librispeech/ASR/pruned_transducer_stateless3/model.py index 7852f84e97..d45f6dadc9 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/model.py @@ -84,9 +84,7 @@ def __init__( self.decoder_giga = decoder_giga self.joiner_giga = joiner_giga - self.simple_am_proj = ScaledLinear( - encoder_dim, vocab_size, initial_speed=0.5 - ) + self.simple_am_proj = ScaledLinear(encoder_dim, vocab_size, initial_speed=0.5) self.simple_lm_proj = ScaledLinear(decoder_dim, vocab_size) if decoder_giga is not None: @@ -190,9 +188,7 @@ def forward( y_padded = y.pad(mode="constant", padding_value=0) y_padded = y_padded.to(torch.int64) - boundary = torch.zeros( - (x.size(0), 4), dtype=torch.int64, device=x.device - ) + boundary = torch.zeros((x.size(0), 4), dtype=torch.int64, device=x.device) boundary[:, 2] = y_lens boundary[:, 3] = encoder_out_lens diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_check.py b/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_check.py index d03d1d7ef1..163d737e3d 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_check.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_check.py @@ -203,9 +203,7 @@ def test_joiner( ) # Now test encoder_proj - joiner_encoder_proj_inputs = { - encoder_proj_input_name: encoder_out.numpy() - } + joiner_encoder_proj_inputs = {encoder_proj_input_name: encoder_out.numpy()} joiner_encoder_proj_out = joiner_encoder_proj_session.run( [encoder_proj_output_name], joiner_encoder_proj_inputs )[0] @@ -214,16 +212,10 @@ def test_joiner( torch_joiner_encoder_proj_out = model.joiner.encoder_proj(encoder_out) assert torch.allclose( joiner_encoder_proj_out, torch_joiner_encoder_proj_out, atol=1e-5 - ), ( - (joiner_encoder_proj_out - torch_joiner_encoder_proj_out) - .abs() - .max() - ) + ), ((joiner_encoder_proj_out - torch_joiner_encoder_proj_out).abs().max()) # Now test decoder_proj - joiner_decoder_proj_inputs = { - decoder_proj_input_name: decoder_out.numpy() - } + joiner_decoder_proj_inputs = {decoder_proj_input_name: decoder_out.numpy()} joiner_decoder_proj_out = joiner_decoder_proj_session.run( [decoder_proj_output_name], joiner_decoder_proj_inputs )[0] @@ -232,11 +224,7 @@ def test_joiner( torch_joiner_decoder_proj_out = model.joiner.decoder_proj(decoder_out) assert torch.allclose( joiner_decoder_proj_out, torch_joiner_decoder_proj_out, atol=1e-5 - ), ( - (joiner_decoder_proj_out - torch_joiner_decoder_proj_out) - .abs() - .max() - ) + ), ((joiner_decoder_proj_out - torch_joiner_decoder_proj_out).abs().max()) @torch.no_grad() @@ -288,9 +276,7 @@ def main(): if __name__ == "__main__": torch.manual_seed(20220727) - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_pretrained.py index ea5d4e674d..825c6510bb 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_pretrained.py @@ -141,8 +141,7 @@ def read_sound_files( for f in filenames: wave, sample_rate = torchaudio.load(f) assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" + f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}" ) # We use only the first channel ans.append(wave[0]) @@ -191,11 +190,7 @@ def greedy_search( projected_encoder_out = joiner_encoder_proj.run( [joiner_encoder_proj.get_outputs()[0].name], - { - joiner_encoder_proj.get_inputs()[ - 0 - ].name: packed_encoder_out.data.numpy() - }, + {joiner_encoder_proj.get_inputs()[0].name: packed_encoder_out.data.numpy()}, )[0] blank_id = 0 # hard-code to 0 @@ -382,9 +377,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless3/pretrained.py index 19b636a237..77bd6d13db 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/pretrained.py @@ -177,8 +177,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -232,8 +231,7 @@ def read_sound_files( for f in filenames: wave, sample_rate = torchaudio.load(f) assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" + f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}" ) # We use only the first channel ans.append(wave[0]) @@ -302,9 +300,7 @@ def main(): features = fbank(waves) feature_lengths = [f.size(0) for f in features] - features = pad_sequence( - features, batch_first=True, padding_value=math.log(1e-10) - ) + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) feature_lengths = torch.tensor(feature_lengths, device=device) @@ -391,9 +387,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/scaling_converter.py b/egs/librispeech/ASR/pruned_transducer_stateless3/scaling_converter.py index 1e6022b57a..b712eeda03 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/scaling_converter.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/scaling_converter.py @@ -234,9 +234,7 @@ def scaled_lstm_to_lstm(scaled_lstm: ScaledLSTM) -> nn.LSTM: assert lstm._flat_weights_names == scaled_lstm._flat_weights_names for idx in range(len(scaled_lstm._flat_weights_names)): - scaled_weight = ( - scaled_lstm._flat_weights[idx] * scaled_lstm._scales[idx].exp() - ) + scaled_weight = scaled_lstm._flat_weights[idx] * scaled_lstm._scales[idx].exp() lstm._flat_weights[idx].data.copy_(scaled_weight) return lstm diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/streaming_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless3/streaming_decode.py index 10bb44e002..e85d2060a3 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/streaming_decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/streaming_decode.py @@ -52,11 +52,7 @@ from torch.nn.utils.rnn import pad_sequence from train import add_model_arguments, get_params, get_transducer_model -from icefall.checkpoint import ( - average_checkpoints, - find_checkpoints, - load_checkpoint, -) +from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint from icefall.utils import ( AttributeDict, setup_logger, @@ -163,8 +159,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( @@ -272,9 +267,7 @@ def decode_one_chunk( encoder_out = model.joiner.encoder_proj(encoder_out) if params.decoding_method == "greedy_search": - greedy_search( - model=model, encoder_out=encoder_out, streams=decode_streams - ) + greedy_search(model=model, encoder_out=encoder_out, streams=decode_streams) elif params.decoding_method == "fast_beam_search": processed_lens = processed_lens + encoder_out_lens fast_beam_search_one_best( @@ -294,9 +287,7 @@ def decode_one_chunk( num_active_paths=params.num_active_paths, ) else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") states = [torch.unbind(states[0], dim=2), torch.unbind(states[1], dim=2)] @@ -352,9 +343,7 @@ def decode_dataset( decode_results = [] # Contain decode streams currently running. decode_streams = [] - initial_states = model.encoder.get_init_state( - params.left_context, device=device - ) + initial_states = model.encoder.get_init_state(params.left_context, device=device) for num, cut in enumerate(cuts): # each utterance has a DecodeStream. decode_stream = DecodeStream( @@ -426,9 +415,7 @@ def decode_dataset( elif params.decoding_method == "modified_beam_search": key = f"num_active_paths_{params.num_active_paths}" else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") return {key: decode_results} @@ -461,8 +448,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -535,8 +521,7 @@ def main(): ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/test_onnx.py b/egs/librispeech/ASR/pruned_transducer_stateless3/test_onnx.py index 66ffbd3ec1..598fcf3442 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/test_onnx.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/test_onnx.py @@ -90,9 +90,7 @@ def test_conv2d_subsampling(): onnx_y = torch.from_numpy(onnx_y) torch_y = jit_model(x) - assert torch.allclose(onnx_y, torch_y, atol=1e-05), ( - (onnx_y - torch_y).abs().max() - ) + assert torch.allclose(onnx_y, torch_y, atol=1e-05), (onnx_y - torch_y).abs().max() os.remove(filename) @@ -147,9 +145,7 @@ def test_rel_pos(): onnx_pos_emb = torch.from_numpy(onnx_pos_emb) torch_y, torch_pos_emb = jit_model(x) - assert torch.allclose(onnx_y, torch_y, atol=1e-05), ( - (onnx_y - torch_y).abs().max() - ) + assert torch.allclose(onnx_y, torch_y, atol=1e-05), (onnx_y - torch_y).abs().max() assert torch.allclose(onnx_pos_emb, torch_pos_emb, atol=1e-05), ( (onnx_pos_emb - torch_pos_emb).abs().max() @@ -197,9 +193,7 @@ def test_conformer_encoder_layer(): encoder_layer.eval() encoder_layer = convert_scaled_to_non_scaled(encoder_layer, inplace=True) - jit_model = torch.jit.trace( - encoder_layer, (x, pos_emb, src_key_padding_mask) - ) + jit_model = torch.jit.trace(encoder_layer, (x, pos_emb, src_key_padding_mask)) torch.onnx.export( encoder_layer, @@ -236,9 +230,7 @@ def test_conformer_encoder_layer(): onnx_y = torch.from_numpy(onnx_y) torch_y = jit_model(x, pos_emb, src_key_padding_mask) - assert torch.allclose(onnx_y, torch_y, atol=1e-05), ( - (onnx_y - torch_y).abs().max() - ) + assert torch.allclose(onnx_y, torch_y, atol=1e-05), (onnx_y - torch_y).abs().max() print(onnx_y.abs().sum(), torch_y.abs().sum(), onnx_y.shape, torch_y.shape) @@ -322,9 +314,7 @@ def test_conformer_encoder(): onnx_y = torch.from_numpy(onnx_y) torch_y = jit_model(x, pos_emb, src_key_padding_mask) - assert torch.allclose(onnx_y, torch_y, atol=1e-05), ( - (onnx_y - torch_y).abs().max() - ) + assert torch.allclose(onnx_y, torch_y, atol=1e-05), (onnx_y - torch_y).abs().max() print(onnx_y.abs().sum(), torch_y.abs().sum(), onnx_y.shape, torch_y.shape) @@ -379,9 +369,7 @@ def test_conformer(): onnx_y_lens = torch.from_numpy(onnx_y_lens) torch_y, torch_y_lens = jit_model(x, x_lens) - assert torch.allclose(onnx_y, torch_y, atol=1e-05), ( - (onnx_y - torch_y).abs().max() - ) + assert torch.allclose(onnx_y, torch_y, atol=1e-05), (onnx_y - torch_y).abs().max() assert torch.allclose(onnx_y_lens, torch_y_lens, atol=1e-05), ( (onnx_y_lens - torch_y_lens).abs().max() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/train.py b/egs/librispeech/ASR/pruned_transducer_stateless3/train.py index 44e96644a7..e9ceb60de8 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/train.py @@ -92,9 +92,7 @@ str2bool, ) -LRSchedulerType = Union[ - torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler -] +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] def add_model_arguments(parser: argparse.ArgumentParser): @@ -163,8 +161,7 @@ def get_parser(): "--full-libri", type=str2bool, default=True, - help="When enabled, use 960h LibriSpeech. " - "Otherwise, use 100h subset.", + help="When enabled, use 960h LibriSpeech. " "Otherwise, use 100h subset.", ) parser.add_argument( @@ -214,8 +211,7 @@ def get_parser(): "--initial-lr", type=float, default=0.003, - help="The initial learning rate. This value should not need " - "to be changed.", + help="The initial learning rate. This value should not need " "to be changed.", ) parser.add_argument( @@ -238,8 +234,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( @@ -262,8 +257,7 @@ def get_parser(): "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)" - "part.", + help="The scale to smooth the loss with am (output of encoder network)" "part.", ) parser.add_argument( @@ -672,9 +666,7 @@ def compute_loss( # If either all simple_loss or pruned_loss is inf or nan, # we stop the training process by raising an exception - if torch.all(~simple_loss_is_finite) or torch.all( - ~pruned_loss_is_finite - ): + if torch.all(~simple_loss_is_finite) or torch.all(~pruned_loss_is_finite): raise ValueError( "There are too many utterances in this batch " "leading to inf or nan losses." @@ -687,14 +679,9 @@ def compute_loss( # overwhelming the simple_loss and causing it to diverge, # in case it had not fully learned the alignment yet. pruned_loss_scale = ( - 0.0 - if warmup < 1.0 - else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) - ) - loss = ( - params.simple_loss_scale * simple_loss - + pruned_loss_scale * pruned_loss + 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) ) + loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss assert loss.requires_grad == is_training @@ -705,9 +692,7 @@ def compute_loss( # (1) The acutal subsampling factor is ((lens - 1) // 2 - 1) // 2 # (2) If some utterances in the batch lead to inf/nan loss, they # are filtered out. - info["frames"] = ( - (feature_lens // params.subsampling_factor).sum().item() - ) + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa info["utterances"] = feature.size(0) @@ -919,9 +904,7 @@ def train_one_epoch( f"train/current_{prefix}_", params.batch_idx_train, ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) libri_tot_loss.write_summary( tb_writer, "train/libri_tot_", params.batch_idx_train ) @@ -967,8 +950,7 @@ def remove_short_and_long_utt(c: Cut): # the threshold if c.duration < 1.0 or c.duration > 20.0: logging.warning( - f"Exclude cut with ID {c.id} from training. " - f"Duration: {c.duration}" + f"Exclude cut with ID {c.id} from training. " f"Duration: {c.duration}" ) return False @@ -1109,9 +1091,7 @@ def run(rank, world_size, args): train_giga_cuts = train_giga_cuts.repeat(times=None) if args.enable_musan: - cuts_musan = load_manifest( - Path(args.manifest_dir) / "musan_cuts.jsonl.gz" - ) + cuts_musan = load_manifest(Path(args.manifest_dir) / "musan_cuts.jsonl.gz") else: cuts_musan = None diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py index 4f043e5a66..2f9a60f13e 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py @@ -306,8 +306,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -427,9 +426,7 @@ def decode_one_batch( simulate_streaming=True, ) else: - encoder_out, encoder_out_lens = model.encoder( - x=feature, x_lens=feature_lens - ) + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) if ( params.decoding_method == "fast_beam_search" @@ -485,10 +482,7 @@ def decode_one_batch( nbest_scale=params.nbest_scale, return_timestamps=True, ) - elif ( - params.decoding_method == "greedy_search" - and params.max_sym_per_frame == 1 - ): + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: res = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -566,9 +560,7 @@ def decode_dataset( sp: spm.SentencePieceProcessor, word_table: Optional[k2.SymbolTable] = None, decoding_graph: Optional[k2.Fsa] = None, -) -> Dict[ - str, List[Tuple[str, List[str], List[str], List[float], List[float]]] -]: +) -> Dict[str, List[Tuple[str, List[str], List[str], List[float], List[float]]]]: """Decode dataset. Args: @@ -643,9 +635,7 @@ def decode_dataset( cut_ids, hyps, texts, timestamps_hyp, timestamps_ref ): ref_words = ref_text.split() - this_batch.append( - (cut_id, ref_words, hyp_words, time_ref, time_hyp) - ) + this_batch.append((cut_id, ref_words, hyp_words, time_ref, time_hyp)) results[name].extend(this_batch) @@ -654,9 +644,7 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -694,8 +682,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -722,9 +709,7 @@ def save_results( note = "" logging.info(s) - s = "\nFor {}, symbol-delay of different settings are:\n".format( - test_set_name - ) + s = "\nFor {}, symbol-delay of different settings are:\n".format(test_set_name) note = "\tbest for {}".format(test_set_name) for key, val in test_set_delays: s += "{}\tmean: {}s, variance: {}{}\n".format(key, val[0], val[1], note) @@ -773,9 +758,7 @@ def main(): if "LG" in params.decoding_method: params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" elif "beam_search" in params.decoding_method: - params.suffix += ( - f"-{params.decoding_method}-beam-size-{params.beam_size}" - ) + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -812,9 +795,9 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -841,9 +824,9 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -902,9 +885,7 @@ def main(): decoding_graph.scores *= params.ngram_lm_scale else: word_table = None - decoding_graph = k2.trivial_graph( - params.vocab_size - 1, device=device - ) + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) else: decoding_graph = None word_table = None diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/export.py b/egs/librispeech/ASR/pruned_transducer_stateless4/export.py index ce7518ceb4..64ef897339 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/export.py @@ -133,8 +133,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( @@ -183,9 +182,9 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -212,9 +211,9 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -282,9 +281,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/streaming_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless4/streaming_decode.py index 7af9ea9b8d..d74d1c89db 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/streaming_decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/streaming_decode.py @@ -175,8 +175,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( @@ -284,9 +283,7 @@ def decode_one_chunk( encoder_out = model.joiner.encoder_proj(encoder_out) if params.decoding_method == "greedy_search": - greedy_search( - model=model, encoder_out=encoder_out, streams=decode_streams - ) + greedy_search(model=model, encoder_out=encoder_out, streams=decode_streams) elif params.decoding_method == "fast_beam_search": processed_lens = processed_lens + encoder_out_lens fast_beam_search_one_best( @@ -306,9 +303,7 @@ def decode_one_chunk( num_active_paths=params.num_active_paths, ) else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") states = [torch.unbind(states[0], dim=2), torch.unbind(states[1], dim=2)] @@ -364,9 +359,7 @@ def decode_dataset( decode_results = [] # Contain decode streams currently running. decode_streams = [] - initial_states = model.encoder.get_init_state( - params.left_context, device=device - ) + initial_states = model.encoder.get_init_state(params.left_context, device=device) for num, cut in enumerate(cuts): # each utterance has a DecodeStream. decode_stream = DecodeStream( @@ -438,9 +431,7 @@ def decode_dataset( elif params.decoding_method == "modified_beam_search": key = f"num_active_paths_{params.num_active_paths}" else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") return {key: decode_results} @@ -473,8 +464,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -547,9 +537,9 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -576,9 +566,9 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/train.py b/egs/librispeech/ASR/pruned_transducer_stateless4/train.py index cf32e565b6..97f3e56a99 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/train.py @@ -101,9 +101,7 @@ str2bool, ) -LRSchedulerType = Union[ - torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler -] +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] def add_model_arguments(parser: argparse.ArgumentParser): @@ -239,8 +237,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( @@ -263,8 +260,7 @@ def get_parser(): "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)" - "part.", + help="The scale to smooth the loss with am (output of encoder network)" "part.", ) parser.add_argument( @@ -621,11 +617,7 @@ def compute_loss( warmup: a floating point value which increases throughout training; values >= 1.0 are fully warmed up and have all modules present. """ - device = ( - model.device - if isinstance(model, DDP) - else next(model.parameters()).device - ) + device = model.device if isinstance(model, DDP) else next(model.parameters()).device feature = batch["inputs"] # at entry, feature is (N, T, C) assert feature.ndim == 3 @@ -665,9 +657,7 @@ def compute_loss( # If either all simple_loss or pruned_loss is inf or nan, # we stop the training process by raising an exception - if torch.all(~simple_loss_is_finite) or torch.all( - ~pruned_loss_is_finite - ): + if torch.all(~simple_loss_is_finite) or torch.all(~pruned_loss_is_finite): raise ValueError( "There are too many utterances in this batch " "leading to inf or nan losses." @@ -680,14 +670,9 @@ def compute_loss( # overwhelming the simple_loss and causing it to diverge, # in case it had not fully learned the alignment yet. pruned_loss_scale = ( - 0.0 - if warmup < 1.0 - else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) - ) - loss = ( - params.simple_loss_scale * simple_loss - + pruned_loss_scale * pruned_loss + 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) ) + loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss assert loss.requires_grad == is_training @@ -698,9 +683,7 @@ def compute_loss( # (1) The acutal subsampling factor is ((lens - 1) // 2 - 1) // 2 # (2) If some utterances in the batch lead to inf/nan loss, they # are filtered out. - info["frames"] = ( - (feature_lens // params.subsampling_factor).sum().item() - ) + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa info["utterances"] = feature.size(0) @@ -879,9 +862,7 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -1013,8 +994,7 @@ def remove_short_and_long_utt(c: Cut): # the threshold if c.duration < 1.0 or c.duration > 20.0: logging.warning( - f"Exclude cut with ID {c.id} from training. " - f"Duration: {c.duration}" + f"Exclude cut with ID {c.id} from training. " f"Duration: {c.duration}" ) return False diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py index 427b062943..b3a7d71bc0 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py @@ -214,10 +214,7 @@ def get_init_state( (num_encoder_layers, cnn_module_kernel - 1, encoder_dim). NOTE: the returned tensors are on the given device. """ - if ( - len(self._init_state) == 2 - and self._init_state[0].size(1) == left_context - ): + if len(self._init_state) == 2 and self._init_state[0].size(1) == left_context: # Note: It is OK to share the init state as it is # not going to be modified by the model return self._init_state @@ -439,9 +436,7 @@ def __init__( self.d_model = d_model - self.self_attn = RelPositionMultiheadAttention( - d_model, nhead, dropout=0.0 - ) + self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0) self.feed_forward = nn.Sequential( ScaledLinear(d_model, dim_feedforward), @@ -459,9 +454,7 @@ def __init__( ScaledLinear(dim_feedforward, d_model, initial_scale=0.25), ) - self.conv_module = ConvolutionModule( - d_model, cnn_module_kernel, causal=causal - ) + self.conv_module = ConvolutionModule(d_model, cnn_module_kernel, causal=causal) self.norm_final = BasicNorm(d_model) @@ -527,9 +520,7 @@ def forward( src = src + self.dropout(src_att) # convolution module - conv, _ = self.conv_module( - src, src_key_padding_mask=src_key_padding_mask - ) + conv, _ = self.conv_module(src, src_key_padding_mask=src_key_padding_mask) src = src + self.dropout(conv) # feed forward module @@ -802,9 +793,7 @@ class RelPositionalEncoding(torch.nn.Module): """ - def __init__( - self, d_model: int, dropout_rate: float, max_len: int = 5000 - ) -> None: + def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None: """Construct an PositionalEncoding object.""" super(RelPositionalEncoding, self).__init__() self.d_model = d_model @@ -820,9 +809,7 @@ def extend_pe(self, x: Tensor, left_context: int = 0) -> None: # the length of self.pe is 2 * input_len - 1 if self.pe.size(1) >= x_size_1 * 2 - 1: # Note: TorchScript doesn't implement operator== for torch.Device - if self.pe.dtype != x.dtype or str(self.pe.device) != str( - x.device - ): + if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device): self.pe = self.pe.to(dtype=x.dtype, device=x.device) return # Suppose `i` means to the position of query vector and `j` means the @@ -848,9 +835,7 @@ def extend_pe(self, x: Tensor, left_context: int = 0) -> None: pe = torch.cat([pe_positive, pe_negative], dim=1) self.pe = pe.to(device=x.device, dtype=x.dtype) - def forward( - self, x: torch.Tensor, left_context: int = 0 - ) -> Tuple[Tensor, Tensor]: + def forward(self, x: torch.Tensor, left_context: int = 0) -> Tuple[Tensor, Tensor]: """Add positional encoding. Args: @@ -1118,9 +1103,9 @@ def multi_head_attention_forward( if torch.equal(query, key) and torch.equal(key, value): # self-attention - q, k, v = nn.functional.linear( - query, in_proj_weight, in_proj_bias - ).chunk(3, dim=-1) + q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk( + 3, dim=-1 + ) elif torch.equal(key, value): # encoder-decoder attention @@ -1189,31 +1174,22 @@ def multi_head_attention_forward( if attn_mask.dim() == 2: attn_mask = attn_mask.unsqueeze(0) if list(attn_mask.size()) != [1, query.size(0), key.size(0)]: - raise RuntimeError( - "The size of the 2D attn_mask is not correct." - ) + raise RuntimeError("The size of the 2D attn_mask is not correct.") elif attn_mask.dim() == 3: if list(attn_mask.size()) != [ bsz * num_heads, query.size(0), key.size(0), ]: - raise RuntimeError( - "The size of the 3D attn_mask is not correct." - ) + raise RuntimeError("The size of the 3D attn_mask is not correct.") else: raise RuntimeError( - "attn_mask's dimension {} is not supported".format( - attn_mask.dim() - ) + "attn_mask's dimension {} is not supported".format(attn_mask.dim()) ) # attn_mask's dim is 3 now. # convert ByteTensor key_padding_mask to bool - if ( - key_padding_mask is not None - and key_padding_mask.dtype == torch.uint8 - ): + if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: warnings.warn( "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead." ) @@ -1253,23 +1229,15 @@ def multi_head_attention_forward( # first compute matrix a and matrix c # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) - matrix_ac = torch.matmul( - q_with_bias_u, k - ) # (batch, head, time1, time2) + matrix_ac = torch.matmul(q_with_bias_u, k) # (batch, head, time1, time2) # compute matrix b and matrix d - matrix_bd = torch.matmul( - q_with_bias_v, p - ) # (batch, head, time1, 2*time1-1) + matrix_bd = torch.matmul(q_with_bias_v, p) # (batch, head, time1, 2*time1-1) matrix_bd = self.rel_shift(matrix_bd, left_context) - attn_output_weights = ( - matrix_ac + matrix_bd - ) # (batch, head, time1, time2) + attn_output_weights = matrix_ac + matrix_bd # (batch, head, time1, time2) - attn_output_weights = attn_output_weights.view( - bsz * num_heads, tgt_len, -1 - ) + attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1) assert list(attn_output_weights.size()) == [ bsz * num_heads, @@ -1310,21 +1278,17 @@ def multi_head_attention_forward( ): if attn_mask.size(0) != 1: attn_mask = attn_mask.view(bsz, num_heads, tgt_len, src_len) - combined_mask = attn_mask | key_padding_mask.unsqueeze( - 1 - ).unsqueeze(2) + combined_mask = attn_mask | key_padding_mask.unsqueeze(1).unsqueeze(2) else: # attn_mask.shape == (1, tgt_len, src_len) - combined_mask = attn_mask.unsqueeze( - 0 - ) | key_padding_mask.unsqueeze(1).unsqueeze(2) + combined_mask = attn_mask.unsqueeze(0) | key_padding_mask.unsqueeze( + 1 + ).unsqueeze(2) attn_output_weights = attn_output_weights.view( bsz, num_heads, tgt_len, src_len ) - attn_output_weights = attn_output_weights.masked_fill( - combined_mask, 0.0 - ) + attn_output_weights = attn_output_weights.masked_fill(combined_mask, 0.0) attn_output_weights = attn_output_weights.view( bsz * num_heads, tgt_len, src_len ) @@ -1336,13 +1300,9 @@ def multi_head_attention_forward( attn_output = torch.bmm(attn_output_weights, v) assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim] attn_output = ( - attn_output.transpose(0, 1) - .contiguous() - .view(tgt_len, bsz, embed_dim) - ) - attn_output = nn.functional.linear( - attn_output, out_proj_weight, out_proj_bias + attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) ) + attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias) if need_weights: # average attention weights over heads @@ -1481,16 +1441,12 @@ def forward( # manualy padding self.lorder zeros to the left x = nn.functional.pad(x, (self.lorder, 0), "constant", 0.0) else: - assert ( - not self.training - ), "Cache should be None in training time" + assert not self.training, "Cache should be None in training time" assert cache.size(0) == self.lorder x = torch.cat([cache.permute(1, 2, 0), x], dim=2) if right_context > 0: cache = x.permute(2, 0, 1)[ - -(self.lorder + right_context) : ( # noqa - -right_context - ), + -(self.lorder + right_context) : (-right_context), # noqa ..., ] else: @@ -1666,9 +1622,7 @@ def __init__( self.stddev = stddev self.final_log_weight = ( - torch.tensor( - (final_weight / (1 - final_weight)) * (self.num_inputs - 1) - ) + torch.tensor((final_weight / (1 - final_weight)) * (self.num_inputs - 1)) .log() .item() ) @@ -1765,16 +1719,14 @@ def _get_random_pure_weights( # final contains self.num_inputs - 1 in all elements final = torch.full((num_frames,), self.num_inputs - 1, device=device) # nonfinal contains random integers in [0..num_inputs - 2], these are for non-final weights. - nonfinal = torch.randint( - self.num_inputs - 1, (num_frames,), device=device - ) + nonfinal = torch.randint(self.num_inputs - 1, (num_frames,), device=device) indexes = torch.where( torch.rand(num_frames, device=device) < final_prob, final, nonfinal ) - ans = torch.nn.functional.one_hot( - indexes, num_classes=self.num_inputs - ).to(dtype=dtype) + ans = torch.nn.functional.one_hot(indexes, num_classes=self.num_inputs).to( + dtype=dtype + ) return ans def _get_random_mixed_weights( diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py index 22bcdd88e6..5c76afde61 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py @@ -303,8 +303,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( @@ -477,9 +476,7 @@ def decode_one_batch( simulate_streaming=True, ) else: - encoder_out, encoder_out_lens = model.encoder( - x=feature, x_lens=feature_lens - ) + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) hyps = [] @@ -545,10 +542,7 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif ( - params.decoding_method == "greedy_search" - and params.max_sym_per_frame == 1 - ): + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -696,9 +690,7 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -731,8 +723,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -787,9 +778,7 @@ def main(): if "LG" in params.decoding_method: params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" elif "beam_search" in params.decoding_method: - params.suffix += ( - f"-{params.decoding_method}-beam-size-{params.beam_size}" - ) + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -828,9 +817,9 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -857,9 +846,9 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -937,9 +926,7 @@ def main(): decoding_graph.scores *= params.ngram_lm_scale else: word_table = None - decoding_graph = k2.trivial_graph( - params.vocab_size - 1, device=device - ) + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) else: decoding_graph = None word_table = None diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/export.py b/egs/librispeech/ASR/pruned_transducer_stateless5/export.py index b2e5b430e2..f0bfd3b4ce 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/export.py @@ -133,8 +133,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( @@ -181,9 +180,9 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -210,9 +209,9 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -280,9 +279,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless5/pretrained.py index 1e100fcbd7..77ba0873b5 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/pretrained.py @@ -166,8 +166,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -199,8 +198,7 @@ def read_sound_files( for f in filenames: wave, sample_rate = torchaudio.load(f) assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" + f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}" ) # We use only the first channel ans.append(wave[0]) @@ -264,15 +262,11 @@ def main(): features = fbank(waves) feature_lengths = [f.size(0) for f in features] - features = pad_sequence( - features, batch_first=True, padding_value=math.log(1e-10) - ) + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) feature_lengths = torch.tensor(feature_lengths, device=device) - encoder_out, encoder_out_lens = model.encoder( - x=features, x_lens=feature_lengths - ) + encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lengths) num_waves = encoder_out.size(0) hyps = [] @@ -344,9 +338,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/streaming_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless5/streaming_decode.py index 6fee9483ef..e750f5554d 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/streaming_decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/streaming_decode.py @@ -175,8 +175,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( @@ -284,9 +283,7 @@ def decode_one_chunk( encoder_out = model.joiner.encoder_proj(encoder_out) if params.decoding_method == "greedy_search": - greedy_search( - model=model, encoder_out=encoder_out, streams=decode_streams - ) + greedy_search(model=model, encoder_out=encoder_out, streams=decode_streams) elif params.decoding_method == "fast_beam_search": processed_lens = processed_lens + encoder_out_lens fast_beam_search_one_best( @@ -306,9 +303,7 @@ def decode_one_chunk( num_active_paths=params.num_active_paths, ) else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") states = [torch.unbind(states[0], dim=2), torch.unbind(states[1], dim=2)] @@ -364,9 +359,7 @@ def decode_dataset( decode_results = [] # Contain decode streams currently running. decode_streams = [] - initial_states = model.encoder.get_init_state( - params.left_context, device=device - ) + initial_states = model.encoder.get_init_state(params.left_context, device=device) for num, cut in enumerate(cuts): # each utterance has a DecodeStream. decode_stream = DecodeStream( @@ -438,9 +431,7 @@ def decode_dataset( elif params.decoding_method == "modified_beam_search": key = f"num_active_paths_{params.num_active_paths}" else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") return {key: decode_results} @@ -473,8 +464,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -547,9 +537,9 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -576,9 +566,9 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/train.py b/egs/librispeech/ASR/pruned_transducer_stateless5/train.py index 179d9372ef..a1a810d3eb 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/train.py @@ -89,9 +89,7 @@ str2bool, ) -LRSchedulerType = Union[ - torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler -] +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] def add_model_arguments(parser: argparse.ArgumentParser): @@ -248,8 +246,7 @@ def get_parser(): "--initial-lr", type=float, default=0.003, - help="The initial learning rate. This value should not need " - "to be changed.", + help="The initial learning rate. This value should not need " "to be changed.", ) parser.add_argument( @@ -272,8 +269,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( @@ -296,8 +292,7 @@ def get_parser(): "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)" - "part.", + help="The scale to smooth the loss with am (output of encoder network)" "part.", ) parser.add_argument( @@ -645,11 +640,7 @@ def compute_loss( warmup: a floating point value which increases throughout training; values >= 1.0 are fully warmed up and have all modules present. """ - device = ( - model.device - if isinstance(model, DDP) - else next(model.parameters()).device - ) + device = model.device if isinstance(model, DDP) else next(model.parameters()).device feature = batch["inputs"] # at entry, feature is (N, T, C) assert feature.ndim == 3 @@ -690,9 +681,7 @@ def compute_loss( # If the batch contains more than 10 utterances AND # if either all simple_loss or pruned_loss is inf or nan, # we stop the training process by raising an exception - if torch.all(~simple_loss_is_finite) or torch.all( - ~pruned_loss_is_finite - ): + if torch.all(~simple_loss_is_finite) or torch.all(~pruned_loss_is_finite): raise ValueError( "There are too many utterances in this batch " "leading to inf or nan losses." @@ -705,14 +694,9 @@ def compute_loss( # overwhelming the simple_loss and causing it to diverge, # in case it had not fully learned the alignment yet. pruned_loss_scale = ( - 0.0 - if warmup < 1.0 - else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) - ) - loss = ( - params.simple_loss_scale * simple_loss - + pruned_loss_scale * pruned_loss + 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) ) + loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss assert loss.requires_grad == is_training @@ -723,9 +707,7 @@ def compute_loss( # (1) The acutal subsampling factor is ((lens - 1) // 2 - 1) // 2 # (2) If some utterances in the batch lead to inf/nan loss, they # are filtered out. - info["frames"] = ( - (feature_lens // params.subsampling_factor).sum().item() - ) + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa info["utterances"] = feature.size(0) @@ -908,9 +890,7 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -1023,7 +1003,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2 ** 22 + 2**22 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) @@ -1045,8 +1025,7 @@ def remove_short_and_long_utt(c: Cut): # the threshold if c.duration < 1.0 or c.duration > 20.0: logging.warning( - f"Exclude cut with ID {c.id} from training. " - f"Duration: {c.duration}" + f"Exclude cut with ID {c.id} from training. " f"Duration: {c.duration}" ) return False diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless6/conformer.py index 53788b3f77..0667e7f61e 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/conformer.py @@ -90,10 +90,7 @@ def __init__( output_layers = [] if middle_output_layer is not None: - assert ( - middle_output_layer >= 0 - and middle_output_layer < num_encoder_layers - ) + assert middle_output_layer >= 0 and middle_output_layer < num_encoder_layers output_layers.append(middle_output_layer) # The last layer is always needed. @@ -178,9 +175,7 @@ def __init__( self.d_model = d_model - self.self_attn = RelPositionMultiheadAttention( - d_model, nhead, dropout=0.0 - ) + self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0) self.feed_forward = nn.Sequential( ScaledLinear(d_model, dim_feedforward), @@ -362,9 +357,7 @@ class RelPositionalEncoding(torch.nn.Module): """ - def __init__( - self, d_model: int, dropout_rate: float, max_len: int = 5000 - ) -> None: + def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None: """Construct an PositionalEncoding object.""" super(RelPositionalEncoding, self).__init__() self.d_model = d_model @@ -379,9 +372,7 @@ def extend_pe(self, x: Tensor) -> None: # the length of self.pe is 2 * input_len - 1 if self.pe.size(1) >= x.size(1) * 2 - 1: # Note: TorchScript doesn't implement operator== for torch.Device - if self.pe.dtype != x.dtype or str(self.pe.device) != str( - x.device - ): + if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device): self.pe = self.pe.to(dtype=x.dtype, device=x.device) return # Suppose `i` means to the position of query vector and `j` means the @@ -656,9 +647,9 @@ def multi_head_attention_forward( if torch.equal(query, key) and torch.equal(key, value): # self-attention - q, k, v = nn.functional.linear( - query, in_proj_weight, in_proj_bias - ).chunk(3, dim=-1) + q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk( + 3, dim=-1 + ) elif torch.equal(key, value): # encoder-decoder attention @@ -727,31 +718,22 @@ def multi_head_attention_forward( if attn_mask.dim() == 2: attn_mask = attn_mask.unsqueeze(0) if list(attn_mask.size()) != [1, query.size(0), key.size(0)]: - raise RuntimeError( - "The size of the 2D attn_mask is not correct." - ) + raise RuntimeError("The size of the 2D attn_mask is not correct.") elif attn_mask.dim() == 3: if list(attn_mask.size()) != [ bsz * num_heads, query.size(0), key.size(0), ]: - raise RuntimeError( - "The size of the 3D attn_mask is not correct." - ) + raise RuntimeError("The size of the 3D attn_mask is not correct.") else: raise RuntimeError( - "attn_mask's dimension {} is not supported".format( - attn_mask.dim() - ) + "attn_mask's dimension {} is not supported".format(attn_mask.dim()) ) # attn_mask's dim is 3 now. # convert ByteTensor key_padding_mask to bool - if ( - key_padding_mask is not None - and key_padding_mask.dtype == torch.uint8 - ): + if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: warnings.warn( "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead." ) @@ -790,9 +772,7 @@ def multi_head_attention_forward( # first compute matrix a and matrix c # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) - matrix_ac = torch.matmul( - q_with_bias_u, k - ) # (batch, head, time1, time2) + matrix_ac = torch.matmul(q_with_bias_u, k) # (batch, head, time1, time2) # compute matrix b and matrix d matrix_bd = torch.matmul( @@ -800,13 +780,9 @@ def multi_head_attention_forward( ) # (batch, head, time1, 2*time1-1) matrix_bd = self.rel_shift(matrix_bd) - attn_output_weights = ( - matrix_ac + matrix_bd - ) # (batch, head, time1, time2) + attn_output_weights = matrix_ac + matrix_bd # (batch, head, time1, time2) - attn_output_weights = attn_output_weights.view( - bsz * num_heads, tgt_len, -1 - ) + attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1) assert list(attn_output_weights.size()) == [ bsz * num_heads, @@ -840,13 +816,9 @@ def multi_head_attention_forward( attn_output = torch.bmm(attn_output_weights, v) assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim] attn_output = ( - attn_output.transpose(0, 1) - .contiguous() - .view(tgt_len, bsz, embed_dim) - ) - attn_output = nn.functional.linear( - attn_output, out_proj_weight, out_proj_bias + attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) ) + attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias) if need_weights: # average attention weights over heads @@ -869,9 +841,7 @@ class ConvolutionModule(nn.Module): """ - def __init__( - self, channels: int, kernel_size: int, bias: bool = True - ) -> None: + def __init__(self, channels: int, kernel_size: int, bias: bool = True) -> None: """Construct an ConvolutionModule object.""" super(ConvolutionModule, self).__init__() # kernerl_size should be a odd number for 'SAME' padding diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless6/decode.py index 74df040060..3734564fe5 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/decode.py @@ -208,8 +208,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -267,9 +266,7 @@ def decode_one_batch( supervisions = batch["supervisions"] feature_lens = supervisions["num_frames"].to(device) - layer_results, encoder_out_lens = model.encoder( - x=feature, x_lens=feature_lens - ) + layer_results, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) encoder_out = layer_results[-1] hyps = [] @@ -285,10 +282,7 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif ( - params.decoding_method == "greedy_search" - and params.max_sym_per_frame == 1 - ): + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -411,9 +405,7 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -446,8 +438,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -490,9 +481,7 @@ def main(): params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-states-{params.max_states}" elif "beam_search" in params.decoding_method: - params.suffix += ( - f"-{params.decoding_method}-beam-size-{params.beam_size}" - ) + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -524,9 +513,9 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -553,9 +542,9 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/export.py b/egs/librispeech/ASR/pruned_transducer_stateless6/export.py index cff9c73776..3d1e7bc18b 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/export.py @@ -51,11 +51,7 @@ import torch from train import get_params, get_transducer_model -from icefall.checkpoint import ( - average_checkpoints, - find_checkpoints, - load_checkpoint, -) +from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint from icefall.utils import str2bool @@ -120,8 +116,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) return parser @@ -160,8 +155,7 @@ def main(): ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -209,9 +203,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/extract_codebook_index.py b/egs/librispeech/ASR/pruned_transducer_stateless6/extract_codebook_index.py index 21409287c0..86cf348779 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/extract_codebook_index.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/extract_codebook_index.py @@ -21,9 +21,10 @@ from pathlib import Path import torch -from vq_utils import CodebookIndexExtractor from asr_datamodule import LibriSpeechAsrDataModule from hubert_xlarge import HubertXlargeFineTuned +from vq_utils import CodebookIndexExtractor + from icefall.utils import AttributeDict, str2bool diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/hubert_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless6/hubert_decode.py index 49b5578142..b8440f90a7 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/hubert_decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/hubert_decode.py @@ -23,7 +23,6 @@ from typing import Dict, List, Tuple import torch - from asr_datamodule import LibriSpeechAsrDataModule from hubert_xlarge import HubertXlargeFineTuned @@ -99,9 +98,7 @@ def decode_dataset( if batch_idx % 20 == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -124,9 +121,7 @@ def save_results( ) test_set_wers[key] = wer - logging.info( - "Wrote detailed error stats to {}".format(errs_filename) - ) + logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = params.res_dir / f"wer-summary-{test_set_name}.txt" @@ -155,9 +150,7 @@ def main(): # reset some parameters needed by hubert. params.update(HubertXlargeFineTuned.get_params()) - params.res_dir = ( - params.exp_dir / f"ctc_greedy_search-{params.teacher_model_id}" - ) + params.res_dir = params.exp_dir / f"ctc_greedy_search-{params.teacher_model_id}" setup_logger(f"{params.res_dir}/log/log-ctc_greedy_search") logging.info("Decoding started") @@ -190,9 +183,7 @@ def main(): params=params, ) - save_results( - params=params, test_set_name=test_set, results_dict=results_dict - ) + save_results(params=params, test_set_name=test_set, results_dict=results_dict) logging.info("Done!") diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/hubert_xlarge.py b/egs/librispeech/ASR/pruned_transducer_stateless6/hubert_xlarge.py index 55ce7b00da..4f9417c9f7 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/hubert_xlarge.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/hubert_xlarge.py @@ -22,11 +22,7 @@ from typing import Dict, List, Tuple import torch -from fairseq import ( - checkpoint_utils, - tasks, - utils, -) +from fairseq import checkpoint_utils, tasks, utils from fairseq.data.data_utils import post_process from omegaconf import OmegaConf @@ -51,9 +47,7 @@ def _load_hubert_model(params: AttributeDict): "data": str(params.hubert_model_dir), } ) - model_path = Path(params.hubert_model_dir) / ( - params.teacher_model_id + ".pt" - ) + model_path = Path(params.hubert_model_dir) / (params.teacher_model_id + ".pt") task = tasks.setup_task(cfg_task) processor = task.target_dictionary models, saved_cfg = checkpoint_utils.load_model_ensemble( @@ -151,9 +145,7 @@ def extract_layers_result( supervisions = batch["supervisions"] num_samples = supervisions["num_samples"] B, T = features.shape - padding_mask = torch.arange(0, T).expand(B, T) > num_samples.reshape( - [-1, 1] - ) + padding_mask = torch.arange(0, T).expand(B, T) > num_samples.reshape([-1, 1]) padding_mask = padding_mask.to(self.params.device) features = features.to(self.params.device) @@ -163,9 +155,7 @@ def extract_layers_result( features = features.transpose(1, 2) features = self.w2v_model.layer_norm(features) - padding_mask = self.w2v_model.forward_padding_mask( - features, padding_mask - ) + padding_mask = self.w2v_model.forward_padding_mask(features, padding_mask) if self.w2v_model.post_extract_proj is not None: features = self.w2v_model.post_extract_proj(features) @@ -212,9 +202,7 @@ def ctc_greedy_search(self, batch): toks = encoder_out.argmax(dim=-1) blank = 0 toks = [tok.unique_consecutive() for tok in toks] - hyps = [ - self.processor.string(tok[tok != blank].int().cpu()) for tok in toks - ] + hyps = [self.processor.string(tok[tok != blank].int().cpu()) for tok in toks] hyps = [post_process(hyp, "letter") for hyp in hyps] return hyps diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/model.py b/egs/librispeech/ASR/pruned_transducer_stateless6/model.py index 7716d19cf0..daadb70c98 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/model.py @@ -69,9 +69,7 @@ def __init__( self.decoder = decoder self.joiner = joiner - self.simple_am_proj = ScaledLinear( - encoder_dim, vocab_size, initial_speed=0.5 - ) + self.simple_am_proj = ScaledLinear(encoder_dim, vocab_size, initial_speed=0.5) self.simple_lm_proj = ScaledLinear(decoder_dim, vocab_size) from icefall import is_module_available @@ -180,9 +178,7 @@ def forward( y_padded = y.pad(mode="constant", padding_value=0) y_padded = y_padded.to(torch.int64) - boundary = torch.zeros( - (x.size(0), 4), dtype=torch.int64, device=x.device - ) + boundary = torch.zeros((x.size(0), 4), dtype=torch.int64, device=x.device) boundary[:, 2] = y_lens boundary[:, 3] = x_lens @@ -237,9 +233,7 @@ def forward( return (simple_loss, pruned_loss, codebook_loss) @staticmethod - def concat_successive_codebook_indexes( - middle_layer_output, codebook_indexes - ): + def concat_successive_codebook_indexes(middle_layer_output, codebook_indexes): # Output rate of hubert is 50 frames per second, # while that of current encoder is 25. # Following code handling two issues: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/train.py b/egs/librispeech/ASR/pruned_transducer_stateless6/train.py index f717d85fb8..a24becb14f 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/train.py @@ -101,9 +101,7 @@ str2bool, ) -LRSchedulerType = Union[ - torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler -] +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] def get_parser(): @@ -203,8 +201,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( @@ -227,8 +224,7 @@ def get_parser(): "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)" - "part.", + help="The scale to smooth the loss with am (output of encoder network)" "part.", ) parser.add_argument( @@ -569,9 +565,7 @@ def save_checkpoint( def extract_codebook_indexes(batch): cuts = batch["supervisions"]["cut"] # -100 is identical to ignore_value in CE loss computation. - cuts_pre_mixed = [ - c if isinstance(c, MonoCut) else c.tracks[0].cut for c in cuts - ] + cuts_pre_mixed = [c if isinstance(c, MonoCut) else c.tracks[0].cut for c in cuts] codebook_indexes, codebook_indexes_lens = collate_custom_field( cuts_pre_mixed, "codebook_indexes", pad_value=-100 ) @@ -604,11 +598,7 @@ def compute_loss( warmup: a floating point value which increases throughout training; values >= 1.0 are fully warmed up and have all modules present. """ - device = ( - model.device - if isinstance(model, DDP) - else next(model.parameters()).device - ) + device = model.device if isinstance(model, DDP) else next(model.parameters()).device feature = batch["inputs"] # at entry, feature is (N, T, C) assert feature.ndim == 3 @@ -655,9 +645,7 @@ def compute_loss( # If the batch contains more than 10 utterances AND # if either all simple_loss or pruned_loss is inf or nan, # we stop the training process by raising an exception - if torch.all(~simple_loss_is_finite) or torch.all( - ~pruned_loss_is_finite - ): + if torch.all(~simple_loss_is_finite) or torch.all(~pruned_loss_is_finite): raise ValueError( "There are too many utterances in this batch " "leading to inf or nan losses." @@ -670,14 +658,9 @@ def compute_loss( # overwhelming the simple_loss and causing it to diverge, # in case it had not fully learned the alignment yet. pruned_loss_scale = ( - 0.0 - if warmup < 1.0 - else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) - ) - loss = ( - params.simple_loss_scale * simple_loss - + pruned_loss_scale * pruned_loss + 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) ) + loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss if is_training and params.enable_distillation: assert codebook_loss is not None loss += params.codebook_loss_scale * codebook_loss @@ -690,9 +673,7 @@ def compute_loss( # (1) The acutal subsampling factor is ((lens - 1) // 2 - 1) // 2 # (2) If some utterances in the batch lead to inf/nan loss, they # are filtered out. - info["frames"] = ( - (feature_lens // params.subsampling_factor).sum().item() - ) + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa info["utterances"] = feature.size(0) @@ -873,9 +854,7 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -1007,8 +986,7 @@ def remove_short_and_long_utt(c: Cut): # the threshold if c.duration < 1.0 or c.duration > 20.0: logging.warning( - f"Exclude cut with ID {c.id} from training. " - f"Duration: {c.duration}" + f"Exclude cut with ID {c.id} from training. " f"Duration: {c.duration}" ) return False diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/vq_utils.py b/egs/librispeech/ASR/pruned_transducer_stateless6/vq_utils.py index 47cf2b14b4..97a83b9744 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/vq_utils.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/vq_utils.py @@ -68,9 +68,7 @@ def __init__(self, params: AttributeDict): def init_dirs(self): # vq_dir is the root dir for quantization, containing: # training data, trained quantizer, and extracted codebook indexes - self.vq_dir = ( - self.params.exp_dir / f"vq/{self.params.teacher_model_id}/" - ) + self.vq_dir = self.params.exp_dir / f"vq/{self.params.teacher_model_id}/" self.vq_dir.mkdir(parents=True, exist_ok=True) # manifest_dir contains: @@ -208,9 +206,7 @@ def minibatch_generator(data: torch.Tensor, repeat: bool): start = cur_offset % (data.shape[0] + 1 - B) end = start + B cur_offset += B - yield data[start:end, :].to(self.params.device).to( - dtype=torch.float - ) + yield data[start:end, :].to(self.params.device).to(dtype=torch.float) for x in minibatch_generator(train, repeat=True): trainer.step(x) @@ -227,9 +223,7 @@ def split_ori_manifests(self): """ for subset in self.params.subsets: logging.info(f"About to split {subset}.") - ori_manifest = ( - f"./data/fbank/librispeech_cuts_train-{subset}.jsonl.gz" - ) + ori_manifest = f"./data/fbank/librispeech_cuts_train-{subset}.jsonl.gz" split_cmd = f"lhotse split {self.params.world_size} {ori_manifest} {self.manifest_dir}" os.system(f"{split_cmd}") @@ -240,16 +234,13 @@ def join_manifests(self): logging.info("Start to join manifest files.") for subset in self.params.subsets: vq_manifest_path = ( - self.dst_manifest_dir - / f"librispeech_cuts_train-{subset}-vq.jsonl.gz" + self.dst_manifest_dir / f"librispeech_cuts_train-{subset}-vq.jsonl.gz" ) ori_manifest_path = ( - self.ori_manifest_dir - / f"librispeech_cuts_train-{subset}.jsonl.gz" + self.ori_manifest_dir / f"librispeech_cuts_train-{subset}.jsonl.gz" ) dst_vq_manifest_path = ( - self.dst_manifest_dir - / f"librispeech_cuts_train-{subset}.jsonl.gz" + self.dst_manifest_dir / f"librispeech_cuts_train-{subset}.jsonl.gz" ) cuts_vq = load_manifest(vq_manifest_path) cuts_ori = load_manifest(ori_manifest_path) @@ -269,8 +260,7 @@ def merge_vq_manifests(self): for subset in self.params.subsets: vq_manifests = f"{self.manifest_dir}/with_codebook_indexes-librispeech-cuts_train-{subset}*.jsonl.gz" dst_vq_manifest = ( - self.dst_manifest_dir - / f"librispeech_cuts_train-{subset}-vq.jsonl.gz" + self.dst_manifest_dir / f"librispeech_cuts_train-{subset}-vq.jsonl.gz" ) if 1 == self.params.world_size: merge_cmd = f"cp {vq_manifests} {dst_vq_manifest}" @@ -330,9 +320,7 @@ def quantizer(self): def load_ori_dl(self, subset): if self.params.world_size == 1: - ori_manifest_path = ( - f"./data/fbank/librispeech_cuts_train-{subset}.jsonl.gz" - ) + ori_manifest_path = f"./data/fbank/librispeech_cuts_train-{subset}.jsonl.gz" else: ori_manifest_path = ( self.manifest_dir diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless7/decode.py index 06c5863f19..162966df81 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/decode.py @@ -272,8 +272,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -393,9 +392,7 @@ def decode_one_batch( simulate_streaming=True, ) else: - encoder_out, encoder_out_lens = model.encoder( - x=feature, x_lens=feature_lens - ) + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) hyps = [] @@ -454,10 +451,7 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif ( - params.decoding_method == "greedy_search" - and params.max_sym_per_frame == 1 - ): + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -588,9 +582,7 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -623,8 +615,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -679,9 +670,7 @@ def main(): if "LG" in params.decoding_method: params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" elif "beam_search" in params.decoding_method: - params.suffix += ( - f"-{params.decoding_method}-beam-size-{params.beam_size}" - ) + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -718,9 +707,9 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -747,9 +736,9 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -808,9 +797,7 @@ def main(): decoding_graph.scores *= params.ngram_lm_scale else: word_table = None - decoding_graph = k2.trivial_graph( - params.vocab_size - 1, device=device - ) + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) else: decoding_graph = None word_table = None diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/decoder.py b/egs/librispeech/ASR/pruned_transducer_stateless7/decoder.py index 712dc8ce12..5f90e6375f 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/decoder.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/decoder.py @@ -69,7 +69,7 @@ def __init__( out_channels=decoder_dim, kernel_size=context_size, padding=0, - groups=decoder_dim//4, # group size == 4 + groups=decoder_dim // 4, # group size == 4 bias=False, ) @@ -91,9 +91,7 @@ def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor: if self.context_size > 1: embedding_out = embedding_out.permute(0, 2, 1) if need_pad is True: - embedding_out = F.pad( - embedding_out, pad=(self.context_size - 1, 0) - ) + embedding_out = F.pad(embedding_out, pad=(self.context_size - 1, 0)) else: # During inference time, there is no need to do extra padding # as we only need one output diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/export.py b/egs/librispeech/ASR/pruned_transducer_stateless7/export.py index 5744ea3ead..57af52fb13 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/export.py @@ -176,8 +176,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) add_model_arguments(parser) @@ -215,9 +214,9 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -244,9 +243,9 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -316,9 +315,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/jit_pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless7/jit_pretrained.py index e2405d5ef0..f469442eda 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/jit_pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/jit_pretrained.py @@ -94,8 +94,7 @@ def read_sound_files( for f in filenames: wave, sample_rate = torchaudio.load(f) assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" + f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}" ) # We use only the first channel ans.append(wave[0]) @@ -267,9 +266,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/joiner.py b/egs/librispeech/ASR/pruned_transducer_stateless7/joiner.py index 7d8de5afed..3ddac2cf2a 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/joiner.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/joiner.py @@ -56,9 +56,7 @@ def forward( assert encoder_out.shape[:-1] == decoder_out.shape[:-1] if project_input: - logit = self.encoder_proj(encoder_out) + self.decoder_proj( - decoder_out - ) + logit = self.encoder_proj(encoder_out) + self.decoder_proj(decoder_out) else: logit = encoder_out + decoder_out diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/model.py b/egs/librispeech/ASR/pruned_transducer_stateless7/model.py index 53cde6c6f4..0e59b0f2f6 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/model.py @@ -15,14 +15,15 @@ # limitations under the License. +import random + import k2 import torch import torch.nn as nn -import random from encoder_interface import EncoderInterface +from scaling import penalize_abs_values_gt from icefall.utils import add_sos -from scaling import penalize_abs_values_gt class Transducer(nn.Module): @@ -65,7 +66,8 @@ def __init__( self.joiner = joiner self.simple_am_proj = nn.Linear( - encoder_dim, vocab_size, + encoder_dim, + vocab_size, ) self.simple_lm_proj = nn.Linear(decoder_dim, vocab_size) @@ -133,18 +135,16 @@ def forward( y_padded = y.pad(mode="constant", padding_value=0) y_padded = y_padded.to(torch.int64) - boundary = torch.zeros( - (x.size(0), 4), dtype=torch.int64, device=x.device - ) + boundary = torch.zeros((x.size(0), 4), dtype=torch.int64, device=x.device) boundary[:, 2] = y_lens boundary[:, 3] = x_lens lm = self.simple_lm_proj(decoder_out) am = self.simple_am_proj(encoder_out) - #if self.training and random.random() < 0.25: + # if self.training and random.random() < 0.25: # lm = penalize_abs_values_gt(lm, 100.0, 1.0e-04) - #if self.training and random.random() < 0.25: + # if self.training and random.random() < 0.25: # am = penalize_abs_values_gt(am, 30.0, 1.0e-04) with torch.cuda.amp.autocast(enabled=False): diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py index bb8b0a0e38..8b90c9a0d8 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py @@ -14,17 +14,17 @@ # See the License for the specific language governing permissions and # limitations under the License. +import contextlib +import logging +import random from collections import defaultdict -from typing import List, Optional, Union, Tuple, List -from lhotse.utils import fix_random_seed +from typing import List, Optional, Tuple, Union + import torch +from lhotse.utils import fix_random_seed from scaling import ActivationBalancer -import random from torch import Tensor from torch.optim import Optimizer -import logging -import contextlib - class BatchedOptimizer(Optimizer): @@ -37,11 +37,10 @@ class BatchedOptimizer(Optimizer): Args: params: """ + def __init__(self, params, defaults): super(BatchedOptimizer, self).__init__(params, defaults) - - @contextlib.contextmanager def batched_params(self, param_group): """ @@ -73,7 +72,9 @@ def batched_params(self, param_group): group: a parameter group, which is a list of parameters; should be one of self.groups. """ - batches = defaultdict(list) # `batches` maps from tuple (dtype_as_str,*shape) to list of nn.Parameter + batches = defaultdict( + list + ) # `batches` maps from tuple (dtype_as_str,*shape) to list of nn.Parameter for p in param_group: key = (str(p.dtype), *p.shape) @@ -82,7 +83,7 @@ def batched_params(self, param_group): stacked_params_dict = dict() # turn batches into a list, in deterministic order. - batches = [ batches[key] for key in sorted(batches.keys()) ] + batches = [batches[key] for key in sorted(batches.keys())] # pairs will contain pairs of (stacked_param, state), one for each batch # in `batches`. pairs = [] @@ -94,77 +95,78 @@ def batched_params(self, param_group): # group. class Optimizer will take care of saving/loading state. state = self.state[p] p_stacked = torch.stack(batch) - grad = torch.stack([torch.zeros_like(p) if p.grad is None else p.grad for p in batch ]) + grad = torch.stack( + [torch.zeros_like(p) if p.grad is None else p.grad for p in batch] + ) p_stacked.grad = grad stacked_params_dict[key] = p_stacked pairs.append((p_stacked, state)) - yield pairs # <-- calling code will do the actual optimization here! + yield pairs # <-- calling code will do the actual optimization here! for ((stacked_params, _state), batch) in zip(pairs, batches): for i, p in enumerate(batch): # batch is list of Parameter p.copy_(stacked_params[i]) - class ScaledAdam(BatchedOptimizer): """ - Implements 'Scaled Adam', a variant of Adam where we scale each parameter's update - proportional to the norm of that parameter; and also learn the scale of the parameter, - in log space, subject to upper and lower limits (as if we had factored each parameter as - param = underlying_param * log_scale.exp()) - - - Args: - params: The parameters or param_groups to optimize (like other Optimizer subclasses) - lr: The learning rate. We will typically use a learning rate schedule that starts - at 0.03 and decreases over time, i.e. much higher than other common - optimizers. - clipping_scale: (e.g. 2.0) - A scale for gradient-clipping: if specified, the normalized gradients - over the whole model will be clipped to have 2-norm equal to - `clipping_scale` times the median 2-norm over the most recent period - of `clipping_update_period` minibatches. By "normalized gradients", - we mean after multiplying by the rms parameter value for this tensor - [for non-scalars]; this is appropriate because our update is scaled - by this quantity. - betas: beta1,beta2 are momentum constants for regular momentum, and moving sum-sq grad. - Must satisfy 0 < beta <= beta2 < 1. - scalar_lr_scale: A scaling factor on the learning rate, that we use to update the - scale of each parameter tensor and scalar parameters of the mode.. - If each parameter were decomposed - as p * p_scale.exp(), where (p**2).mean().sqrt() == 1.0, scalar_lr_scale - would be a the scaling factor on the learning rate of p_scale. - eps: A general-purpose epsilon to prevent division by zero - param_min_rms: Minimum root-mean-square value of parameter tensor, for purposes of - learning the scale on the parameters (we'll constrain the rms of each non-scalar - parameter tensor to be >= this value) - param_max_rms: Maximum root-mean-square value of parameter tensor, for purposes of - learning the scale on the parameters (we'll constrain the rms of each non-scalar - parameter tensor to be <= this value) - scalar_max: Maximum absolute value for scalar parameters (applicable if your - model has any parameters with numel() == 1). - size_update_period: The periodicity, in steps, with which we update the size (scale) - of the parameter tensor. This is provided to save a little time - in the update. - clipping_update_period: if clipping_scale is specified, this is the period + Implements 'Scaled Adam', a variant of Adam where we scale each parameter's update + proportional to the norm of that parameter; and also learn the scale of the parameter, + in log space, subject to upper and lower limits (as if we had factored each parameter as + param = underlying_param * log_scale.exp()) + + + Args: + params: The parameters or param_groups to optimize (like other Optimizer subclasses) + lr: The learning rate. We will typically use a learning rate schedule that starts + at 0.03 and decreases over time, i.e. much higher than other common + optimizers. + clipping_scale: (e.g. 2.0) + A scale for gradient-clipping: if specified, the normalized gradients + over the whole model will be clipped to have 2-norm equal to + `clipping_scale` times the median 2-norm over the most recent period + of `clipping_update_period` minibatches. By "normalized gradients", + we mean after multiplying by the rms parameter value for this tensor + [for non-scalars]; this is appropriate because our update is scaled + by this quantity. + betas: beta1,beta2 are momentum constants for regular momentum, and moving sum-sq grad. + Must satisfy 0 < beta <= beta2 < 1. + scalar_lr_scale: A scaling factor on the learning rate, that we use to update the + scale of each parameter tensor and scalar parameters of the mode.. + If each parameter were decomposed + as p * p_scale.exp(), where (p**2).mean().sqrt() == 1.0, scalar_lr_scale + would be a the scaling factor on the learning rate of p_scale. + eps: A general-purpose epsilon to prevent division by zero + param_min_rms: Minimum root-mean-square value of parameter tensor, for purposes of + learning the scale on the parameters (we'll constrain the rms of each non-scalar + parameter tensor to be >= this value) + param_max_rms: Maximum root-mean-square value of parameter tensor, for purposes of + learning the scale on the parameters (we'll constrain the rms of each non-scalar + parameter tensor to be <= this value) + scalar_max: Maximum absolute value for scalar parameters (applicable if your + model has any parameters with numel() == 1). + size_update_period: The periodicity, in steps, with which we update the size (scale) + of the parameter tensor. This is provided to save a little time + in the update. + clipping_update_period: if clipping_scale is specified, this is the period """ + def __init__( - self, - params, - lr=3e-02, - clipping_scale=None, - betas=(0.9, 0.98), - scalar_lr_scale=0.1, - eps=1.0e-08, - param_min_rms=1.0e-05, - param_max_rms=3.0, - scalar_max=10.0, - size_update_period=4, - clipping_update_period=100, + self, + params, + lr=3e-02, + clipping_scale=None, + betas=(0.9, 0.98), + scalar_lr_scale=0.1, + eps=1.0e-08, + param_min_rms=1.0e-05, + param_max_rms=3.0, + scalar_max=10.0, + size_update_period=4, + clipping_update_period=100, ): - defaults = dict( lr=lr, clipping_scale=clipping_scale, @@ -183,7 +185,6 @@ def __init__( def __setstate__(self, state): super(ScaledAdam, self).__setstate__(state) - @torch.no_grad() def step(self, closure=None): """Performs a single optimization step. @@ -206,7 +207,9 @@ def step(self, closure=None): # a regular parameter, and will have a .grad, but the 1st dim corresponds to # a stacking dim, it is not a real dim. - if len(batches[0][1]) == 0: # if len(first state) == 0: not yet initialized + if ( + len(batches[0][1]) == 0 + ): # if len(first state) == 0: not yet initialized clipping_scale = 1 else: clipping_scale = self._get_clipping_scale(group, batches) @@ -225,13 +228,9 @@ def step(self, closure=None): self._step_one_batch(group, p, state, clipping_scale) - return loss - def _init_state(self, - group: dict, - p: Tensor, - state: dict): + def _init_state(self, group: dict, p: Tensor, state: dict): """ Initializes state dict for parameter 'p'. Assumes that dim 0 of tensor p is actually the batch dimension, corresponding to batched-together @@ -247,7 +246,7 @@ def _init_state(self, state["step"] = 0 - kwargs = {'device':p.device, 'dtype':p.dtype} + kwargs = {"device": p.device, "dtype": p.dtype} # 'delta' implements conventional momentum. There are # several different kinds of update going on, so rather than @@ -255,36 +254,30 @@ def _init_state(self, # parameter-change "delta", which combines all forms of # update. this is equivalent to how it's done in Adam, # except for the first few steps. - state["delta"] = torch.zeros_like( - p, memory_format=torch.preserve_format - ) + state["delta"] = torch.zeros_like(p, memory_format=torch.preserve_format) batch_size = p.shape[0] numel = p.numel() // batch_size numel = p.numel() - if numel > 1: # "param_rms" just periodically records the scalar root-mean-square value of # the parameter tensor. # it has a shape like (batch_size, 1, 1, 1, 1) - param_rms = (p**2).mean(dim=list(range(1, p.ndim)), - keepdim=True).sqrt() + param_rms = (p**2).mean(dim=list(range(1, p.ndim)), keepdim=True).sqrt() state["param_rms"] = param_rms state["scale_exp_avg_sq"] = torch.zeros_like(param_rms) - state["scale_grads"] = torch.zeros(size_update_period, *param_rms.shape, - **kwargs) - + state["scale_grads"] = torch.zeros( + size_update_period, *param_rms.shape, **kwargs + ) # exp_avg_sq is the weighted sum of scaled gradients. as in Adam. - state["exp_avg_sq"] = torch.zeros_like( - p, memory_format=torch.preserve_format - ) + state["exp_avg_sq"] = torch.zeros_like(p, memory_format=torch.preserve_format) - def _get_clipping_scale(self, - group: dict, - pairs: List[Tuple[Tensor, dict]]) -> float: + def _get_clipping_scale( + self, group: dict, pairs: List[Tuple[Tensor, dict]] + ) -> float: """ Returns a scalar factor <= 1.0 that dictates gradient clipping, i.e. we will scale the gradients by this amount before applying the rest of the update. @@ -314,57 +307,65 @@ def _get_clipping_scale(self, if p.numel() == p.shape[0]: # a batch of scalars tot_sumsq += (grad**2).sum() # sum() to change shape [1] to [] else: - tot_sumsq += ((grad * state["param_rms"])**2).sum() + tot_sumsq += ((grad * state["param_rms"]) ** 2).sum() tot_norm = tot_sumsq.sqrt() - if not "model_norms" in first_state: - first_state["model_norms"] = torch.zeros(clipping_update_period, - device=p.device) + if "model_norms" not in first_state: + first_state["model_norms"] = torch.zeros( + clipping_update_period, device=p.device + ) first_state["model_norms"][step % clipping_update_period] = tot_norm if step % clipping_update_period == 0: # Print some stats. # We don't reach here if step == 0 because we would have returned # above. - sorted_norms = first_state["model_norms"].sort()[0].to('cpu') + sorted_norms = first_state["model_norms"].sort()[0].to("cpu") quartiles = [] for n in range(0, 5): - index = min(clipping_update_period - 1, - (clipping_update_period // 4) * n) + index = min( + clipping_update_period - 1, (clipping_update_period // 4) * n + ) quartiles.append(sorted_norms[index].item()) median = quartiles[2] threshold = clipping_scale * median first_state["model_norm_threshold"] = threshold - percent_clipped = (first_state["num_clipped"] * 100.0 / clipping_update_period - if "num_clipped" in first_state else 0.0) + percent_clipped = ( + first_state["num_clipped"] * 100.0 / clipping_update_period + if "num_clipped" in first_state + else 0.0 + ) first_state["num_clipped"] = 0 - quartiles = ' '.join([ '%.3e' % x for x in quartiles ]) - logging.info(f"Clipping_scale={clipping_scale}, grad-norm quartiles {quartiles}, " - f"threshold={threshold:.3e}, percent-clipped={percent_clipped:.1f}") + quartiles = " ".join(["%.3e" % x for x in quartiles]) + logging.info( + f"Clipping_scale={clipping_scale}, grad-norm quartiles {quartiles}, " + f"threshold={threshold:.3e}, percent-clipped={percent_clipped:.1f}" + ) if step < clipping_update_period: return 1.0 # We have not yet estimated a norm to clip to. else: try: model_norm_threshold = first_state["model_norm_threshold"] - except: - logging.info("Warning: model_norm_threshold not in state: possibly " - "you changed config when restarting, adding clipping_scale option?") + except KeyError: + logging.info( + "Warning: model_norm_threshold not in state: possibly " + "you changed config when restarting, adding clipping_scale option?" + ) return 1.0 - ans = min(1.0,(model_norm_threshold / (tot_norm + 1.0e-20)).item()) + ans = min(1.0, (model_norm_threshold / (tot_norm + 1.0e-20)).item()) if ans < 1.0: first_state["num_clipped"] += 1 if ans < 0.1: - logging.warn(f"Scaling gradients by {ans}, model_norm_threshold={model_norm_threshold}") + logging.warn( + f"Scaling gradients by {ans}, model_norm_threshold={model_norm_threshold}" + ) return ans - - def _step_one_batch(self, - group: dict, - p: Tensor, - state: dict, - clipping_scale: float): + def _step_one_batch( + self, group: dict, p: Tensor, state: dict, clipping_scale: float + ): """ Do the step for one parameter, which is actually going to be a batch of `real` parameters, with dim 0 as the batch dim. @@ -391,17 +392,18 @@ def _step_one_batch(self, # Update the size/scale of p, and set param_rms scale_grads = state["scale_grads"] scale_grads[step % size_update_period] = (p * grad).sum( - dim=list(range(1, p.ndim)), keepdim=True) + dim=list(range(1, p.ndim)), keepdim=True + ) if step % size_update_period == size_update_period - 1: param_rms = state["param_rms"] # shape: (batch_size, 1, 1, ..) - param_rms.copy_((p**2).mean(dim=list(range(1, p.ndim)), - keepdim=True).sqrt()) + param_rms.copy_( + (p**2).mean(dim=list(range(1, p.ndim)), keepdim=True).sqrt() + ) if step > 0: # self._size_update() learns the overall scale on the # parameter, by shrinking or expanding it. self._size_update(group, scale_grads, p, state) - if numel == 1: # For parameters with 1 element we just use regular Adam. # Updates delta. @@ -411,24 +413,21 @@ def _step_one_batch(self, state["step"] = step + 1 - - def _size_update(self, - group: dict, - scale_grads: Tensor, - p: Tensor, - state: dict) -> None: + def _size_update( + self, group: dict, scale_grads: Tensor, p: Tensor, state: dict + ) -> None: """ - Called only where p.numel() > 1, this updates the scale of the parameter. - If we imagine: p = underlying_param * scale.exp(), and we are doing - gradient descent on underlying param and on scale, this function does the update - on `scale`. - - Args: - group: dict to look up configuration values - scale_grads: a tensor of shape (size_update_period, batch_size, 1, 1,...) containing - grads w.r.t. the scales. - p: The parameter to update - state: The state-dict of p + Called only where p.numel() > 1, this updates the scale of the parameter. + If we imagine: p = underlying_param * scale.exp(), and we are doing + gradient descent on underlying param and on scale, this function does the update + on `scale`. + + Args: + group: dict to look up configuration values + scale_grads: a tensor of shape (size_update_period, batch_size, 1, 1,...) containing + grads w.r.t. the scales. + p: The parameter to update + state: The state-dict of p """ param_rms = state["param_rms"] @@ -443,25 +442,28 @@ def _size_update(self, size_update_period = scale_grads.shape[0] # correct beta2 for the size update period: we will have # faster decay at this level. - beta2_corr = beta2 ** size_update_period + beta2_corr = beta2**size_update_period scale_exp_avg_sq = state["scale_exp_avg_sq"] # shape: (batch_size, 1, 1, ..) scale_exp_avg_sq.mul_(beta2_corr).add_( - (scale_grads ** 2).mean(dim=0), # mean over dim `size_update_period` - alpha=1-beta2_corr) # shape is (batch_size, 1, 1, ...) + (scale_grads**2).mean(dim=0), # mean over dim `size_update_period` + alpha=1 - beta2_corr, + ) # shape is (batch_size, 1, 1, ...) # The 1st time we reach here is when size_step == 1. size_step = (step + 1) // size_update_period - bias_correction2 = 1 - beta2_corr ** size_step + bias_correction2 = 1 - beta2_corr**size_step # we don't bother with bias_correction1; this will help prevent divergence # at the start of training. denom = scale_exp_avg_sq.sqrt() + eps - scale_step = -size_lr * (bias_correction2 ** 0.5) * scale_grads.sum(dim=0) / denom + scale_step = ( + -size_lr * (bias_correction2**0.5) * scale_grads.sum(dim=0) / denom + ) - is_too_small = (param_rms < param_min_rms) - is_too_large = (param_rms > param_max_rms) + is_too_small = param_rms < param_min_rms + is_too_large = param_rms > param_max_rms # when the param gets too small, just don't shrink it any further. scale_step.masked_fill_(is_too_small, 0.0) @@ -469,13 +471,9 @@ def _size_update(self, scale_step.masked_fill_(is_too_large, -size_lr * size_update_period) delta = state["delta"] # the factor of (1-beta1) relates to momentum. - delta.add_(p * scale_step, alpha=(1-beta1)) - + delta.add_(p * scale_step, alpha=(1 - beta1)) - def _step(self, - group: dict, - p: Tensor, - state: dict): + def _step(self, group: dict, p: Tensor, state: dict): """ This function does the core update of self.step(), in the case where the members of the batch have more than 1 element. @@ -496,8 +494,7 @@ def _step(self, step = state["step"] exp_avg_sq = state["exp_avg_sq"] - exp_avg_sq.mul_(beta2).addcmul_(grad, grad, - value=(1-beta2)) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=(1 - beta2)) this_step = state["step"] - (state["zero_step"] if "zero_step" in state else 0) bias_correction2 = 1 - beta2 ** (this_step + 1) @@ -509,17 +506,13 @@ def _step(self, denom += eps grad = grad / denom - alpha = -lr * (1-beta1) * state["param_rms"].clamp(min=param_min_rms) + alpha = -lr * (1 - beta1) * state["param_rms"].clamp(min=param_min_rms) delta = state["delta"] delta.add_(grad * alpha) p.add_(delta) - - def _step_scalar(self, - group: dict, - p: Tensor, - state: dict): + def _step_scalar(self, group: dict, p: Tensor, state: dict): """ A simplified form of the core update for scalar tensors, where we cannot get a good estimate of the parameter rms. @@ -531,8 +524,7 @@ def _step_scalar(self, grad = p.grad exp_avg_sq = state["exp_avg_sq"] # shape: (batch_size,) - exp_avg_sq.mul_(beta2).addcmul_(grad, grad, - value=1-beta2) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) # bias_correction2 is like in Adam. Don't bother with bias_correction1; # slower update at the start will help stability anyway. @@ -540,12 +532,11 @@ def _step_scalar(self, denom = (exp_avg_sq / bias_correction2).sqrt() + eps delta = state["delta"] - delta.add_(grad / denom, alpha=-lr*(1-beta1)) + delta.add_(grad / denom, alpha=-lr * (1 - beta1)) p.clamp_(min=-scalar_max, max=scalar_max) p.add_(delta) - class LRScheduler(object): """ Base-class for learning rate schedulers where the learning-rate depends on both the @@ -555,18 +546,14 @@ class LRScheduler(object): def __init__(self, optimizer: Optimizer, verbose: bool = False): # Attach optimizer if not isinstance(optimizer, Optimizer): - raise TypeError( - "{} is not an Optimizer".format(type(optimizer).__name__) - ) + raise TypeError("{} is not an Optimizer".format(type(optimizer).__name__)) self.optimizer = optimizer self.verbose = verbose for group in optimizer.param_groups: group.setdefault("base_lr", group["lr"]) - self.base_lrs = [ - group["base_lr"] for group in optimizer.param_groups - ] + self.base_lrs = [group["base_lr"] for group in optimizer.param_groups] self.epoch = 0 self.batch = 0 @@ -680,13 +667,15 @@ def __init__( def get_lr(self): factor = ( - (self.batch ** 2 + self.lr_batches ** 2) / self.lr_batches ** 2 + (self.batch**2 + self.lr_batches**2) / self.lr_batches**2 ) ** -0.25 * ( - ((self.epoch ** 2 + self.lr_epochs ** 2) / self.lr_epochs ** 2) - ** -0.25 + ((self.epoch**2 + self.lr_epochs**2) / self.lr_epochs**2) ** -0.25 + ) + warmup_factor = ( + 1.0 + if self.batch >= self.warmup_batches + else 0.5 + 0.5 * (self.batch / self.warmup_batches) ) - warmup_factor = (1.0 if self.batch >= self.warmup_batches - else 0.5 + 0.5 * (self.batch / self.warmup_batches)) return [x * factor * warmup_factor for x in self.base_lrs] @@ -745,13 +734,14 @@ class Eve(Optimizer): parameters, if they fall below this we will stop applying weight decay. - .. _Adam\: A Method for Stochastic Optimization: + .. _Adam: A Method for Stochastic Optimization: https://arxiv.org/abs/1412.6980 .. _Decoupled Weight Decay Regularization: https://arxiv.org/abs/1711.05101 .. _On the Convergence of Adam and Beyond: https://openreview.net/forum?id=ryQu7f-RZ """ + def __init__( self, params, @@ -766,17 +756,11 @@ def __init__( if not 0.0 <= eps: raise ValueError("Invalid epsilon value: {}".format(eps)) if not 0.0 <= betas[0] < 1.0: - raise ValueError( - "Invalid beta parameter at index 0: {}".format(betas[0]) - ) + raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) if not 0.0 <= betas[1] < 1.0: - raise ValueError( - "Invalid beta parameter at index 1: {}".format(betas[1]) - ) + raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) if not 0 <= weight_decay <= 0.1: - raise ValueError( - "Invalid weight_decay value: {}".format(weight_decay) - ) + raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) if not 0 < target_rms <= 10.0: raise ValueError("Invalid target_rms value: {}".format(target_rms)) defaults = dict( @@ -812,9 +796,7 @@ def step(self, closure=None): # Perform optimization step grad = p.grad if grad.is_sparse: - raise RuntimeError( - "AdamW does not support sparse gradients" - ) + raise RuntimeError("AdamW does not support sparse gradients") state = self.state[p] @@ -841,7 +823,7 @@ def step(self, closure=None): # Decay the first and second moment running average coefficient exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) - denom = (exp_avg_sq.sqrt() * (bias_correction2 ** -0.5)).add_( + denom = (exp_avg_sq.sqrt() * (bias_correction2**-0.5)).add_( group["eps"] ) @@ -852,30 +834,31 @@ def step(self, closure=None): if p.numel() > 1: # avoid applying this weight-decay on "scaling factors" # (which are scalar). - is_above_target_rms = p.norm() > ( - target_rms * (p.numel() ** 0.5) - ) + is_above_target_rms = p.norm() > (target_rms * (p.numel() ** 0.5)) p.mul_(1 - (weight_decay * is_above_target_rms)) p.addcdiv_(exp_avg, denom, value=-step_size) if random.random() < 0.0005: - step = (exp_avg/denom) * step_size - logging.info(f"Delta rms = {(step**2).mean().item()}, shape = {step.shape}") - + step = (exp_avg / denom) * step_size + logging.info( + f"Delta rms = {(step**2).mean().item()}, shape = {step.shape}" + ) return loss def _test_scaled_adam(hidden_dim: int): import timeit + from scaling import ScaledLinear + E = 100 B = 4 T = 2 logging.info("in test_eve_cain") - #device = torch.device('cuda') - device = torch.device('cpu') + # device = torch.device('cuda') + device = torch.device("cpu") dtype = torch.float32 fix_random_seed(42) @@ -889,79 +872,92 @@ def _test_scaled_adam(hidden_dim: int): fix_random_seed(42) Linear = torch.nn.Linear if iter == 0 else ScaledLinear - m = torch.nn.Sequential(Linear(E, hidden_dim), - torch.nn.PReLU(), - Linear(hidden_dim, hidden_dim), - torch.nn.PReLU(), - Linear(hidden_dim, E), - ).to(device) - - train_pairs = [ (100.0 * torch.randn(B, T, E, device=device, dtype=dtype) * input_magnitudes, - torch.randn(B, T, E, device=device, dtype=dtype) * output_magnitudes) for _ in range(20) ] + m = torch.nn.Sequential( + Linear(E, hidden_dim), + torch.nn.PReLU(), + Linear(hidden_dim, hidden_dim), + torch.nn.PReLU(), + Linear(hidden_dim, E), + ).to(device) + + train_pairs = [ + ( + 100.0 + * torch.randn(B, T, E, device=device, dtype=dtype) + * input_magnitudes, + torch.randn(B, T, E, device=device, dtype=dtype) * output_magnitudes, + ) + for _ in range(20) + ] - if iter == 0: optim = Eve(m.parameters(), lr=0.003) - elif iter == 1: optim = ScaledAdam(m.parameters(), lr=0.03, clipping_scale=2.0) + if iter == 0: + optim = Eve(m.parameters(), lr=0.003) + elif iter == 1: + optim = ScaledAdam(m.parameters(), lr=0.03, clipping_scale=2.0) scheduler = Eden(optim, lr_batches=200, lr_epochs=5, verbose=False) - start = timeit.default_timer() avg_loss = 0.0 for epoch in range(180): scheduler.step_epoch() - #if epoch == 100 and iter in [2,3]: + # if epoch == 100 and iter in [2,3]: # optim.reset_speedup() # check it doesn't crash. - #if epoch == 130: + # if epoch == 130: # opts = diagnostics.TensorDiagnosticOptions( # 2 ** 22 # ) # allow 4 megabytes per sub-module # diagnostic = diagnostics.attach_diagnostics(m, opts) - - for n, (x,y) in enumerate(train_pairs): + for n, (x, y) in enumerate(train_pairs): y_out = m(x) - loss = ((y_out - y)**2).mean() * 100.0 + loss = ((y_out - y) ** 2).mean() * 100.0 if epoch == 0 and n == 0: avg_loss = loss.item() else: avg_loss = 0.98 * avg_loss + 0.02 * loss.item() if n == 0 and epoch % 5 == 0: - #norm1 = '%.2e' % (m[0].weight**2).mean().sqrt().item() - #norm1b = '%.2e' % (m[0].bias**2).mean().sqrt().item() - #norm2 = '%.2e' % (m[2].weight**2).mean().sqrt().item() - #norm2b = '%.2e' % (m[2].bias**2).mean().sqrt().item() - #scale1 = '%.2e' % (m[0].weight_scale.exp().item()) - #scale1b = '%.2e' % (m[0].bias_scale.exp().item()) - #scale2 = '%.2e' % (m[2].weight_scale.exp().item()) - #scale2b = '%.2e' % (m[2].bias_scale.exp().item()) + # norm1 = '%.2e' % (m[0].weight**2).mean().sqrt().item() + # norm1b = '%.2e' % (m[0].bias**2).mean().sqrt().item() + # norm2 = '%.2e' % (m[2].weight**2).mean().sqrt().item() + # norm2b = '%.2e' % (m[2].bias**2).mean().sqrt().item() + # scale1 = '%.2e' % (m[0].weight_scale.exp().item()) + # scale1b = '%.2e' % (m[0].bias_scale.exp().item()) + # scale2 = '%.2e' % (m[2].weight_scale.exp().item()) + # scale2b = '%.2e' % (m[2].bias_scale.exp().item()) lr = scheduler.get_last_lr()[0] - logging.info(f"Iter {iter}, epoch {epoch}, batch {n}, avg_loss {avg_loss:.4g}, lr={lr:.4e}") #, norms={norm1,norm1b,norm2,norm2b}") # scales={scale1,scale1b,scale2,scale2b} + logging.info( + f"Iter {iter}, epoch {epoch}, batch {n}, avg_loss {avg_loss:.4g}, lr={lr:.4e}" + ) # , norms={norm1,norm1b,norm2,norm2b}") # scales={scale1,scale1b,scale2,scale2b} loss.log().backward() optim.step() optim.zero_grad() scheduler.step_batch() - #diagnostic.print_diagnostics() + # diagnostic.print_diagnostics() stop = timeit.default_timer() logging.info(f"Iter={iter}, Time taken: {stop - start}") logging.info(f"last lr = {scheduler.get_last_lr()}") - #logging.info("state dict = ", scheduler.state_dict()) - #logging.info("optim state_dict = ", optim.state_dict()) + # logging.info("state dict = ", scheduler.state_dict()) + # logging.info("optim state_dict = ", optim.state_dict()) logging.info(f"input_magnitudes = {input_magnitudes}") logging.info(f"output_magnitudes = {output_magnitudes}") - if __name__ == "__main__": torch.set_num_threads(1) torch.set_num_interop_threads(1) logging.getLogger().setLevel(logging.INFO) import subprocess - s = subprocess.check_output("git status -uno .; git log -1; git diff HEAD .", shell=True) + + s = subprocess.check_output( + "git status -uno .; git log -1; git diff HEAD .", shell=True + ) logging.info(s) import sys + if len(sys.argv) > 1: hidden_dim = int(sys.argv[1]) else: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless7/pretrained.py index 7fe1e681a3..758e0c0366 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/pretrained.py @@ -177,8 +177,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -210,8 +209,7 @@ def read_sound_files( for f in filenames: wave, sample_rate = torchaudio.load(f) assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" + f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}" ) # We use only the first channel ans.append(wave[0]) @@ -275,15 +273,11 @@ def main(): features = fbank(waves) feature_lengths = [f.size(0) for f in features] - features = pad_sequence( - features, batch_first=True, padding_value=math.log(1e-10) - ) + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) feature_lengths = torch.tensor(feature_lengths, device=device) - encoder_out, encoder_out_lens = model.encoder( - x=features, x_lens=feature_lengths - ) + encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lengths) num_waves = encoder_out.size(0) hyps = [] @@ -355,9 +349,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index 50cedba568..6f63e0629f 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -16,12 +16,12 @@ import collections +import logging +import random +from functools import reduce from itertools import repeat from typing import Optional, Tuple, Union -from functools import reduce -import logging -import random import torch import torch.nn as nn import torch.nn.functional as F @@ -32,27 +32,24 @@ class ActivationBalancerFunction(torch.autograd.Function): @staticmethod def forward( - ctx, - x: Tensor, - scale_factor: Tensor, - sign_factor: Optional[Tensor], - channel_dim: int, + ctx, + x: Tensor, + scale_factor: Tensor, + sign_factor: Optional[Tensor], + channel_dim: int, ) -> Tensor: if channel_dim < 0: channel_dim += x.ndim ctx.channel_dim = channel_dim - xgt0 = (x > 0) + xgt0 = x > 0 if sign_factor is None: ctx.save_for_backward(xgt0, scale_factor) else: ctx.save_for_backward(xgt0, scale_factor, sign_factor) return x - @staticmethod - def backward( - ctx, x_grad: Tensor - ) -> Tuple[Tensor, None, None, None]: + def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None]: if len(ctx.saved_tensors) == 3: xgt0, scale_factor, sign_factor = ctx.saved_tensors for _ in range(ctx.channel_dim, x_grad.ndim - 1): @@ -65,14 +62,22 @@ def backward( scale_factor = scale_factor.unsqueeze(-1) factor = scale_factor * (xgt0.to(x_grad.dtype) - 0.5) neg_delta_grad = x_grad.abs() * factor - return x_grad - neg_delta_grad, None, None, None, - -def _compute_scale_factor(x: Tensor, - channel_dim: int, - min_abs: float, - max_abs: float, - gain_factor: float, - max_factor: float) -> Tensor: + return ( + x_grad - neg_delta_grad, + None, + None, + None, + ) + + +def _compute_scale_factor( + x: Tensor, + channel_dim: int, + min_abs: float, + max_abs: float, + gain_factor: float, + max_factor: float, +) -> Tensor: if channel_dim < 0: channel_dim += x.ndim sum_dims = [d for d in range(x.ndim) if d != channel_dim] @@ -83,71 +88,76 @@ def _compute_scale_factor(x: Tensor, else: # below_threshold is 0 if x_abs_mean > min_abs, can be at most max_factor if # x_abs)_mean , min_abs. - below_threshold = ((min_abs - x_abs_mean) * (gain_factor / min_abs)).clamp(min=0, max=max_factor) + below_threshold = ((min_abs - x_abs_mean) * (gain_factor / min_abs)).clamp( + min=0, max=max_factor + ) - above_threshold = ((x_abs_mean - max_abs) * (gain_factor / max_abs)).clamp(min=0, max=max_factor) + above_threshold = ((x_abs_mean - max_abs) * (gain_factor / max_abs)).clamp( + min=0, max=max_factor + ) return below_threshold - above_threshold -def _compute_sign_factor(x: Tensor, - channel_dim: int, - min_positive: float, - max_positive: float, - gain_factor: float, - max_factor: float) -> Tensor: + +def _compute_sign_factor( + x: Tensor, + channel_dim: int, + min_positive: float, + max_positive: float, + gain_factor: float, + max_factor: float, +) -> Tensor: if channel_dim < 0: channel_dim += x.ndim sum_dims = [d for d in range(x.ndim) if d != channel_dim] - proportion_positive = torch.mean((x > 0).to(torch.float32), - dim=sum_dims) + proportion_positive = torch.mean((x > 0).to(torch.float32), dim=sum_dims) if min_positive == 0.0: factor1 = 0.0 else: # 0 if proportion_positive >= min_positive, else can be # as large as max_factor. - factor1 = ((min_positive - proportion_positive) * - (gain_factor / min_positive)).clamp_(min=0, max=max_factor) + factor1 = ( + (min_positive - proportion_positive) * (gain_factor / min_positive) + ).clamp_(min=0, max=max_factor) if max_positive == 1.0: factor2 = 0.0 else: # 0 if self.proportion_positive <= max_positive, else can be # as large as -max_factor. - factor2 = ((proportion_positive - max_positive) * - (gain_factor / (1.0 - max_positive))).clamp_(min=0, max=max_factor) + factor2 = ( + (proportion_positive - max_positive) * (gain_factor / (1.0 - max_positive)) + ).clamp_(min=0, max=max_factor) sign_factor = factor1 - factor2 # require min_positive != 0 or max_positive != 1: assert not isinstance(sign_factor, float) return sign_factor - class ActivationScaleBalancerFunction(torch.autograd.Function): """ This object is used in class ActivationBalancer when the user specified min_positive=0, max_positive=1, so there are no constraints on the signs of the activations and only the absolute value has a constraint. """ + @staticmethod def forward( - ctx, - x: Tensor, - sign_factor: Tensor, - scale_factor: Tensor, - channel_dim: int, + ctx, + x: Tensor, + sign_factor: Tensor, + scale_factor: Tensor, + channel_dim: int, ) -> Tensor: if channel_dim < 0: channel_dim += x.ndim ctx.channel_dim = channel_dim - xgt0 = (x > 0) + xgt0 = x > 0 ctx.save_for_backward(xgt0, sign_factor, scale_factor) return x - @staticmethod - def backward( - ctx, x_grad: Tensor - ) -> Tuple[Tensor, None, None, None]: + def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None]: xgt0, sign_factor, scale_factor = ctx.saved_tensors for _ in range(ctx.channel_dim, x_grad.ndim - 1): sign_factor = sign_factor.unsqueeze(-1) @@ -155,18 +165,24 @@ def backward( factor = sign_factor + scale_factor * (xgt0.to(x_grad.dtype) - 0.5) neg_delta_grad = x_grad.abs() * factor - return x_grad - neg_delta_grad, None, None, None, + return ( + x_grad - neg_delta_grad, + None, + None, + None, + ) class RandomClampFunction(torch.autograd.Function): @staticmethod def forward( - ctx, - x: Tensor, - min: Optional[float], - max: Optional[float], - prob: float, - reflect: float) -> Tensor: + ctx, + x: Tensor, + min: Optional[float], + max: Optional[float], + prob: float, + reflect: float, + ) -> Tensor: x_clamped = torch.clamp(x, min=min, max=max) mask = torch.rand_like(x) < prob ans = torch.where(mask, x_clamped, x) @@ -179,30 +195,32 @@ def forward( @staticmethod def backward(ctx, ans_grad: Tensor) -> Tuple[Tensor, None, None, None, None]: - is_same, = ctx.saved_tensors + (is_same,) = ctx.saved_tensors x_grad = ans_grad * is_same.to(ans_grad.dtype) reflect = ctx.reflect - if reflect != 0.0: + if reflect != 0.0: x_grad = x_grad * (1.0 + reflect) - (ans_grad * reflect) return x_grad, None, None, None, None -def random_clamp(x: Tensor, - min: Optional[float] = None, - max: Optional[float] = None, - prob: float = 0.5, - reflect: float = 0.0): + +def random_clamp( + x: Tensor, + min: Optional[float] = None, + max: Optional[float] = None, + prob: float = 0.5, + reflect: float = 0.0, +): return RandomClampFunction.apply(x, min, max, prob, reflect) -def random_cast_to_half(x: Tensor, - min_abs: float = 5.0e-06) -> Tensor: +def random_cast_to_half(x: Tensor, min_abs: float = 5.0e-06) -> Tensor: """ A randomized way of casting a floating point value to half precision. """ if x.dtype == torch.float16: return x x_abs = x.abs() - is_too_small = (x_abs < min_abs) + is_too_small = x_abs < min_abs # for elements where is_too_small is true, random_val will contain +-min_abs with # probability (x.abs() / min_abs), and 0.0 otherwise. [so this preserves expectations, # for those elements]. @@ -215,6 +233,7 @@ class RandomGradFunction(torch.autograd.Function): Does nothing in forward pass; in backward pass, gets rid of very small grads using randomized approach that preserves expectations (intended to reduce roundoff). """ + @staticmethod def forward(ctx, x: Tensor, min_abs: float) -> Tensor: ctx.min_abs = min_abs @@ -223,35 +242,37 @@ def forward(ctx, x: Tensor, min_abs: float) -> Tensor: @staticmethod def backward(ctx, ans_grad: Tensor) -> Tuple[Tensor, None]: if ans_grad.dtype == torch.float16: - return random_cast_to_half(ans_grad.to(torch.float32), - min_abs=ctx.min_abs), None + return ( + random_cast_to_half(ans_grad.to(torch.float32), min_abs=ctx.min_abs), + None, + ) else: return ans_grad, None + class RandomGrad(torch.nn.Module): """ Gets rid of very small gradients using an expectation-preserving method, intended to increase accuracy of training when using amp (automatic mixed precision) """ - def __init__(self, - min_abs: float = 5.0e-06): + + def __init__(self, min_abs: float = 5.0e-06): super(RandomGrad, self).__init__() self.min_abs = min_abs - def forward(self, - x: Tensor): + def forward(self, x: Tensor): if torch.jit.is_scripting() or not self.training: return x else: return RandomGradFunction.apply(x, self.min_abs) - class SoftmaxFunction(torch.autograd.Function): """ Tries to handle half-precision derivatives in a randomized way that should be more accurate for training than the default behavior. """ + @staticmethod def forward(ctx, x: Tensor, dim: int): ans = x.softmax(dim=dim) @@ -267,7 +288,7 @@ def forward(ctx, x: Tensor, dim: int): @staticmethod def backward(ctx, ans_grad: Tensor): - ans, = ctx.saved_tensors + (ans,) = ctx.saved_tensors with torch.cuda.amp.autocast(enabled=False): ans_grad = ans_grad.to(torch.float32) ans = ans.to(torch.float32) @@ -276,9 +297,7 @@ def backward(ctx, ans_grad: Tensor): return x_grad, None - -def softmax(x: Tensor, - dim: int): +def softmax(x: Tensor, dim: int): if torch.jit.is_scripting(): return x.softmax(dim) @@ -288,20 +307,18 @@ def softmax(x: Tensor, class MaxEigLimiterFunction(torch.autograd.Function): @staticmethod def forward( - ctx, - x: Tensor, - coeffs: Tensor, - direction: Tensor, - channel_dim: int, - grad_scale: float) -> Tensor: + ctx, + x: Tensor, + coeffs: Tensor, + direction: Tensor, + channel_dim: int, + grad_scale: float, + ) -> Tensor: ctx.channel_dim = channel_dim ctx.grad_scale = grad_scale - ctx.save_for_backward(x.detach(), - coeffs.detach(), - direction.detach()) + ctx.save_for_backward(x.detach(), coeffs.detach(), direction.detach()) return x - @staticmethod def backward(ctx, x_grad, *args): with torch.enable_grad(): @@ -311,15 +328,20 @@ def backward(ctx, x_grad, *args): x = x_orig.transpose(ctx.channel_dim, -1).reshape(-1, num_channels) new_direction.requires_grad = False x = x - x.mean(dim=0) - x_var = (x ** 2).mean() + x_var = (x**2).mean() x_residual = x - coeffs * new_direction - x_residual_var = (x_residual ** 2).mean() + x_residual_var = (x_residual**2).mean() # `variance_proportion` is the proportion of the variance accounted for # by the top eigen-direction. This is to be minimized. variance_proportion = (x_var - x_residual_var) / (x_var + 1.0e-20) variance_proportion.backward() x_orig_grad = x_orig.grad - x_extra_grad = x_orig.grad * ctx.grad_scale * x_grad.norm() / (x_orig_grad.norm() + 1.0e-20) + x_extra_grad = ( + x_orig.grad + * ctx.grad_scale + * x_grad.norm() + / (x_orig_grad.norm() + 1.0e-20) + ) return x_grad + x_extra_grad.detach(), None, None, None, None @@ -385,15 +407,12 @@ def forward(self, x: Tensor) -> Tensor: # region if it happens to exit it. eps = eps.clamp(min=self.eps_min, max=self.eps_max) scales = ( - torch.mean(x ** 2, dim=self.channel_dim, keepdim=True) + eps.exp() + torch.mean(x**2, dim=self.channel_dim, keepdim=True) + eps.exp() ) ** -0.5 return x * scales - -def ScaledLinear(*args, - initial_scale: float = 1.0, - **kwargs ) -> nn.Linear: +def ScaledLinear(*args, initial_scale: float = 1.0, **kwargs) -> nn.Linear: """ Behaves like a constructor of a modified version of nn.Linear that gives an easy way to set the default initial parameter scale. @@ -412,16 +431,11 @@ def ScaledLinear(*args, with torch.no_grad(): ans.weight[:] *= initial_scale if ans.bias is not None: - torch.nn.init.uniform_(ans.bias, - -0.1 * initial_scale, - 0.1 * initial_scale) + torch.nn.init.uniform_(ans.bias, -0.1 * initial_scale, 0.1 * initial_scale) return ans - -def ScaledConv1d(*args, - initial_scale: float = 1.0, - **kwargs ) -> nn.Conv1d: +def ScaledConv1d(*args, initial_scale: float = 1.0, **kwargs) -> nn.Conv1d: """ Behaves like a constructor of a modified version of nn.Conv1d that gives an easy way to set the default initial parameter scale. @@ -440,13 +454,10 @@ def ScaledConv1d(*args, with torch.no_grad(): ans.weight[:] *= initial_scale if ans.bias is not None: - torch.nn.init.uniform_(ans.bias, - -0.1 * initial_scale, - 0.1 * initial_scale) + torch.nn.init.uniform_(ans.bias, -0.1 * initial_scale, 0.1 * initial_scale) return ans - class ActivationBalancer(torch.nn.Module): """ Modifies the backpropped derivatives of a function to try to encourage, for @@ -486,18 +497,19 @@ class ActivationBalancer(torch.nn.Module): from doing it at the same time. Early in training we may use higher probabilities than this; it will decay to this value. """ + def __init__( - self, - num_channels: int, - channel_dim: int, - min_positive: float = 0.05, - max_positive: float = 0.95, - max_factor: float = 0.04, - sign_gain_factor: float = 0.01, - scale_gain_factor: float = 0.02, - min_abs: float = 0.2, - max_abs: float = 100.0, - min_prob: float = 0.1, + self, + num_channels: int, + channel_dim: int, + min_positive: float = 0.05, + max_positive: float = 0.95, + max_factor: float = 0.04, + sign_gain_factor: float = 0.01, + scale_gain_factor: float = 0.02, + min_abs: float = 0.2, + max_abs: float = 100.0, + min_prob: float = 0.1, ): super(ActivationBalancer, self).__init__() self.num_channels = num_channels @@ -515,9 +527,7 @@ def __init__( # We occasionally sync this to a tensor called `count`, that exists to # make sure it is synced to disk when we load and save the model. self.cpu_count = 0 - self.register_buffer('count', torch.tensor(0, dtype=torch.int64)) - - + self.register_buffer("count", torch.tensor(0, dtype=torch.int64)) def forward(self, x: Tensor) -> Tensor: if torch.jit.is_scripting() or not x.requires_grad: @@ -535,26 +545,35 @@ def forward(self, x: Tensor) -> Tensor: # the prob of doing some work exponentially decreases from 0.5 till it hits # a floor at min_prob (==0.1, by default) - prob = max(self.min_prob, 0.5 ** (1 + (count/4000.0))) + prob = max(self.min_prob, 0.5 ** (1 + (count / 4000.0))) if random.random() < prob: sign_gain_factor = 0.5 if self.min_positive != 0.0 or self.max_positive != 1.0: - sign_factor = _compute_sign_factor(x, self.channel_dim, - self.min_positive, self.max_positive, - gain_factor=self.sign_gain_factor / prob, - max_factor=self.max_factor) + sign_factor = _compute_sign_factor( + x, + self.channel_dim, + self.min_positive, + self.max_positive, + gain_factor=self.sign_gain_factor / prob, + max_factor=self.max_factor, + ) else: sign_factor = None - - scale_factor = _compute_scale_factor(x, self.channel_dim, - min_abs=self.min_abs, - max_abs=self.max_abs, - gain_factor=self.scale_gain_factor / prob, - max_factor=self.max_factor) + scale_factor = _compute_scale_factor( + x, + self.channel_dim, + min_abs=self.min_abs, + max_abs=self.max_abs, + gain_factor=self.scale_gain_factor / prob, + max_factor=self.max_factor, + ) return ActivationBalancerFunction.apply( - x, scale_factor, sign_factor, self.channel_dim, + x, + scale_factor, + sign_factor, + self.channel_dim, ) else: return _no_op(x) @@ -594,13 +613,12 @@ def _diag(x: Tensor): # like .diag(), but works for tensors with 3 dims. else: (batch, dim, dim) = x.shape x = x.reshape(batch, dim * dim) - x = x[:, ::dim+1] + x = x[:, :: dim + 1] assert x.shape == (batch, dim) return x -def _whitening_metric(x: Tensor, - num_groups: int): +def _whitening_metric(x: Tensor, num_groups: int): """ Computes the "whitening metric", a value which will be 1.0 if all the eigenvalues of of the centered feature covariance are the same within each group's covariance matrix @@ -630,19 +648,17 @@ def _whitening_metric(x: Tensor, # the following expression is what we'd get if we took the matrix product # of each covariance and measured the mean of its trace, i.e. # the same as _diag(torch.matmul(x_covar, x_covar)).mean(). - x_covarsq_mean_diag = (x_covar ** 2).sum() / (num_groups * channels_per_group) + x_covarsq_mean_diag = (x_covar**2).sum() / (num_groups * channels_per_group) # this metric will be >= 1.0; the larger it is, the less 'white' the data was. - metric = x_covarsq_mean_diag / (x_covar_mean_diag ** 2 + 1.0e-20) + metric = x_covarsq_mean_diag / (x_covar_mean_diag**2 + 1.0e-20) return metric class WhiteningPenaltyFunction(torch.autograd.Function): @staticmethod - def forward(ctx, - x: Tensor, - num_groups: int, - whitening_limit: float, - grad_scale: float) -> Tensor: + def forward( + ctx, x: Tensor, num_groups: int, whitening_limit: float, grad_scale: float + ) -> Tensor: ctx.save_for_backward(x) ctx.num_groups = num_groups ctx.whitening_limit = whitening_limit @@ -650,9 +666,8 @@ def forward(ctx, return x @staticmethod - def backward(ctx, - x_grad: Tensor): - x_orig, = ctx.saved_tensors + def backward(ctx, x_grad: Tensor): + (x_orig,) = ctx.saved_tensors with torch.enable_grad(): with torch.cuda.amp.autocast(enabled=False): x_detached = x_orig.to(torch.float32).detach() @@ -661,25 +676,28 @@ def backward(ctx, metric = _whitening_metric(x_detached, ctx.num_groups) if random.random() < 0.005 or __name__ == "__main__": - logging.info(f"Whitening: num_groups={ctx.num_groups}, num_channels={x_orig.shape[-1]}, " - f"metric={metric.item():.2f} vs. limit={ctx.whitening_limit}") + logging.info( + f"Whitening: num_groups={ctx.num_groups}, num_channels={x_orig.shape[-1]}, " + f"metric={metric.item():.2f} vs. limit={ctx.whitening_limit}" + ) (metric - ctx.whitening_limit).relu().backward() penalty_grad = x_detached.grad - scale = ctx.grad_scale * (x_grad.to(torch.float32).norm() / - (penalty_grad.norm() + 1.0e-20)) + scale = ctx.grad_scale * ( + x_grad.to(torch.float32).norm() / (penalty_grad.norm() + 1.0e-20) + ) penalty_grad = penalty_grad * scale return x_grad + penalty_grad.to(x_grad.dtype), None, None, None - class Whiten(nn.Module): def __init__( - self, - num_groups: int, - whitening_limit: float, - prob: Union[float, Tuple[float,float]], - grad_scale: float): + self, + num_groups: int, + whitening_limit: float, + prob: Union[float, Tuple[float, float]], + grad_scale: float, + ): """ Args: num_groups: the number of groups to divide the channel dim into before @@ -714,8 +732,7 @@ def __init__( self.grad_scale = grad_scale - def forward(self, - x: Tensor) -> Tensor: + def forward(self, x: Tensor) -> Tensor: """ In the forward pass, this function just returns the input unmodified. In the backward pass, it will modify the gradients to ensure that the @@ -735,19 +752,21 @@ def forward(self, if not x.requires_grad or random.random() > self.prob or self.grad_scale == 0: return _no_op(x) else: - if hasattr(self, 'min_prob') and random.random() < 0.25: + if hasattr(self, "min_prob") and random.random() < 0.25: # occasionally switch between min_prob and max_prob, based on whether # we are above or below the threshold. - if _whitening_metric(x.to(torch.float32), self.num_groups) > self.whitening_limit: + if ( + _whitening_metric(x.to(torch.float32), self.num_groups) + > self.whitening_limit + ): # there would be a change to the grad. self.prob = self.max_prob else: self.prob = self.min_prob - return WhiteningPenaltyFunction.apply(x, - self.num_groups, - self.whitening_limit, - self.grad_scale) + return WhiteningPenaltyFunction.apply( + x, self.num_groups, self.whitening_limit, self.grad_scale + ) class WithLoss(torch.autograd.Function): @@ -755,11 +774,14 @@ class WithLoss(torch.autograd.Function): def forward(ctx, x: Tensor, y: Tensor): ctx.y_shape = y.shape return x + @staticmethod def backward(ctx, ans_grad: Tensor): - return ans_grad, torch.ones(ctx.y_shape, - dtype=ans_grad.dtype, - device=ans_grad.device) + return ans_grad, torch.ones( + ctx.y_shape, dtype=ans_grad.dtype, device=ans_grad.device + ) + + def with_loss(x, y): if torch.jit.is_scripting(): return x @@ -768,7 +790,7 @@ def with_loss(x, y): def _no_op(x: Tensor) -> Tensor: - if (torch.jit.is_scripting()): + if torch.jit.is_scripting(): return x else: # a no-op function that will have a node in the autograd graph, @@ -783,6 +805,7 @@ def __init__(self): def forward(self, x): return _no_op(x) + class MaxEig(torch.nn.Module): """ Modifies the backpropped derivatives of a function to try to discourage @@ -803,13 +826,14 @@ class MaxEig(torch.nn.Module): scale: determines the scale with which we modify the gradients, relative to the existing / unmodified gradients """ + def __init__( - self, - num_channels: int, - channel_dim: int, - max_var_per_eig: float = 0.2, - min_prob: float = 0.01, - scale: float = 0.01, + self, + num_channels: int, + channel_dim: int, + max_var_per_eig: float = 0.2, + min_prob: float = 0.01, + scale: float = 0.01, ): super(MaxEig, self).__init__() self.num_channels = num_channels @@ -825,7 +849,7 @@ def __init__( # random parameters unchanged for comparison direction = torch.arange(num_channels).to(torch.float) direction = direction / direction.norm() - self.register_buffer('max_eig_direction', direction) + self.register_buffer("max_eig_direction", direction) self.min_prob = min_prob # cur_prob is the current probability we'll use to apply the ActivationBalancer. @@ -833,12 +857,12 @@ def __init__( # active. self.cur_prob = 1.0 - - def forward(self, x: Tensor) -> Tensor: - if (torch.jit.is_scripting() or - self.max_var_per_eig <= 0 or - random.random() > self.cur_prob): + if ( + torch.jit.is_scripting() + or self.max_var_per_eig <= 0 + or random.random() > self.cur_prob + ): return _no_op(x) with torch.cuda.amp.autocast(enabled=False): @@ -848,7 +872,9 @@ def forward(self, x: Tensor) -> Tensor: with torch.no_grad(): x = x.transpose(self.channel_dim, -1).reshape(-1, self.num_channels) x = x - x.mean(dim=0) - new_direction, coeffs = self._find_direction_coeffs(x, self.max_eig_direction) + new_direction, coeffs = self._find_direction_coeffs( + x, self.max_eig_direction + ) x_var = (x**2).mean() x_residual = x - coeffs * new_direction x_residual_var = (x_residual**2).mean() @@ -861,7 +887,9 @@ def forward(self, x: Tensor) -> Tensor: self._set_direction(0.1 * self.max_eig_direction + new_direction) if random.random() < 0.01 or __name__ == "__main__": - logging.info(f"variance_proportion = {variance_proportion.item()}, shape={tuple(orig_x.shape)}, cur_prob={self.cur_prob}") + logging.info( + f"variance_proportion = {variance_proportion.item()}, shape={tuple(orig_x.shape)}, cur_prob={self.cur_prob}" + ) if variance_proportion >= self.max_var_per_eig: # The constraint is active. Note, we should quite rarely @@ -869,17 +897,16 @@ def forward(self, x: Tensor) -> Tensor: # starting to diverge, should this constraint be active. cur_prob = self.cur_prob self.cur_prob = 1.0 # next time, do the update with probability 1.0. - return MaxEigLimiterFunction.apply(orig_x, coeffs, new_direction, - self.channel_dim, self.scale) + return MaxEigLimiterFunction.apply( + orig_x, coeffs, new_direction, self.channel_dim, self.scale + ) else: # let self.cur_prob exponentially approach self.min_prob, as # long as the constraint is inactive. self.cur_prob = 0.75 * self.cur_prob + 0.25 * self.min_prob return orig_x - - def _set_direction(self, - direction: Tensor): + def _set_direction(self, direction: Tensor): """ Sets self.max_eig_direction to a normalized version of `direction` """ @@ -889,40 +916,39 @@ def _set_direction(self, if direction_sum - direction_sum == 0: # no inf/nan self.max_eig_direction[:] = direction else: - logging.info(f"Warning: sum of direction in MaxEig is {direction_sum}, " - "num_channels={self.num_channels}, channel_dim={self.channel_dim}") - + logging.info( + f"Warning: sum of direction in MaxEig is {direction_sum}, " + "num_channels={self.num_channels}, channel_dim={self.channel_dim}" + ) - def _find_direction_coeffs(self, - x: Tensor, - prev_direction: Tensor) -> Tuple[Tensor, Tensor, Tensor]: + def _find_direction_coeffs( + self, x: Tensor, prev_direction: Tensor + ) -> Tuple[Tensor, Tensor, Tensor]: + """ + Figure out (an approximation to) the proportion of the variance of a set of + feature vectors that can be attributed to the top eigen-direction. + Args: + x: a Tensor of shape (num_frames, num_channels), with num_frames > 1. + prev_direction: a Tensor of shape (num_channels,), that is our previous estimate + of the top eigen-direction, or a random direction if this is the first + iteration. Does not have to be normalized, but should be nonzero. + + Returns: (cur_direction, coeffs), where: + cur_direction: a Tensor of shape (num_channels,) that is the current + estimate of the top eigen-direction. + coeffs: a Tensor of shape (num_frames, 1) that minimizes, or + approximately minimizes, (x - coeffs * cur_direction).norm() """ - Figure out (an approximation to) the proportion of the variance of a set of - feature vectors that can be attributed to the top eigen-direction. - Args: - x: a Tensor of shape (num_frames, num_channels), with num_frames > 1. - prev_direction: a Tensor of shape (num_channels,), that is our previous estimate - of the top eigen-direction, or a random direction if this is the first - iteration. Does not have to be normalized, but should be nonzero. - - Returns: (cur_direction, coeffs), where: - cur_direction: a Tensor of shape (num_channels,) that is the current - estimate of the top eigen-direction. - coeffs: a Tensor of shape (num_frames, 1) that minimizes, or - approximately minimizes, (x - coeffs * cur_direction).norm() - """ (num_frames, num_channels) = x.shape assert num_channels > 1 and num_frames > 1 assert prev_direction.shape == (num_channels,) # `coeffs` are the coefficients of `prev_direction` in x. # actually represent the coeffs up to a constant positive factor. coeffs = (x * prev_direction).sum(dim=1, keepdim=True) + 1.0e-10 - cur_direction = (x * coeffs).sum(dim=0) / ((coeffs ** 2).sum() + 1.0e-20) + cur_direction = (x * coeffs).sum(dim=0) / ((coeffs**2).sum() + 1.0e-20) return cur_direction, coeffs - - class DoubleSwishFunction(torch.autograd.Function): """ double_swish(x) = x * torch.sigmoid(x-1) @@ -950,7 +976,7 @@ def forward(ctx, x: Tensor) -> Tensor: y = x * s if requires_grad: - deriv = (y * (1 - s) + s) + deriv = y * (1 - s) + s # notes on derivative of x * sigmoid(x - 1): # https://www.wolframalpha.com/input?i=d%2Fdx+%28x+*+sigmoid%28x-1%29%29 # min \simeq -0.043638. Take floor as -0.043637 so it's a lower bund @@ -959,7 +985,9 @@ def forward(ctx, x: Tensor) -> Tensor: # floors), should be expectation-preserving. floor = -0.043637 ceil = 1.2 - d_scaled = ((deriv - floor) * (255.0 / (ceil - floor)) + torch.rand_like(deriv)) + d_scaled = (deriv - floor) * (255.0 / (ceil - floor)) + torch.rand_like( + deriv + ) if __name__ == "__main__": # for self-testing only. assert d_scaled.min() >= 0.0 @@ -972,12 +1000,12 @@ def forward(ctx, x: Tensor) -> Tensor: @staticmethod def backward(ctx, y_grad: Tensor) -> Tensor: - d, = ctx.saved_tensors + (d,) = ctx.saved_tensors # the same constants as used in forward pass. floor = -0.043637 ceil = 1.2 - d = (d * ((ceil - floor) / 255.0) + floor) - return (y_grad * d) + d = d * ((ceil - floor) / 255.0) + floor + return y_grad * d class DoubleSwish(torch.nn.Module): @@ -990,7 +1018,6 @@ def forward(self, x: Tensor) -> Tensor: return DoubleSwishFunction.apply(x) - def _test_max_eig(): for proportion in [0.1, 0.5, 10.0]: logging.info(f"proportion = {proportion}") @@ -1002,11 +1029,9 @@ def _test_max_eig(): x.requires_grad = True num_channels = 128 - m = MaxEig(num_channels, - 1, # channel_dim - 0.5, # max_var_per_eig - scale=0.1) # grad_scale - + m = MaxEig( + num_channels, 1, 0.5, scale=0.1 # channel_dim # max_var_per_eig + ) # grad_scale for _ in range(4): y = m(x) @@ -1031,11 +1056,9 @@ def _test_whiten(): x.requires_grad = True num_channels = 128 - m = Whiten(1, # num_groups - 5.0, # whitening_limit, - prob=1.0, - grad_scale=0.1) # grad_scale - + m = Whiten( + 1, 5.0, prob=1.0, grad_scale=0.1 # num_groups # whitening_limit, + ) # grad_scale for _ in range(4): y = m(x) @@ -1049,7 +1072,6 @@ def _test_whiten(): assert not torch.allclose(x.grad, y_grad) - def _test_activation_balancer_sign(): probs = torch.arange(0, 1, 0.01) N = 1000 @@ -1077,9 +1099,7 @@ def _test_activation_balancer_sign(): def _test_activation_balancer_magnitude(): magnitudes = torch.arange(0, 1, 0.01) N = 1000 - x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze( - -1 - ) + x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze(-1) x = x.detach() x.requires_grad = True m = ActivationBalancer( @@ -1111,8 +1131,8 @@ def _test_basic_norm(): y = m(x) assert y.shape == x.shape - x_rms = (x ** 2).mean().sqrt() - y_rms = (y ** 2).mean().sqrt() + x_rms = (x**2).mean().sqrt() + y_rms = (y**2).mean().sqrt() print("x rms = ", x_rms) print("y rms = ", y_rms) assert y_rms < x_rms @@ -1124,30 +1144,27 @@ def _test_double_swish_deriv(): x.requires_grad = True m = DoubleSwish() - tol = ((1.2-(-0.043637))/255.0) + tol = (1.2 - (-0.043637)) / 255.0 torch.autograd.gradcheck(m, x, atol=tol) - # for self-test. x = torch.randn(1000, 1000, dtype=torch.double) * 3.0 x.requires_grad = True y = m(x) - def _test_softmax(): a = torch.randn(2, 10, dtype=torch.float64) b = a.clone() a.requires_grad = True b.requires_grad = True - a.softmax(dim=1)[:,0].sum().backward() + a.softmax(dim=1)[:, 0].sum().backward() print("a grad = ", a.grad) - softmax(b, dim=1)[:,0].sum().backward() + softmax(b, dim=1)[:, 0].sum().backward() print("b grad = ", b.grad) assert torch.allclose(a.grad, b.grad) - if __name__ == "__main__": logging.getLogger().setLevel(logging.INFO) torch.set_num_threads(1) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling_converter.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling_converter.py index 8d357b15f1..56165d1f92 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling_converter.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling_converter.py @@ -26,11 +26,7 @@ import torch import torch.nn as nn -from scaling import ( - ActivationBalancer, - BasicNorm, - Whiten, -) +from scaling import ActivationBalancer, BasicNorm, Whiten class NonScaledNorm(nn.Module): diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py index 3f27736b3c..7160fc54ab 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py @@ -84,9 +84,7 @@ from icefall.hooks import register_inf_check_hooks from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool -LRSchedulerType = Union[ - torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler -] +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: @@ -269,8 +267,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( @@ -293,8 +290,7 @@ def get_parser(): "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)" - "part.", + help="The scale to smooth the loss with am (output of encoder network)" "part.", ) parser.add_argument( @@ -646,11 +642,7 @@ def compute_loss( warmup: a floating point value which increases throughout training; values >= 1.0 are fully warmed up and have all modules present. """ - device = ( - model.device - if isinstance(model, DDP) - else next(model.parameters()).device - ) + device = model.device if isinstance(model, DDP) else next(model.parameters()).device feature = batch["inputs"] # at entry, feature is (N, T, C) assert feature.ndim == 3 @@ -697,9 +689,7 @@ def compute_loss( info = MetricsTracker() with warnings.catch_warnings(): warnings.simplefilter("ignore") - info["frames"] = ( - (feature_lens // params.subsampling_factor).sum().item() - ) + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() # Note: We use reduction=sum while computing the loss. info["loss"] = loss.detach().cpu().item() @@ -870,9 +860,7 @@ def train_one_epoch( # of the grad scaler is configurable, but we can't configure it to have different # behavior depending on the current grad scale. cur_grad_scale = scaler._scale.item() - if cur_grad_scale < 1.0 or ( - cur_grad_scale < 8.0 and batch_idx % 400 == 0 - ): + if cur_grad_scale < 1.0 or (cur_grad_scale < 8.0 and batch_idx % 400 == 0): scaler.update(cur_grad_scale * 2.0) if cur_grad_scale < 0.01: logging.warning(f"Grad scale is small: {cur_grad_scale}") @@ -890,11 +878,7 @@ def train_one_epoch( f"batch {batch_idx}, loss[{loss_info}], " f"tot_loss[{tot_loss}], batch size: {batch_size}, " f"lr: {cur_lr:.2e}, " - + ( - f"grad_scale: {scaler._scale.item()}" - if params.use_fp16 - else "" - ) + + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") ) if tb_writer is not None: @@ -905,9 +889,7 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) if params.use_fp16: tb_writer.add_scalar( "train/grad_scale", @@ -915,10 +897,7 @@ def train_one_epoch( params.batch_idx_train, ) - if ( - batch_idx % params.valid_interval == 0 - and not params.print_diagnostics - ): + if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: logging.info("Computing validation loss") valid_info = compute_validation_loss( params=params, @@ -1009,9 +988,7 @@ def run(rank, world_size, args): logging.info("Using DDP") model = DDP(model, device_ids=[rank], find_unused_parameters=True) - optimizer = ScaledAdam( - model.parameters(), lr=params.base_lr, clipping_scale=2.0 - ) + optimizer = ScaledAdam(model.parameters(), lr=params.base_lr, clipping_scale=2.0) scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) @@ -1029,7 +1006,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2 ** 22 + 2**22 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) @@ -1054,8 +1031,7 @@ def remove_short_and_long_utt(c: Cut): # the threshold if c.duration < 1.0 or c.duration > 20.0: logging.warning( - f"Exclude cut with ID {c.id} from training. " - f"Duration: {c.duration}" + f"Exclude cut with ID {c.id} from training. " f"Duration: {c.duration}" ) return False diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index 023dec97d5..b007a73084 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -16,32 +16,35 @@ # limitations under the License. import copy +import itertools +import logging import math +import random import warnings -import itertools from typing import List, Optional, Tuple, Union -import logging + import torch -import random from encoder_interface import EncoderInterface +from scaling import ( + ScaledLinear, # not as in other dirs.. just scales down initial parameter values. +) from scaling import ( ActivationBalancer, BasicNorm, - MaxEig, DoubleSwish, + Identity, + MaxEig, ScaledConv1d, - ScaledLinear, # not as in other dirs.. just scales down initial parameter values. Whiten, - Identity, _diag, - random_clamp, penalize_abs_values_gt, + random_clamp, softmax, ) from torch import Tensor, nn -from icefall.utils import make_pad_mask from icefall.dist import get_rank +from icefall.utils import make_pad_mask class Zipformer(EncoderInterface): @@ -89,7 +92,7 @@ def __init__( self.batch_count = 0 self.warmup_end = warmup_batches - for u,d in zip(encoder_unmasked_dims, encoder_dims): + for u, d in zip(encoder_unmasked_dims, encoder_dims): assert u <= d, (u, d) # self.encoder_embed converts the input of shape (N, T, num_features) @@ -97,9 +100,9 @@ def __init__( # That is, it does two things simultaneously: # (1) subsampling: T -> (T - 7)//2 # (2) embedding: num_features -> encoder_dims - self.encoder_embed = Conv2dSubsampling(num_features, encoder_dims[0], - dropout=dropout) - + self.encoder_embed = Conv2dSubsampling( + num_features, encoder_dims[0], dropout=dropout + ) # each one will be ZipformerEncoder or DownsampledZipformerEncoder encoders = [] @@ -123,13 +126,13 @@ def __init__( num_encoder_layers[i], dropout, warmup_begin=warmup_batches * (i + 1) / (num_encoders + 1), - warmup_end=warmup_batches * (i + 2) / (num_encoders + 1) + warmup_end=warmup_batches * (i + 2) / (num_encoders + 1), ) if zipformer_downsampling_factors[i] != 1: encoder = DownsampledZipformerEncoder( encoder, - input_dim=encoder_dims[i-1] if i > 0 else encoder_dims[0], + input_dim=encoder_dims[i - 1] if i > 0 else encoder_dims[0], output_dim=encoder_dims[i], downsample=zipformer_downsampling_factors[i], ) @@ -139,10 +142,9 @@ def __init__( # initializes self.skip_layers and self.skip_modules self._init_skip_modules() - self.downsample_output = AttentionDownsample(encoder_dims[-1], - encoder_dims[-1], - downsample=output_downsampling_factor) - + self.downsample_output = AttentionDownsample( + encoder_dims[-1], encoder_dims[-1], downsample=output_downsampling_factor + ) def _get_layer_skip_dropout_prob(self): if not self.training: @@ -166,27 +168,31 @@ def _init_skip_modules(self): skip_modules = [] z = self.zipformer_downsampling_factors for i in range(len(z)): - if i <= 1 or z[i-1] <= z[i]: + if i <= 1 or z[i - 1] <= z[i]: skip_layers.append(None) skip_modules.append(SimpleCombinerIdentity()) else: # TEMP - for j in range(i-2, -1, -1): + for j in range(i - 2, -1, -1): if z[j] <= z[i] or j == 0: # TEMP logging statement. - logging.info(f"At encoder stack {i}, which has downsampling_factor={z[i]}, we will " - f"combine the outputs of layers {j} and {i-1}, with downsampling_factors={z[j]} and {z[i-1]}.") + logging.info( + f"At encoder stack {i}, which has downsampling_factor={z[i]}, we will " + f"combine the outputs of layers {j} and {i-1}, with downsampling_factors={z[j]} and {z[i-1]}." + ) skip_layers.append(j) - skip_modules.append(SimpleCombiner(self.encoder_dims[j], - self.encoder_dims[i-1], - min_weight=(0.0,0.25))) + skip_modules.append( + SimpleCombiner( + self.encoder_dims[j], + self.encoder_dims[i - 1], + min_weight=(0.0, 0.25), + ) + ) break self.skip_layers = skip_layers self.skip_modules = nn.ModuleList(skip_modules) - def get_feature_masks( - self, - x: torch.Tensor) -> List[float]: + def get_feature_masks(self, x: torch.Tensor) -> List[float]: # Note: The actual return type is Union[List[float], List[Tensor]], # but to make torch.jit.script() work, we use List[float] """ @@ -206,46 +212,56 @@ def get_feature_masks( """ num_encoders = len(self.encoder_dims) if torch.jit.is_scripting() or not self.training: - return [ 1.0 ] * num_encoders + return [1.0] * num_encoders (num_frames0, batch_size, _encoder_dims0) = x.shape - - assert self.encoder_dims[0] == _encoder_dims0, (self.encoder_dims, _encoder_dims0) + assert self.encoder_dims[0] == _encoder_dims0, ( + self.encoder_dims, + _encoder_dims0, + ) max_downsampling_factor = max(self.zipformer_downsampling_factors) - num_frames_max = (num_frames0 + max_downsampling_factor - 1) - + num_frames_max = num_frames0 + max_downsampling_factor - 1 feature_mask_dropout_prob = 0.15 # frame_mask_max shape: (num_frames_max, batch_size, 1) - frame_mask_max = (torch.rand(num_frames_max, batch_size, 1, - device=x.device) > - feature_mask_dropout_prob).to(x.dtype) + frame_mask_max = ( + torch.rand(num_frames_max, batch_size, 1, device=x.device) + > feature_mask_dropout_prob + ).to(x.dtype) feature_masks = [] for i in range(num_encoders): ds = self.zipformer_downsampling_factors[i] - upsample_factor = (max_downsampling_factor // ds) + upsample_factor = max_downsampling_factor // ds - frame_mask = (frame_mask_max.unsqueeze(1).expand(num_frames_max, upsample_factor, - batch_size, 1) - .reshape(num_frames_max * upsample_factor, batch_size, 1)) + frame_mask = ( + frame_mask_max.unsqueeze(1) + .expand(num_frames_max, upsample_factor, batch_size, 1) + .reshape(num_frames_max * upsample_factor, batch_size, 1) + ) num_frames = (num_frames0 + ds - 1) // ds frame_mask = frame_mask[:num_frames] - feature_mask = torch.ones(num_frames, batch_size, self.encoder_dims[i], - dtype=x.dtype, device=x.device) + feature_mask = torch.ones( + num_frames, + batch_size, + self.encoder_dims[i], + dtype=x.dtype, + device=x.device, + ) u = self.encoder_unmasked_dims[i] feature_mask[:, :, u:] *= frame_mask feature_masks.append(feature_mask) return feature_masks - def forward( - self, x: torch.Tensor, x_lens: torch.Tensor, + self, + x: torch.Tensor, + x_lens: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Args: @@ -271,7 +287,9 @@ def forward( outputs = [] feature_masks = self.get_feature_masks(x) - for i, (module, skip_module) in enumerate(zip(self.encoders, self.skip_modules)): + for i, (module, skip_module) in enumerate( + zip(self.encoders, self.skip_modules) + ): ds = self.zipformer_downsampling_factors[i] k = self.skip_layers[i] if isinstance(k, int): @@ -280,9 +298,11 @@ def forward( x = skip_module(outputs[k], x) elif (not self.training) or random.random() > layer_skip_dropout_prob: x = skip_module(outputs[k], x) - x = module(x, - feature_mask=feature_masks[i], - src_key_padding_mask=None if mask is None else mask[...,::ds]) + x = module( + x, + feature_mask=feature_masks[i], + src_key_padding_mask=None if mask is None else mask[..., ::ds], + ) outputs.append(x) x = self.downsample_output(x) @@ -312,15 +332,16 @@ class ZipformerEncoderLayer(nn.Module): >>> pos_emb = torch.rand(32, 19, 512) >>> out = encoder_layer(src, pos_emb) """ + def __init__( - self, - d_model: int, - attention_dim: int, - nhead: int, - feedforward_dim: int = 2048, - dropout: float = 0.1, - cnn_module_kernel: int = 31, - pos_dim: int = 4, + self, + d_model: int, + attention_dim: int, + nhead: int, + feedforward_dim: int = 2048, + dropout: float = 0.1, + cnn_module_kernel: int = 31, + pos_dim: int = 4, ) -> None: super(ZipformerEncoderLayer, self).__init__() @@ -330,29 +351,24 @@ def __init__( self.batch_count = 0 self.self_attn = RelPositionMultiheadAttention( - d_model, attention_dim, nhead, pos_dim, dropout=0.0, + d_model, + attention_dim, + nhead, + pos_dim, + dropout=0.0, ) self.pooling = PoolingModule(d_model) - self.feed_forward1 = FeedforwardModule(d_model, - feedforward_dim, - dropout) + self.feed_forward1 = FeedforwardModule(d_model, feedforward_dim, dropout) - self.feed_forward2 = FeedforwardModule(d_model, - feedforward_dim, - dropout) + self.feed_forward2 = FeedforwardModule(d_model, feedforward_dim, dropout) - self.feed_forward3 = FeedforwardModule(d_model, - feedforward_dim, - dropout) + self.feed_forward3 = FeedforwardModule(d_model, feedforward_dim, dropout) + self.conv_module1 = ConvolutionModule(d_model, cnn_module_kernel) - self.conv_module1 = ConvolutionModule(d_model, - cnn_module_kernel) - - self.conv_module2 = ConvolutionModule(d_model, - cnn_module_kernel) + self.conv_module2 = ConvolutionModule(d_model, cnn_module_kernel) self.norm_final = BasicNorm(d_model) @@ -360,14 +376,15 @@ def __init__( # try to ensure the output is close to zero-mean (or at least, zero-median). self.balancer = ActivationBalancer( - d_model, channel_dim=-1, - min_positive=0.45, max_positive=0.55, + d_model, + channel_dim=-1, + min_positive=0.45, + max_positive=0.55, max_abs=6.0, ) - self.whiten = Whiten(num_groups=1, - whitening_limit=5.0, - prob=(0.025, 0.25), - grad_scale=0.01) + self.whiten = Whiten( + num_groups=1, whitening_limit=5.0, prob=(0.025, 0.25), grad_scale=0.01 + ) def get_bypass_scale(self): if torch.jit.is_scripting() or not self.training: @@ -382,8 +399,9 @@ def get_bypass_scale(self): if self.batch_count > warmup_period: clamp_min = final_clamp_min else: - clamp_min = (initial_clamp_min - - (self.batch_count / warmup_period) * (initial_clamp_min - final_clamp_min)) + clamp_min = initial_clamp_min - (self.batch_count / warmup_period) * ( + initial_clamp_min - final_clamp_min + ) return self.bypass_scale.clamp(min=clamp_min, max=1.0) def get_dynamic_dropout_rate(self): @@ -398,8 +416,9 @@ def get_dynamic_dropout_rate(self): if self.batch_count > warmup_period: return final_dropout_rate else: - return (initial_dropout_rate - - (initial_dropout_rate * final_dropout_rate) * (self.batch_count / warmup_period)) + return initial_dropout_rate - ( + initial_dropout_rate * final_dropout_rate + ) * (self.batch_count / warmup_period) def forward( self, @@ -508,13 +527,14 @@ class ZipformerEncoder(nn.Module): >>> src = torch.rand(10, 32, 512) >>> out = zipformer_encoder(src) """ + def __init__( - self, - encoder_layer: nn.Module, - num_layers: int, - dropout: float, - warmup_begin: float, - warmup_end: float + self, + encoder_layer: nn.Module, + num_layers: int, + dropout: float, + warmup_begin: float, + warmup_end: float, ) -> None: super().__init__() # will be written to, see set_batch_count() Note: in inference time this @@ -528,8 +548,7 @@ def __init__( # so that we can keep this consistent across worker tasks (for efficiency). self.module_seed = torch.randint(0, 1000, ()).item() - self.encoder_pos = RelPositionalEncoding(encoder_layer.d_model, - dropout) + self.encoder_pos = RelPositionalEncoding(encoder_layer.d_model, dropout) self.layers = nn.ModuleList( [copy.deepcopy(encoder_layer) for i in range(num_layers)] @@ -538,15 +557,13 @@ def __init__( assert 0 <= warmup_begin <= warmup_end, (warmup_begin, warmup_end) - - delta = (1. / num_layers) * (warmup_end - warmup_begin) + delta = (1.0 / num_layers) * (warmup_end - warmup_begin) cur_begin = warmup_begin for i in range(num_layers): self.layers[i].warmup_begin = cur_begin cur_begin += delta self.layers[i].warmup_end = cur_begin - def get_layers_to_drop(self, rnd_seed: int): ans = set() if not self.training: @@ -579,12 +596,14 @@ def get_layerdrop_prob(layer: int) -> float: # linearly interpolate t = (batch_count - layer_warmup_begin) / layer_warmup_end assert 0.0 <= t < 1.001, t - return initial_layerdrop_prob + t * (final_layerdrop_prob - initial_layerdrop_prob) + return initial_layerdrop_prob + t * ( + final_layerdrop_prob - initial_layerdrop_prob + ) shared_rng = random.Random(batch_count + self.module_seed) independent_rng = random.Random(rnd_seed) - layerdrop_probs = [ get_layerdrop_prob(i) for i in range(num_layers) ] + layerdrop_probs = [get_layerdrop_prob(i) for i in range(num_layers)] tot = sum(layerdrop_probs) # Instead of drawing the samples independently, we first randomly decide # how many layers to drop out, using the same random number generator between @@ -604,11 +623,12 @@ def get_layerdrop_prob(layer: int) -> float: if len(ans) == num_to_drop: break if shared_rng.random() < 0.005 or __name__ == "__main__": - logging.info(f"warmup_begin={self.warmup_begin:.1f}, warmup_end={self.warmup_end:.1f}, " - f"batch_count={batch_count:.1f}, num_to_drop={num_to_drop}, layers_to_drop={ans}") + logging.info( + f"warmup_begin={self.warmup_begin:.1f}, warmup_end={self.warmup_end:.1f}, " + f"batch_count={batch_count:.1f}, num_to_drop={num_to_drop}, layers_to_drop={ans}" + ) return ans - def forward( self, src: Tensor, @@ -639,7 +659,6 @@ def forward( pos_emb = self.encoder_pos(src) output = src - if torch.jit.is_scripting(): layers_to_drop = [] else: @@ -670,28 +689,27 @@ class DownsampledZipformerEncoder(nn.Module): after convolutional downsampling, and then upsampled again at the output, and combined with the origin input, so that the output has the same shape as the input. """ - def __init__(self, - encoder: nn.Module, - input_dim: int, - output_dim: int, - downsample: int): + + def __init__( + self, encoder: nn.Module, input_dim: int, output_dim: int, downsample: int + ): super(DownsampledZipformerEncoder, self).__init__() self.downsample_factor = downsample self.downsample = AttentionDownsample(input_dim, output_dim, downsample) self.encoder = encoder self.upsample = SimpleUpsample(output_dim, downsample) - self.out_combiner = SimpleCombiner(input_dim, - output_dim, - min_weight=(0.0, 0.25)) - - - def forward(self, - src: Tensor, - # Note: the type of feature_mask should be Unino[float, Tensor], - # but to make torch.jit.script() happ, we use float here - feature_mask: float = 1.0, - mask: Optional[Tensor] = None, - src_key_padding_mask: Optional[Tensor] = None, + self.out_combiner = SimpleCombiner( + input_dim, output_dim, min_weight=(0.0, 0.25) + ) + + def forward( + self, + src: Tensor, + # Note: the type of feature_mask should be Unino[float, Tensor], + # but to make torch.jit.script() happ, we use float here + feature_mask: float = 1.0, + mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, ) -> Tensor: r"""Downsample, go through encoder, upsample. @@ -718,42 +736,43 @@ def forward(self, src = self.downsample(src) ds = self.downsample_factor if mask is not None: - mask = mask[::ds,::ds] + mask = mask[::ds, ::ds] src = self.encoder( - src, feature_mask=feature_mask, mask=mask, src_key_padding_mask=mask, + src, + feature_mask=feature_mask, + mask=mask, + src_key_padding_mask=mask, ) src = self.upsample(src) # remove any extra frames that are not a multiple of downsample_factor - src = src[:src_orig.shape[0]] + src = src[: src_orig.shape[0]] return self.out_combiner(src_orig, src) + class AttentionDownsample(torch.nn.Module): """ Does downsampling with attention, by weighted sum, and a projection.. """ - def __init__(self, - in_channels: int, - out_channels: int, - downsample: int): + + def __init__(self, in_channels: int, out_channels: int, downsample: int): """ Require out_channels > in_channels. """ super(AttentionDownsample, self).__init__() - self.query = nn.Parameter(torch.randn(in_channels) * (in_channels ** -0.5)) + self.query = nn.Parameter(torch.randn(in_channels) * (in_channels**-0.5)) # fill in the extra dimensions with a projection of the input if out_channels > in_channels: - self.extra_proj = nn.Linear(in_channels * downsample, - out_channels - in_channels, - bias=False) + self.extra_proj = nn.Linear( + in_channels * downsample, out_channels - in_channels, bias=False + ) else: self.extra_proj = None self.downsample = downsample - def forward(self, - src: Tensor) -> Tensor: + def forward(self, src: Tensor) -> Tensor: """ x: (seq_len, batch_size, in_channels) Returns a tensor of shape @@ -767,16 +786,14 @@ def forward(self, if seq_len != d_seq_len * ds: # right-pad src, repeating the last element. pad = d_seq_len * ds - seq_len - src_extra = src[src.shape[0]-1:].expand(pad, src.shape[1], src.shape[2]) + src_extra = src[src.shape[0] - 1 :].expand(pad, src.shape[1], src.shape[2]) src = torch.cat((src, src_extra), dim=0) assert src.shape[0] == d_seq_len * ds, (src.shape[0], d_seq_len, ds) src = src.reshape(d_seq_len, ds, batch_size, in_channels) scores = (src * self.query).sum(dim=-1, keepdim=True) - scores = penalize_abs_values_gt(scores, - limit=10.0, - penalty=1.0e-04) + scores = penalize_abs_values_gt(scores, limit=10.0, penalty=1.0e-04) weights = scores.softmax(dim=1) @@ -795,14 +812,12 @@ class SimpleUpsample(torch.nn.Module): A very simple form of upsampling that mostly just repeats the input, but also adds a position-specific bias. """ - def __init__(self, - num_channels: int, - upsample: int): + + def __init__(self, num_channels: int, upsample: int): super(SimpleUpsample, self).__init__() self.bias = nn.Parameter(torch.randn(upsample, num_channels) * 0.01) - def forward(self, - src: Tensor) -> Tensor: + def forward(self, src: Tensor) -> Tensor: """ x: (seq_len, batch_size, num_channels) Returns a tensor of shape @@ -815,6 +830,7 @@ def forward(self, src = src.reshape(seq_len * upsample, batch_size, num_channels) return src + class SimpleCombinerIdentity(nn.Module): def __init__(self, *args, **kwargs): super().__init__() @@ -822,6 +838,7 @@ def __init__(self, *args, **kwargs): def forward(self, src1: Tensor, src2: Tensor) -> Tensor: return src1 + class SimpleCombiner(torch.nn.Module): """ A very simple way of combining 2 vectors of 2 different dims, via a @@ -831,18 +848,14 @@ class SimpleCombiner(torch.nn.Module): dim2: the dimension of the second input, e.g. 384. The output will have the same dimension as dim2. """ - def __init__(self, - dim1: int, - dim2: int, - min_weight: Tuple[float] = (0., 0.)): + + def __init__(self, dim1: int, dim2: int, min_weight: Tuple[float] = (0.0, 0.0)): super(SimpleCombiner, self).__init__() assert dim2 >= dim1, (dim2, dim1) self.weight1 = nn.Parameter(torch.zeros(())) self.min_weight = min_weight - def forward(self, - src1: Tensor, - src2: Tensor) -> Tensor: + def forward(self, src1: Tensor, src2: Tensor) -> Tensor: """ src1: (*, dim1) src2: (*, dim2) @@ -853,10 +866,14 @@ def forward(self, weight1 = self.weight1 if not torch.jit.is_scripting(): - if self.training and random.random() < 0.25 and self.min_weight != (0., 0.): - weight1 = weight1.clamp(min=self.min_weight[0], - max=1.0-self.min_weight[1]) - + if ( + self.training + and random.random() < 0.25 + and self.min_weight != (0.0, 0.0) + ): + weight1 = weight1.clamp( + min=self.min_weight[0], max=1.0 - self.min_weight[1] + ) src1 = src1 * weight1 src2 = src2 * (1.0 - weight1) @@ -869,12 +886,9 @@ def forward(self, else: src1 = src1[:src2_dim] - return src1 + src2 - - class RelPositionalEncoding(torch.nn.Module): """Relative positional encoding module. @@ -888,9 +902,7 @@ class RelPositionalEncoding(torch.nn.Module): """ - def __init__( - self, d_model: int, dropout_rate: float, max_len: int = 5000 - ) -> None: + def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None: """Construct a PositionalEncoding object.""" super(RelPositionalEncoding, self).__init__() self.d_model = d_model @@ -905,9 +917,7 @@ def extend_pe(self, x: Tensor) -> None: # the length of self.pe is 2 * input_len - 1 if self.pe.size(1) >= x.size(0) * 2 - 1: # Note: TorchScript doesn't implement operator== for torch.Device - if self.pe.dtype != x.dtype or str(self.pe.device) != str( - x.device - ): + if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device): self.pe = self.pe.to(dtype=x.dtype, device=x.device) return # Suppose `i` means to the position of query vecotr and `j` means the @@ -955,7 +965,6 @@ def forward(self, x: torch.Tensor) -> Tensor: return self.dropout(pos_emb) - class RelPositionMultiheadAttention(nn.Module): r"""Multi-Head Attention layer with relative position encoding @@ -992,34 +1001,43 @@ def __init__( self.head_dim = attention_dim // num_heads self.pos_dim = pos_dim assert self.head_dim % 2 == 0, self.head_dim - assert ( - self.head_dim * num_heads == attention_dim - ), (self.head_dim, num_heads, attention_dim) + assert self.head_dim * num_heads == attention_dim, ( + self.head_dim, + num_heads, + attention_dim, + ) # the initial_scale is supposed to take over the "scaling" factor of # head_dim ** -0.5, dividing it between the query and key. - in_proj_dim = (2 * attention_dim + # query, key - attention_dim // 2 + # value - pos_dim * num_heads) # positional encoding query - - self.in_proj = ScaledLinear(embed_dim, in_proj_dim, bias=True, - initial_scale=self.head_dim**-0.25) + in_proj_dim = ( + 2 * attention_dim + + attention_dim // 2 # query, key + + pos_dim * num_heads # value + ) # positional encoding query + + self.in_proj = ScaledLinear( + embed_dim, in_proj_dim, bias=True, initial_scale=self.head_dim**-0.25 + ) # self.whiten_values is applied on the values in forward(); # it just copies the keys but prevents low-rank distribution by modifying grads. - self.whiten_values = Whiten(num_groups=num_heads, - whitening_limit=2.0, - prob=(0.025, 0.25), - grad_scale=0.025) - self.whiten_keys = Whiten(num_groups=num_heads, - whitening_limit=2.0, - prob=(0.025, 0.25), - grad_scale=0.025) - + self.whiten_values = Whiten( + num_groups=num_heads, + whitening_limit=2.0, + prob=(0.025, 0.25), + grad_scale=0.025, + ) + self.whiten_keys = Whiten( + num_groups=num_heads, + whitening_limit=2.0, + prob=(0.025, 0.25), + grad_scale=0.025, + ) # linear transformation for positional encoding. - self.linear_pos = ScaledLinear(embed_dim, num_heads * pos_dim, bias=False, - initial_scale=0.05) + self.linear_pos = ScaledLinear( + embed_dim, num_heads * pos_dim, bias=False, initial_scale=0.05 + ) # the following are for diagnosics only, see --print-diagnostics option. # they only copy their inputs. @@ -1031,14 +1049,16 @@ def __init__( ) self.in_proj2 = nn.Linear(embed_dim, attention_dim // 2, bias=False) - self.out_proj2 = ScaledLinear(attention_dim // 2, embed_dim, bias=True, - initial_scale=0.05) + self.out_proj2 = ScaledLinear( + attention_dim // 2, embed_dim, bias=True, initial_scale=0.05 + ) # self.whiten_values2 is applied on the values in forward2() - self.whiten_values2 = Whiten(num_groups=num_heads, - whitening_limit=2.0, - prob=(0.025, 0.25), - grad_scale=0.025) - + self.whiten_values2 = Whiten( + num_groups=num_heads, + whitening_limit=2.0, + prob=(0.025, 0.25), + grad_scale=0.025, + ) def forward( self, @@ -1098,7 +1118,6 @@ def forward( ) return x, weights - def multi_head_attention_forward( self, x_proj: Tensor, @@ -1158,24 +1177,21 @@ def multi_head_attention_forward( pos_dim = self.pos_dim # positional-encoding dim per head assert ( head_dim * num_heads == attention_dim - ), f"attention_dim must be divisible by num_heads: {head_dim}, {num_heads}, {attention_dim}" - + ), f"attention_dim must be divisible by num_heads: {head_dim}, {num_heads}, {attention_dim}" # self-attention - q = x_proj[...,0:attention_dim] - k = x_proj[...,attention_dim:2*attention_dim] + q = x_proj[..., 0:attention_dim] + k = x_proj[..., attention_dim : 2 * attention_dim] value_dim = attention_dim // 2 - v = x_proj[...,2*attention_dim:2*attention_dim+value_dim] + v = x_proj[..., 2 * attention_dim : 2 * attention_dim + value_dim] # p is the position-encoding query, its dimension is num_heads*pos_dim.. - p = x_proj[...,2*attention_dim+value_dim:] - + p = x_proj[..., 2 * attention_dim + value_dim :] k = self.whiten_keys(k) # does nothing in the forward pass. v = self.whiten_values(v) # does nothing in the forward pass. q = self.copy_query(q) # for diagnostics only, does nothing. p = self.copy_pos_query(p) # for diagnostics only, does nothing. - if attn_mask is not None: assert ( attn_mask.dtype == torch.float32 @@ -1195,31 +1211,22 @@ def multi_head_attention_forward( if attn_mask.dim() == 2: attn_mask = attn_mask.unsqueeze(0) if list(attn_mask.size()) != [1, seq_len, seq_len]: - raise RuntimeError( - "The size of the 2D attn_mask is not correct." - ) + raise RuntimeError("The size of the 2D attn_mask is not correct.") elif attn_mask.dim() == 3: if list(attn_mask.size()) != [ bsz * num_heads, seq_len, seq_len, ]: - raise RuntimeError( - "The size of the 3D attn_mask is not correct." - ) + raise RuntimeError("The size of the 3D attn_mask is not correct.") else: raise RuntimeError( - "attn_mask's dimension {} is not supported".format( - attn_mask.dim() - ) + "attn_mask's dimension {} is not supported".format(attn_mask.dim()) ) # attn_mask's dim is 3 now. # convert ByteTensor key_padding_mask to bool - if ( - key_padding_mask is not None - and key_padding_mask.dtype == torch.uint8 - ): + if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: warnings.warn( "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead." ) @@ -1230,7 +1237,6 @@ def multi_head_attention_forward( k = k.reshape(seq_len, bsz, num_heads, head_dim) v = v.reshape(seq_len, bsz * num_heads, head_dim // 2).transpose(0, 1) - if key_padding_mask is not None: assert key_padding_mask.size(0) == bsz, "{} == {}".format( key_padding_mask.size(0), bsz @@ -1239,13 +1245,10 @@ def multi_head_attention_forward( key_padding_mask.size(1), seq_len ) - - q = q.permute(1, 2, 0, 3) # (batch, head, time1, head_dim) p = p.permute(1, 2, 0, 3) # (batch, head, time1, pos_dim) k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) - seq_len2 = 2 * seq_len - 1 pos = pos.reshape(1, seq_len2, num_heads, pos_dim).permute(0, 2, 3, 1) # pos shape now: (batch, head, pos_dim, seq_len2) @@ -1256,13 +1259,16 @@ def multi_head_attention_forward( # the following .as_strided() expression converts the last axis of pos_weights from relative # to absolute position. I don't know whether I might have got the time-offsets backwards or # not, but let this code define which way round it is supposed to be. - pos_weights = pos_weights.as_strided((bsz, num_heads, seq_len, seq_len), - (pos_weights.stride(0), - pos_weights.stride(1), - pos_weights.stride(2)-pos_weights.stride(3), - pos_weights.stride(3)), - storage_offset=pos_weights.stride(3) * (seq_len - 1)) - + pos_weights = pos_weights.as_strided( + (bsz, num_heads, seq_len, seq_len), + ( + pos_weights.stride(0), + pos_weights.stride(1), + pos_weights.stride(2) - pos_weights.stride(3), + pos_weights.stride(3), + ), + storage_offset=pos_weights.stride(3) * (seq_len - 1), + ) # caution: they are really scores at this point. attn_output_weights = torch.matmul(q, k) + pos_weights @@ -1275,10 +1281,9 @@ def multi_head_attention_forward( # this mechanism instead of, say, a limit on entropy, because once the entropy # gets very small gradients through the softmax can become very small, and # some mechanisms like that become ineffective. - attn_output_weights = penalize_abs_values_gt(attn_output_weights, - limit=25.0, - penalty=1.0e-04) - + attn_output_weights = penalize_abs_values_gt( + attn_output_weights, limit=25.0, penalty=1.0e-04 + ) # attn_output_weights: (batch, head, time1, time2) attn_output_weights = attn_output_weights.view( @@ -1320,20 +1325,16 @@ def multi_head_attention_forward( ) attn_output = torch.bmm(attn_output_weights, v) - assert list(attn_output.size()) == [bsz * num_heads, seq_len, - head_dim // 2] + assert list(attn_output.size()) == [bsz * num_heads, seq_len, head_dim // 2] attn_output = ( attn_output.transpose(0, 1) .contiguous() .view(seq_len, bsz, attention_dim // 2) ) - attn_output = nn.functional.linear( - attn_output, out_proj_weight, out_proj_bias - ) + attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias) return attn_output, attn_output_weights - def forward2( self, x: Tensor, @@ -1372,11 +1373,7 @@ def forward2( # returned value is of shape (seq_len, bsz, embed_dim), like x. return self.out_proj2(attn_output) - - def _print_attn_stats( - self, - attn_weights: Tensor, - attn_output: Tensor): + def _print_attn_stats(self, attn_weights: Tensor, attn_output: Tensor): # attn_weights: (batch_size * num_heads, seq_len, seq_len) # attn_output: (bsz * num_heads, seq_len, head_dim) (n, seq_len, head_dim) = attn_output.shape @@ -1387,39 +1384,48 @@ def _print_attn_stats( with torch.cuda.amp.autocast(enabled=False): attn_weights = attn_weights.to(torch.float32) attn_output = attn_output.to(torch.float32) - attn_weights_entropy = -((attn_weights + 1.0e-20).log() * attn_weights).sum( - dim=-1).reshape(bsz, num_heads, seq_len).mean(dim=(0,2)) + attn_weights_entropy = ( + -((attn_weights + 1.0e-20).log() * attn_weights) + .sum(dim=-1) + .reshape(bsz, num_heads, seq_len) + .mean(dim=(0, 2)) + ) attn_output = attn_output.reshape(bsz, num_heads, seq_len, head_dim) - attn_output = attn_output.permute(1, 0, 2, 3).reshape(num_heads, bsz * seq_len, head_dim) + attn_output = attn_output.permute(1, 0, 2, 3).reshape( + num_heads, bsz * seq_len, head_dim + ) attn_output_mean = attn_output.mean(dim=1, keepdim=True) attn_output = attn_output - attn_output_mean - attn_covar = torch.matmul(attn_output.transpose(1, 2), attn_output) / (bsz * seq_len) + attn_covar = torch.matmul(attn_output.transpose(1, 2), attn_output) / ( + bsz * seq_len + ) # attn_covar: (num_heads, head_dim, head_dim) - #eigs, _ = torch.symeig(attn_covar) - #logging.info(f"attn_weights_entropy = {attn_weights_entropy}, output_eigs = {eigs}") + # eigs, _ = torch.symeig(attn_covar) + # logging.info(f"attn_weights_entropy = {attn_weights_entropy}, output_eigs = {eigs}") attn_covar = _diag(attn_covar).mean(dim=1) # (num_heads,) embed_dim = self.in_proj2.weight.shape[1] - in_proj_covar = (self.in_proj2.weight.reshape(num_heads, head_dim, embed_dim) ** 2).mean(dim=(1,2)) - out_proj_covar = (self.out_proj2.weight.reshape(embed_dim, num_heads, head_dim) ** 2).mean(dim=(0,2)) - logging.info(f"attn_weights_entropy = {attn_weights_entropy}, covar={attn_covar}, in_proj_covar={in_proj_covar}, out_proj_covar={out_proj_covar}") - - + in_proj_covar = ( + self.in_proj2.weight.reshape(num_heads, head_dim, embed_dim) ** 2 + ).mean(dim=(1, 2)) + out_proj_covar = ( + self.out_proj2.weight.reshape(embed_dim, num_heads, head_dim) ** 2 + ).mean(dim=(0, 2)) + logging.info( + f"attn_weights_entropy = {attn_weights_entropy}, covar={attn_covar}, in_proj_covar={in_proj_covar}, out_proj_covar={out_proj_covar}" + ) class PoolingModule(nn.Module): """ Averages the input over the time dimension and project with a square matrix. """ - def __init__(self, - d_model: int): + + def __init__(self, d_model: int): super().__init__() - self.proj = ScaledLinear(d_model, d_model, - initial_scale=0.1, bias=False) + self.proj = ScaledLinear(d_model, d_model, initial_scale=0.1, bias=False) - def forward(self, - x: Tensor, - key_padding_mask: Optional[Tensor] = None): + def forward(self, x: Tensor, key_padding_mask: Optional[Tensor] = None): """ Args: x: a Tensor of shape (T, N, C) @@ -1430,7 +1436,7 @@ def forward(self, """ if key_padding_mask is not None: pooling_mask = key_padding_mask.logical_not().to(x.dtype) # (N, T) - pooling_mask = (pooling_mask / pooling_mask.sum(dim=1, keepdim=True)) + pooling_mask = pooling_mask / pooling_mask.sum(dim=1, keepdim=True) pooling_mask = pooling_mask.transpose(0, 1).contiguous().unsqueeze(-1) # now pooling_mask: (T, N, 1) x = (x * pooling_mask).sum(dim=0, keepdim=True) @@ -1444,24 +1450,19 @@ def forward(self, class FeedforwardModule(nn.Module): - """Feedforward module in Zipformer model. - """ - def __init__(self, - d_model: int, - feedforward_dim: int, - dropout: float): + """Feedforward module in Zipformer model.""" + + def __init__(self, d_model: int, feedforward_dim: int, dropout: float): super(FeedforwardModule, self).__init__() self.in_proj = nn.Linear(d_model, feedforward_dim) - self.balancer = ActivationBalancer(feedforward_dim, - channel_dim=-1, max_abs=10.0, - min_prob=0.25) + self.balancer = ActivationBalancer( + feedforward_dim, channel_dim=-1, max_abs=10.0, min_prob=0.25 + ) self.activation = DoubleSwish() self.dropout = nn.Dropout(dropout) - self.out_proj = ScaledLinear(feedforward_dim, d_model, - initial_scale=0.01) + self.out_proj = ScaledLinear(feedforward_dim, d_model, initial_scale=0.01) - def forward(self, - x: Tensor): + def forward(self, x: Tensor): x = self.in_proj(x) x = self.balancer(x) x = self.activation(x) @@ -1481,9 +1482,7 @@ class ConvolutionModule(nn.Module): """ - def __init__( - self, channels: int, kernel_size: int, bias: bool = True - ) -> None: + def __init__(self, channels: int, kernel_size: int, bias: bool = True) -> None: """Construct an ConvolutionModule object.""" super(ConvolutionModule, self).__init__() # kernerl_size should be a odd number for 'SAME' padding @@ -1513,7 +1512,10 @@ def __init__( # the correct range. self.deriv_balancer1 = ActivationBalancer( 2 * channels, - channel_dim=1, max_abs=10.0, min_positive=0.05, max_positive=1.0 + channel_dim=1, + max_abs=10.0, + min_positive=0.05, + max_positive=1.0, ) self.depthwise_conv = nn.Conv1d( @@ -1527,8 +1529,10 @@ def __init__( ) self.deriv_balancer2 = ActivationBalancer( - channels, channel_dim=1, - min_positive=0.05, max_positive=1.0, + channels, + channel_dim=1, + min_positive=0.05, + max_positive=1.0, max_abs=20.0, ) @@ -1544,9 +1548,10 @@ def __init__( initial_scale=0.05, ) - def forward(self, - x: Tensor, - src_key_padding_mask: Optional[Tensor] = None, + def forward( + self, + x: Tensor, + src_key_padding_mask: Optional[Tensor] = None, ) -> Tensor: """Compute convolution module. @@ -1626,8 +1631,7 @@ def __init__( kernel_size=3, padding=(0, 1), # (time, freq) ), - ActivationBalancer(layer1_channels, - channel_dim=1), + ActivationBalancer(layer1_channels, channel_dim=1), DoubleSwish(), nn.Conv2d( in_channels=layer1_channels, @@ -1636,24 +1640,21 @@ def __init__( stride=2, padding=0, ), - ActivationBalancer(layer2_channels, - channel_dim=1), + ActivationBalancer(layer2_channels, channel_dim=1), DoubleSwish(), nn.Conv2d( in_channels=layer2_channels, out_channels=layer3_channels, kernel_size=3, - stride=(1, 2), # (time, freq) + stride=(1, 2), # (time, freq) ), - ActivationBalancer(layer3_channels, - channel_dim=1), + ActivationBalancer(layer3_channels, channel_dim=1), DoubleSwish(), ) out_height = (((in_channels - 1) // 2) - 1) // 2 self.out = ScaledLinear(out_height * layer3_channels, out_channels) self.dropout = nn.Dropout(dropout) - def forward(self, x: torch.Tensor) -> torch.Tensor: """Subsample x. @@ -1674,6 +1675,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.dropout(x) return x + class AttentionCombine(nn.Module): """ This module combines a list of Tensors, all with the same shape, to @@ -1717,15 +1719,12 @@ def __init__( self.random_prob = random_prob self.single_prob = single_prob - self.weight = torch.nn.Parameter(torch.zeros(num_channels, - num_inputs)) + self.weight = torch.nn.Parameter(torch.zeros(num_channels, num_inputs)) self.bias = torch.nn.Parameter(torch.zeros(num_inputs)) assert 0 <= random_prob <= 1, random_prob assert 0 <= single_prob <= 1, single_prob - - def forward(self, inputs: List[Tensor]) -> Tensor: """Forward function. Args: @@ -1756,28 +1755,35 @@ def forward(self, inputs: List[Tensor]) -> Tensor: if self.training: # random masking.. - mask_start = torch.randint(low=1, high=int(num_inputs / self.random_prob), - size=(num_frames,), device=scores.device).unsqueeze(1) + mask_start = torch.randint( + low=1, + high=int(num_inputs / self.random_prob), + size=(num_frames,), + device=scores.device, + ).unsqueeze(1) # mask will have rows like: [ False, False, False, True, True, .. ] - arange = torch.arange(num_inputs, device=scores.device).unsqueeze(0).expand( - num_frames, num_inputs) + arange = ( + torch.arange(num_inputs, device=scores.device) + .unsqueeze(0) + .expand(num_frames, num_inputs) + ) mask = arange >= mask_start - apply_single_prob = torch.logical_and(torch.rand(size=(num_frames, 1), - device=scores.device) < self.single_prob, - mask_start < num_inputs) - single_prob_mask = torch.logical_and(apply_single_prob, - arange < mask_start - 1) + apply_single_prob = torch.logical_and( + torch.rand(size=(num_frames, 1), device=scores.device) + < self.single_prob, + mask_start < num_inputs, + ) + single_prob_mask = torch.logical_and( + apply_single_prob, arange < mask_start - 1 + ) - mask = torch.logical_or(mask, - single_prob_mask) + mask = torch.logical_or(mask, single_prob_mask) - scores = scores.masked_fill(mask, float('-inf')) + scores = scores.masked_fill(mask, float("-inf")) if self.training and random.random() < 0.1: - scores = penalize_abs_values_gt(scores, - limit=10.0, - penalty=1.0e-04) + scores = penalize_abs_values_gt(scores, limit=10.0, penalty=1.0e-04) weights = scores.softmax(dim=1) @@ -1792,7 +1798,6 @@ def forward(self, inputs: List[Tensor]) -> Tensor: return ans - def _test_random_combine(): print("_test_random_combine()") num_inputs = 3 @@ -1801,8 +1806,8 @@ def _test_random_combine(): num_channels=num_channels, num_inputs=num_inputs, random_prob=0.5, - single_prob=0.0) - + single_prob=0.0, + ) x = [torch.ones(3, 4, num_channels) for _ in range(num_inputs)] @@ -1819,7 +1824,10 @@ def _test_zipformer_main(): # Just make sure the forward pass runs. c = Zipformer( - num_features=feature_dim, encoder_dims=(64,96), encoder_unmasked_dims=(48,64), nhead=(4,4) + num_features=feature_dim, + encoder_dims=(64, 96), + encoder_unmasked_dims=(48, 64), + nhead=(4, 4), ) batch_size = 5 seq_len = 20 @@ -1837,19 +1845,18 @@ def _test_zipformer_main(): ) f # to remove flake8 warnings + def _test_conv2d_subsampling(): num_features = 80 encoder_dims = 384 dropout = 0.1 - encoder_embed = Conv2dSubsampling(num_features, encoder_dims, - dropout=dropout) + encoder_embed = Conv2dSubsampling(num_features, encoder_dims, dropout=dropout) for i in range(20, 40): x = torch.rand(2, i, num_features) y = encoder_embed(x) assert (x.shape[1] - 7) // 2 == y.shape[1], (x.shape[1], y.shape[1]) - if __name__ == "__main__": logging.getLogger().setLevel(logging.INFO) torch.set_num_threads(1) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless8/decode.py index 9d7335e777..3d89ae00ae 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless8/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless8/decode.py @@ -273,8 +273,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -394,9 +393,7 @@ def decode_one_batch( simulate_streaming=True, ) else: - encoder_out, encoder_out_lens = model.encoder( - x=feature, x_lens=feature_lens - ) + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) hyps = [] @@ -455,10 +452,7 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif ( - params.decoding_method == "greedy_search" - and params.max_sym_per_frame == 1 - ): + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -589,9 +583,7 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -624,8 +616,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -680,9 +671,7 @@ def main(): if "LG" in params.decoding_method: params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" elif "beam_search" in params.decoding_method: - params.suffix += ( - f"-{params.decoding_method}-beam-size-{params.beam_size}" - ) + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -719,9 +708,9 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -753,9 +742,9 @@ def main(): ) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -816,9 +805,7 @@ def main(): decoding_graph.scores *= params.ngram_lm_scale else: word_table = None - decoding_graph = k2.trivial_graph( - params.vocab_size - 1, device=device - ) + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) else: decoding_graph = None word_table = None diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/export.py b/egs/librispeech/ASR/pruned_transducer_stateless8/export.py index 49f469e294..0a962149d7 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless8/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless8/export.py @@ -176,8 +176,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) add_model_arguments(parser) @@ -217,9 +216,9 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -252,9 +251,9 @@ def main(): ) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -326,9 +325,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/jit_pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless8/jit_pretrained.py index e79a3a3aa3..c458ee5a9d 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless8/jit_pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless8/jit_pretrained.py @@ -94,8 +94,7 @@ def read_sound_files( for f in filenames: wave, sample_rate = torchaudio.load(f) assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" + f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}" ) # We use only the first channel ans.append(wave[0]) @@ -267,9 +266,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/model.py b/egs/librispeech/ASR/pruned_transducer_stateless8/model.py index 497b891360..39a3607968 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless8/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless8/model.py @@ -160,9 +160,7 @@ def forward( y_padded = y.pad(mode="constant", padding_value=0) y_padded = y_padded.to(torch.int64) - boundary = torch.zeros( - (x.size(0), 4), dtype=torch.int64, device=x.device - ) + boundary = torch.zeros((x.size(0), 4), dtype=torch.int64, device=x.device) boundary[:, 2] = y_lens boundary[:, 3] = x_lens diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless8/pretrained.py index 373a48fc1a..f1f0771efe 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless8/pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless8/pretrained.py @@ -177,8 +177,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -210,8 +209,7 @@ def read_sound_files( for f in filenames: wave, sample_rate = torchaudio.load(f) assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" + f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}" ) # We use only the first channel ans.append(wave[0]) @@ -275,15 +273,11 @@ def main(): features = fbank(waves) feature_lengths = [f.size(0) for f in features] - features = pad_sequence( - features, batch_first=True, padding_value=math.log(1e-10) - ) + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) feature_lengths = torch.tensor(feature_lengths, device=device) - encoder_out, encoder_out_lens = model.encoder( - x=features, x_lens=feature_lengths - ) + encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lengths) num_waves = encoder_out.size(0) hyps = [] @@ -355,9 +349,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/train.py b/egs/librispeech/ASR/pruned_transducer_stateless8/train.py index 2603bb8547..ba8ed3ea89 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless8/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless8/train.py @@ -92,9 +92,7 @@ from icefall.hooks import register_inf_check_hooks from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool -LRSchedulerType = Union[ - torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler -] +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: @@ -214,8 +212,7 @@ def get_parser(): "--full-libri", type=str2bool, default=True, - help="When enabled, use 960h LibriSpeech. " - "Otherwise, use 100h subset.", + help="When enabled, use 960h LibriSpeech. " "Otherwise, use 100h subset.", ) parser.add_argument( @@ -285,8 +282,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( @@ -309,8 +305,7 @@ def get_parser(): "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)" - "part.", + help="The scale to smooth the loss with am (output of encoder network)" "part.", ) parser.add_argument( @@ -691,11 +686,7 @@ def compute_loss( warmup: a floating point value which increases throughout training; values >= 1.0 are fully warmed up and have all modules present. """ - device = ( - model.device - if isinstance(model, DDP) - else next(model.parameters()).device - ) + device = model.device if isinstance(model, DDP) else next(model.parameters()).device feature = batch["inputs"] # at entry, feature is (N, T, C) assert feature.ndim == 3 @@ -744,9 +735,7 @@ def compute_loss( info = MetricsTracker() with warnings.catch_warnings(): warnings.simplefilter("ignore") - info["frames"] = ( - (feature_lens // params.subsampling_factor).sum().item() - ) + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() # Note: We use reduction=sum while computing the loss. info["loss"] = loss.detach().cpu().item() @@ -952,9 +941,7 @@ def train_one_epoch( # of the grad scaler is configurable, but we can't configure it to have different # behavior depending on the current grad scale. cur_grad_scale = scaler._scale.item() - if cur_grad_scale < 1.0 or ( - cur_grad_scale < 8.0 and batch_idx % 400 == 0 - ): + if cur_grad_scale < 1.0 or (cur_grad_scale < 8.0 and batch_idx % 400 == 0): scaler.update(cur_grad_scale * 2.0) if cur_grad_scale < 0.01: logging.warning(f"Grad scale is small: {cur_grad_scale}") @@ -975,11 +962,7 @@ def train_one_epoch( f"giga_tot_loss[{giga_tot_loss}], " f"batch size: {batch_size}, " f"lr: {cur_lr:.2e}, " - + ( - f"grad_scale: {scaler._scale.item()}" - if params.use_fp16 - else "" - ) + + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") ) if tb_writer is not None: @@ -992,12 +975,8 @@ def train_one_epoch( f"train/current_{prefix}_", params.batch_idx_train, ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) libri_tot_loss.write_summary( tb_writer, "train/libri_tot_", params.batch_idx_train ) @@ -1011,10 +990,7 @@ def train_one_epoch( params.batch_idx_train, ) - if ( - batch_idx % params.valid_interval == 0 - and not params.print_diagnostics - ): + if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: logging.info("Computing validation loss") valid_info = compute_validation_loss( params=params, @@ -1054,8 +1030,7 @@ def remove_short_and_long_utt(c: Cut): # the threshold if c.duration < 1.0 or c.duration > 20.0: logging.warning( - f"Exclude cut with ID {c.id} from training. " - f"Duration: {c.duration}" + f"Exclude cut with ID {c.id} from training. " f"Duration: {c.duration}" ) return False @@ -1152,9 +1127,7 @@ def run(rank, world_size, args): logging.info("Using DDP") model = DDP(model, device_ids=[rank], find_unused_parameters=True) - optimizer = ScaledAdam( - model.parameters(), lr=params.base_lr, clipping_scale=2.0 - ) + optimizer = ScaledAdam(model.parameters(), lr=params.base_lr, clipping_scale=2.0) scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) @@ -1172,7 +1145,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2 ** 22 + 2**22 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) @@ -1207,9 +1180,7 @@ def run(rank, world_size, args): train_giga_cuts = train_giga_cuts.repeat(times=None) if args.enable_musan: - cuts_musan = load_manifest( - Path(args.manifest_dir) / "musan_cuts.jsonl.gz" - ) + cuts_musan = load_manifest(Path(args.manifest_dir) / "musan_cuts.jsonl.gz") else: cuts_musan = None diff --git a/egs/librispeech/ASR/streaming_conformer_ctc/README.md b/egs/librispeech/ASR/streaming_conformer_ctc/README.md index 01be7090b5..53f383c996 100644 --- a/egs/librispeech/ASR/streaming_conformer_ctc/README.md +++ b/egs/librispeech/ASR/streaming_conformer_ctc/README.md @@ -1,20 +1,20 @@ ## Train and Decode -Commands of data preparation/train/decode steps are almost the same with +Commands of data preparation/train/decode steps are almost the same with ../conformer_ctc experiment except some options. Please read the code and understand following new added options before running this experiment: For data preparation: - + Nothing new. For streaming_conformer_ctc/train.py: - + --dynamic-chunk-training --short-chunk-proportion For streaming_conformer_ctc/streaming_decode.py: - + --chunk-size --tailing-num-frames --simulate-streaming @@ -57,10 +57,10 @@ And check md5sum values again. Finally, following files will be downloaded: -streaming_models/ -|-- lang_bpe -| |-- L.pt -| |-- Linv.pt +streaming_models/ +|-- lang_bpe +| |-- L.pt +| |-- Linv.pt | |-- bpe.model | |-- tokens.txt | `-- words.txt diff --git a/egs/librispeech/ASR/streaming_conformer_ctc/conformer.py b/egs/librispeech/ASR/streaming_conformer_ctc/conformer.py index ff4c914462..5fe92172ee 100644 --- a/egs/librispeech/ASR/streaming_conformer_ctc/conformer.py +++ b/egs/librispeech/ASR/streaming_conformer_ctc/conformer.py @@ -309,36 +309,26 @@ def eval_run_encoder( # start chunk_by_chunk decoding offset = 0 - for cur in range( - 0, num_frames - embed_left_context + 1, stride - ): + for cur in range(0, num_frames - embed_left_context + 1, stride): end = min(cur + decoding_window, num_frames) cur_feature = feature[:, cur:end, :] cur_feature = self.encoder_embed(cur_feature) - cur_embed, cur_pos_emb = self.encoder_pos( - cur_feature, offset - ) - cur_embed = cur_embed.permute( - 1, 0, 2 - ) # (B, T, F) -> (T, B, F) + cur_embed, cur_pos_emb = self.encoder_pos(cur_feature, offset) + cur_embed = cur_embed.permute(1, 0, 2) # (B, T, F) -> (T, B, F) cur_T = cur_feature.size(1) if cur == 0: # for first chunk extract the central pos embedding - pos_emb_central = cur_pos_emb[ - 0, (chunk_size - 1), : - ].view(1, 1, -1) + pos_emb_central = cur_pos_emb[0, (chunk_size - 1), :].view( + 1, 1, -1 + ) cur_T -= 1 pos_emb_positive.append(cur_pos_emb[0, :cur_T].flip(0)) pos_emb_negative.append(cur_pos_emb[0, -cur_T:]) assert pos_emb_positive[-1].size(0) == cur_T - pos_emb_pos = torch.cat(pos_emb_positive, dim=0).unsqueeze( - 0 - ) - pos_emb_neg = torch.cat(pos_emb_negative, dim=0).unsqueeze( - 0 - ) + pos_emb_pos = torch.cat(pos_emb_positive, dim=0).unsqueeze(0) + pos_emb_neg = torch.cat(pos_emb_negative, dim=0).unsqueeze(0) cur_pos_emb = torch.cat( [pos_emb_pos.flip(1), pos_emb_central, pos_emb_neg], dim=1, @@ -413,9 +403,7 @@ def __init__( causal: bool = False, ) -> None: super(ConformerEncoderLayer, self).__init__() - self.self_attn = RelPositionMultiheadAttention( - d_model, nhead, dropout=0.0 - ) + self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0) self.feed_forward = nn.Sequential( nn.Linear(d_model, dim_feedforward), @@ -431,22 +419,16 @@ def __init__( nn.Linear(dim_feedforward, d_model), ) - self.conv_module = ConvolutionModule( - d_model, cnn_module_kernel, causal=causal - ) + self.conv_module = ConvolutionModule(d_model, cnn_module_kernel, causal=causal) - self.norm_ff_macaron = nn.LayerNorm( - d_model - ) # for the macaron style FNN module + self.norm_ff_macaron = nn.LayerNorm(d_model) # for the macaron style FNN module self.norm_ff = nn.LayerNorm(d_model) # for the FNN module self.norm_mha = nn.LayerNorm(d_model) # for the MHA module self.ff_scale = 0.5 self.norm_conv = nn.LayerNorm(d_model) # for the CNN module - self.norm_final = nn.LayerNorm( - d_model - ) # for the final output of the block + self.norm_final = nn.LayerNorm(d_model) # for the final output of the block self.dropout = nn.Dropout(dropout) @@ -480,9 +462,7 @@ def forward( residual = src if self.normalize_before: src = self.norm_ff_macaron(src) - src = residual + self.ff_scale * self.dropout( - self.feed_forward_macaron(src) - ) + src = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(src)) if not self.normalize_before: src = self.norm_ff_macaron(src) @@ -554,9 +534,7 @@ def chunk_forward( residual = src if self.normalize_before: src = self.norm_ff_macaron(src) - src = residual + self.ff_scale * self.dropout( - self.feed_forward_macaron(src) - ) + src = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(src)) if not self.normalize_before: src = self.norm_ff_macaron(src) @@ -736,9 +714,7 @@ class RelPositionalEncoding(torch.nn.Module): """ - def __init__( - self, d_model: int, dropout_rate: float, max_len: int = 5000 - ) -> None: + def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None: """Construct an PositionalEncoding object.""" super(RelPositionalEncoding, self).__init__() self.d_model = d_model @@ -755,9 +731,7 @@ def extend_pe(self, x: Tensor, offset: int = 0) -> None: # the length of self.pe is 2 * input_len - 1 if self.pe.size(1) >= x_size_1 * 2 - 1: # Note: TorchScript doesn't implement operator== for torch.Device - if self.pe.dtype != x.dtype or str(self.pe.device) != str( - x.device - ): + if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device): self.pe = self.pe.to(dtype=x.dtype, device=x.device) return # Suppose `i` means to the position of query vector and `j` means the @@ -783,9 +757,7 @@ def extend_pe(self, x: Tensor, offset: int = 0) -> None: pe = torch.cat([pe_positive, pe_negative], dim=1) self.pe = pe.to(device=x.device, dtype=x.dtype) - def forward( - self, x: torch.Tensor, offset: int = 0 - ) -> Tuple[Tensor, Tensor]: + def forward(self, x: torch.Tensor, offset: int = 0) -> Tuple[Tensor, Tensor]: """Add positional encoding. Args: @@ -813,9 +785,7 @@ def forward( pos_emb = torch.cat( [ pos_emb[:, : (x_T - 1)], - self.pe[0, self.pe.size(1) // 2].view( - 1, 1, self.pe.size(-1) - ), + self.pe[0, self.pe.size(1) // 2].view(1, 1, self.pe.size(-1)), pos_emb[:, -(x_T - 1) :], # noqa: E203 ], dim=1, @@ -1050,9 +1020,9 @@ def multi_head_attention_forward( if torch.equal(query, key) and torch.equal(key, value): # self-attention - q, k, v = nn.functional.linear( - query, in_proj_weight, in_proj_bias - ).chunk(3, dim=-1) + q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk( + 3, dim=-1 + ) elif torch.equal(key, value): # encoder-decoder attention @@ -1120,31 +1090,22 @@ def multi_head_attention_forward( if attn_mask.dim() == 2: attn_mask = attn_mask.unsqueeze(0) if list(attn_mask.size()) != [1, query.size(0), key.size(0)]: - raise RuntimeError( - "The size of the 2D attn_mask is not correct." - ) + raise RuntimeError("The size of the 2D attn_mask is not correct.") elif attn_mask.dim() == 3: if list(attn_mask.size()) != [ bsz * num_heads, query.size(0), key.size(0), ]: - raise RuntimeError( - "The size of the 3D attn_mask is not correct." - ) + raise RuntimeError("The size of the 3D attn_mask is not correct.") else: raise RuntimeError( - "attn_mask's dimension {} is not supported".format( - attn_mask.dim() - ) + "attn_mask's dimension {} is not supported".format(attn_mask.dim()) ) # attn_mask's dim is 3 now. # convert ByteTensor key_padding_mask to bool - if ( - key_padding_mask is not None - and key_padding_mask.dtype == torch.uint8 - ): + if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: warnings.warn( "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead." ) @@ -1185,24 +1146,16 @@ def multi_head_attention_forward( # first compute matrix a and matrix c # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) - matrix_ac = torch.matmul( - q_with_bias_u, k - ) # (batch, head, time1, time2) + matrix_ac = torch.matmul(q_with_bias_u, k) # (batch, head, time1, time2) # compute matrix b and matrix d - matrix_bd = torch.matmul( - q_with_bias_v, p - ) # (batch, head, time1, 2*time1-1) - matrix_bd = self.rel_shift( - matrix_bd, offset=offset - ) # [B, head, time1, time2] + matrix_bd = torch.matmul(q_with_bias_v, p) # (batch, head, time1, 2*time1-1) + matrix_bd = self.rel_shift(matrix_bd, offset=offset) # [B, head, time1, time2] attn_output_weights = ( matrix_ac + matrix_bd ) * scaling # (batch, head, time1, time2) - attn_output_weights = attn_output_weights.view( - bsz * num_heads, tgt_len, -1 - ) + attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1) assert list(attn_output_weights.size()) == [ bsz * num_heads, @@ -1236,13 +1189,9 @@ def multi_head_attention_forward( attn_output = torch.bmm(attn_output_weights, v) assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim] attn_output = ( - attn_output.transpose(0, 1) - .contiguous() - .view(tgt_len, bsz, embed_dim) - ) - attn_output = nn.functional.linear( - attn_output, out_proj_weight, out_proj_bias + attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) ) + attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias) if need_weights: # average attention weights over heads diff --git a/egs/librispeech/ASR/streaming_conformer_ctc/streaming_decode.py b/egs/librispeech/ASR/streaming_conformer_ctc/streaming_decode.py index a74c518366..3965bd5c3f 100755 --- a/egs/librispeech/ASR/streaming_conformer_ctc/streaming_decode.py +++ b/egs/librispeech/ASR/streaming_conformer_ctc/streaming_decode.py @@ -28,6 +28,7 @@ import torch.nn as nn from asr_datamodule import LibriSpeechAsrDataModule from conformer import Conformer + from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler from icefall.checkpoint import average_checkpoints, load_checkpoint from icefall.lexicon import Lexicon @@ -86,8 +87,7 @@ def get_parser(): "--tailing-num-frames", type=int, default=20, - help="tailing dummy frames padded to the right," - "only used during decoding", + help="tailing dummy frames padded to the right," "only used during decoding", ) parser.add_argument( @@ -248,13 +248,9 @@ def decode_one_batch( maxlen = nnet_output.size(1) topk_prob, topk_index = nnet_output.topk(1, dim=2) # (B, maxlen, 1) topk_index = topk_index.view(batch_size, maxlen) # (B, maxlen) - topk_index = topk_index.masked_fill_( - memory_key_padding_mask, 0 - ) # (B, maxlen) + topk_index = topk_index.masked_fill_(memory_key_padding_mask, 0) # (B, maxlen) token_ids = [token_id.tolist() for token_id in topk_index] - token_ids = [ - remove_duplicates_and_blank(token_id) for token_id in token_ids - ] + token_ids = [remove_duplicates_and_blank(token_id) for token_id in token_ids] hyps = bpe_model.decode(token_ids) hyps = [s.split() for s in hyps] return {key: hyps} @@ -337,9 +333,7 @@ def decode_dataset( if batch_idx % 100 == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -364,8 +358,7 @@ def save_results( -{params.chunk_size}-tailing-num-frames-{params.tailing_num_frames}-" for key, results in results_dict.items(): recog_path = ( - params.exp_dir - / f"{result_file_prefix}recogs-{test_set_name}-{key}.txt" + params.exp_dir / f"{result_file_prefix}recogs-{test_set_name}-{key}.txt" ) store_transcripts(filename=recog_path, texts=results) if enable_log: @@ -374,8 +367,7 @@ def save_results( # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. errs_filename = ( - params.exp_dir - / f"{result_file_prefix}-errs-{test_set_name}-{key}.txt" + params.exp_dir / f"{result_file_prefix}-errs-{test_set_name}-{key}.txt" ) with open(errs_filename, "w") as f: wer = write_error_stats( @@ -384,9 +376,7 @@ def save_results( test_set_wers[key] = wer if enable_log: - logging.info( - "Wrote detailed error stats to {}".format(errs_filename) - ) + logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = params.exp_dir / f"wer-summary-{test_set_name}.txt" @@ -474,9 +464,7 @@ def main(): if params.export: logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt") - torch.save( - {"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt" - ) + torch.save({"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt") return model.to(device) @@ -507,9 +495,7 @@ def main(): simulate_streaming=params.simulate_streaming, ) - save_results( - params=params, test_set_name=test_set, results_dict=results_dict - ) + save_results(params=params, test_set_name=test_set, results_dict=results_dict) logging.info("Done!") diff --git a/egs/librispeech/ASR/streaming_conformer_ctc/train.py b/egs/librispeech/ASR/streaming_conformer_ctc/train.py index e41b7ea78a..553b7d0922 100755 --- a/egs/librispeech/ASR/streaming_conformer_ctc/train.py +++ b/egs/librispeech/ASR/streaming_conformer_ctc/train.py @@ -405,9 +405,7 @@ def compute_loss( # # See https://github.com/k2-fsa/icefall/issues/97 # for more details - unsorted_token_ids = graph_compiler.texts_to_ids( - supervisions["text"] - ) + unsorted_token_ids = graph_compiler.texts_to_ids(supervisions["text"]) att_loss = mmodel.decoder_forward( encoder_memory, memory_mask, @@ -436,9 +434,7 @@ def compute_loss( info["utt_duration"] = supervisions["num_frames"].sum().item() # averaged padding proportion over utterances info["utt_pad_proportion"] = ( - ((feature.size(1) - supervisions["num_frames"]) / feature.size(1)) - .sum() - .item() + ((feature.size(1) - supervisions["num_frames"]) / feature.size(1)).sum().item() ) return loss, info @@ -551,9 +547,7 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -668,9 +662,7 @@ def run(rank, world_size, args): cur_lr = optimizer._rate if tb_writer is not None: - tb_writer.add_scalar( - "train/learning_rate", cur_lr, params.batch_idx_train - ) + tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train) tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) if rank == 0: diff --git a/egs/librispeech/ASR/streaming_conformer_ctc/transformer.py b/egs/librispeech/ASR/streaming_conformer_ctc/transformer.py index bc78e4a417..0c87fdf1bb 100644 --- a/egs/librispeech/ASR/streaming_conformer_ctc/transformer.py +++ b/egs/librispeech/ASR/streaming_conformer_ctc/transformer.py @@ -149,9 +149,7 @@ def __init__( norm=decoder_norm, ) - self.decoder_output_layer = torch.nn.Linear( - d_model, self.decoder_num_class - ) + self.decoder_output_layer = torch.nn.Linear(d_model, self.decoder_num_class) self.decoder_criterion = LabelSmoothingLoss() else: @@ -286,23 +284,17 @@ def decoder_forward( """ ys_in = add_sos(token_ids, sos_id=sos_id) ys_in = [torch.tensor(y) for y in ys_in] - ys_in_pad = pad_sequence( - ys_in, batch_first=True, padding_value=float(eos_id) - ) + ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id)) ys_out = add_eos(token_ids, eos_id=eos_id) ys_out = [torch.tensor(y) for y in ys_out] - ys_out_pad = pad_sequence( - ys_out, batch_first=True, padding_value=float(-1) - ) + ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1)) device = memory.device ys_in_pad = ys_in_pad.to(device) ys_out_pad = ys_out_pad.to(device) - tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to( - device - ) + tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device) tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) # TODO: Use length information to create the decoder padding mask @@ -363,23 +355,17 @@ def decoder_nll( ys_in = add_sos(token_ids, sos_id=sos_id) ys_in = [torch.tensor(y) for y in ys_in] - ys_in_pad = pad_sequence( - ys_in, batch_first=True, padding_value=float(eos_id) - ) + ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id)) ys_out = add_eos(token_ids, eos_id=eos_id) ys_out = [torch.tensor(y) for y in ys_out] - ys_out_pad = pad_sequence( - ys_out, batch_first=True, padding_value=float(-1) - ) + ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1)) device = memory.device ys_in_pad = ys_in_pad.to(device, dtype=torch.int64) ys_out_pad = ys_out_pad.to(device, dtype=torch.int64) - tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to( - device - ) + tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device) tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) # TODO: Use length information to create the decoder padding mask @@ -652,9 +638,7 @@ def _get_activation_fn(activation: str): elif activation == "gelu": return nn.functional.gelu - raise RuntimeError( - "activation should be relu/gelu, not {}".format(activation) - ) + raise RuntimeError("activation should be relu/gelu, not {}".format(activation)) class PositionalEncoding(nn.Module): @@ -856,9 +840,7 @@ def encoder_padding_mask( 1, ).to(torch.int32) - lengths = [ - 0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1) - ] + lengths = [0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1)] for idx in range(supervision_segments.size(0)): # Note: TorchScript doesn't allow to unpack tensors as tuples sequence_idx = supervision_segments[idx, 0].item() @@ -879,9 +861,7 @@ def encoder_padding_mask( return mask -def decoder_padding_mask( - ys_pad: torch.Tensor, ignore_id: int = -1 -) -> torch.Tensor: +def decoder_padding_mask(ys_pad: torch.Tensor, ignore_id: int = -1) -> torch.Tensor: """Generate a length mask for input. The masked position are filled with True, diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py index 355ccc99a9..993a7cab50 100644 --- a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py @@ -86,8 +86,7 @@ def add_arguments(cls, parser: argparse.ArgumentParser): "--full-libri", type=str2bool, default=True, - help="When enabled, use 960h LibriSpeech. " - "Otherwise, use 100h subset.", + help="When enabled, use 960h LibriSpeech. " "Otherwise, use 100h subset.", ) group.add_argument( "--manifest-dir", @@ -224,13 +223,9 @@ def train_dataloaders( if self.args.enable_musan: logging.info("Enable MUSAN") logging.info("About to get Musan cuts") - cuts_musan = load_manifest( - self.args.manifest_dir / "musan_cuts.jsonl.gz" - ) + cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") transforms.append( - CutMix( - cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True - ) + CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) ) else: logging.info("Disable MUSAN") @@ -252,9 +247,7 @@ def train_dataloaders( input_transforms = [] if self.args.enable_spec_aug: logging.info("Enable SpecAugment") - logging.info( - f"Time warp factor: {self.args.spec_aug_time_warp_factor}" - ) + logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") # Set the value of num_frame_masks according to Lhotse's version. # In different Lhotse's versions, the default of num_frame_masks is # different. @@ -298,9 +291,7 @@ def train_dataloaders( # Drop feats to be on the safe side. train = K2SpeechRecognitionDataset( cut_transforms=transforms, - input_strategy=OnTheFlyFeatures( - Fbank(FbankConfig(num_mel_bins=80)) - ), + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), input_transforms=input_transforms, return_cuts=self.args.return_cuts, ) @@ -356,9 +347,7 @@ def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: if self.args.on_the_fly_feats: validate = K2SpeechRecognitionDataset( cut_transforms=transforms, - input_strategy=OnTheFlyFeatures( - Fbank(FbankConfig(num_mel_bins=80)) - ), + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), return_cuts=self.args.return_cuts, ) else: diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py b/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py index 7d0cd0bf32..92529e06c5 100755 --- a/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py @@ -336,9 +336,7 @@ def decode_dataset( if batch_idx % 100 == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -400,9 +398,7 @@ def main(): logging.info(f"device: {device}") - HLG = k2.Fsa.from_dict( - torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu") - ) + HLG = k2.Fsa.from_dict(torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu")) HLG = HLG.to(device) assert HLG.requires_grad is False @@ -467,9 +463,7 @@ def main(): if params.export: logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt") - torch.save( - {"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt" - ) + torch.save({"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt") return model.to(device) @@ -498,9 +492,7 @@ def main(): G=G, ) - save_results( - params=params, test_set_name=test_set, results_dict=results_dict - ) + save_results(params=params, test_set_name=test_set, results_dict=results_dict) logging.info("Done!") diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/model.py b/egs/librispeech/ASR/tdnn_lstm_ctc/model.py index 5e04c11b4a..1731e1ebef 100644 --- a/egs/librispeech/ASR/tdnn_lstm_ctc/model.py +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/model.py @@ -66,10 +66,7 @@ def __init__( nn.BatchNorm1d(num_features=500, affine=False), ) self.lstms = nn.ModuleList( - [ - nn.LSTM(input_size=500, hidden_size=500, num_layers=1) - for _ in range(5) - ] + [nn.LSTM(input_size=500, hidden_size=500, num_layers=1) for _ in range(5)] ) self.lstm_bnorms = nn.ModuleList( [nn.BatchNorm1d(num_features=500, affine=False) for _ in range(5)] diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/pretrained.py b/egs/librispeech/ASR/tdnn_lstm_ctc/pretrained.py index 2baeb6bba4..addadbe4e4 100755 --- a/egs/librispeech/ASR/tdnn_lstm_ctc/pretrained.py +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/pretrained.py @@ -29,11 +29,7 @@ from model import TdnnLstm from torch.nn.utils.rnn import pad_sequence -from icefall.decode import ( - get_lattice, - one_best_decoding, - rescore_with_whole_lattice, -) +from icefall.decode import get_lattice, one_best_decoding, rescore_with_whole_lattice from icefall.utils import AttributeDict, get_texts @@ -58,9 +54,7 @@ def get_parser(): help="Path to words.txt", ) - parser.add_argument( - "--HLG", type=str, required=True, help="Path to HLG.pt." - ) + parser.add_argument("--HLG", type=str, required=True, help="Path to HLG.pt.") parser.add_argument( "--method", @@ -145,8 +139,7 @@ def read_sound_files( for f in filenames: wave, sample_rate = torchaudio.load(f) assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" + f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}" ) # We use only the first channel ans.append(wave[0]) @@ -215,9 +208,7 @@ def main(): logging.info("Decoding started") features = fbank(waves) - features = pad_sequence( - features, batch_first=True, padding_value=math.log(1e-10) - ) + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) features = features.permute(0, 2, 1) # now features is (N, C, T) with torch.no_grad(): @@ -269,9 +260,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/train.py b/egs/librispeech/ASR/tdnn_lstm_ctc/train.py index 6b37d5c231..071ac792bd 100755 --- a/egs/librispeech/ASR/tdnn_lstm_ctc/train.py +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/train.py @@ -355,9 +355,7 @@ def compute_loss( info["utt_duration"] = supervisions["num_frames"].sum().item() # averaged padding proportion over utterances info["utt_pad_proportion"] = ( - ((feature.size(2) - supervisions["num_frames"]) / feature.size(2)) - .sum() - .item() + ((feature.size(2) - supervisions["num_frames"]) / feature.size(2)).sum().item() ) return loss, info @@ -470,9 +468,7 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) if batch_idx > 0 and batch_idx % params.valid_interval == 0: valid_info = compute_validation_loss( diff --git a/egs/librispeech/ASR/transducer/beam_search.py b/egs/librispeech/ASR/transducer/beam_search.py index 11032f31ae..b45b6a9d8d 100644 --- a/egs/librispeech/ASR/transducer/beam_search.py +++ b/egs/librispeech/ASR/transducer/beam_search.py @@ -38,9 +38,7 @@ def greedy_search(model: Transducer, encoder_out: torch.Tensor) -> List[int]: blank_id = model.decoder.blank_id device = model.device - sos = torch.tensor([blank_id], device=device, dtype=torch.int64).reshape( - 1, 1 - ) + sos = torch.tensor([blank_id], device=device, dtype=torch.int64).reshape(1, 1) decoder_out, (h, c) = model.decoder(sos) T = encoder_out.size(1) t = 0 @@ -123,9 +121,7 @@ def beam_search( max_u = 20000 # terminate after this number of steps u = 0 - cache: Dict[ - str, Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]] - ] = {} + cache: Dict[str, Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = {} while t < T and u < max_u: # fmt: off @@ -157,9 +153,9 @@ def beam_search( cached_key = "_".join(map(str, y_star.ys)) if cached_key not in cache: - decoder_input = torch.tensor( - [y_star.ys[-1]], device=device - ).reshape(1, 1) + decoder_input = torch.tensor([y_star.ys[-1]], device=device).reshape( + 1, 1 + ) decoder_out, decoder_state = model.decoder( decoder_input, diff --git a/egs/librispeech/ASR/transducer/decode.py b/egs/librispeech/ASR/transducer/decode.py index 5f233df87e..804713a20c 100755 --- a/egs/librispeech/ASR/transducer/decode.py +++ b/egs/librispeech/ASR/transducer/decode.py @@ -228,9 +228,7 @@ def decode_one_batch( supervisions = batch["supervisions"] feature_lens = supervisions["num_frames"].to(device) - encoder_out, encoder_out_lens = model.encoder( - x=feature, x_lens=feature_lens - ) + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) hyps = [] batch_size = encoder_out.size(0) @@ -245,9 +243,7 @@ def decode_one_batch( model=model, encoder_out=encoder_out_i, beam=params.beam_size ) else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") hyps.append(sp.decode(hyp).split()) if params.decoding_method == "greedy_search": @@ -318,9 +314,7 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -353,8 +347,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) diff --git a/egs/librispeech/ASR/transducer/export.py b/egs/librispeech/ASR/transducer/export.py index 5a5db30c41..6db0272f0a 100755 --- a/egs/librispeech/ASR/transducer/export.py +++ b/egs/librispeech/ASR/transducer/export.py @@ -238,9 +238,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/transducer/pretrained.py b/egs/librispeech/ASR/transducer/pretrained.py index 1db2df6481..b1ff7b2b13 100755 --- a/egs/librispeech/ASR/transducer/pretrained.py +++ b/egs/librispeech/ASR/transducer/pretrained.py @@ -189,8 +189,7 @@ def read_sound_files( for f in filenames: wave, sample_rate = torchaudio.load(f) assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" + f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}" ) # We use only the first channel ans.append(wave[0]) @@ -249,9 +248,7 @@ def main(): features = fbank(waves) feature_lengths = [f.size(0) for f in features] - features = pad_sequence( - features, batch_first=True, padding_value=math.log(1e-10) - ) + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) feature_lengths = torch.tensor(feature_lengths, device=device) @@ -287,9 +284,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/transducer/rnn.py b/egs/librispeech/ASR/transducer/rnn.py index 2a165b0c1e..fe8732301a 100644 --- a/egs/librispeech/ASR/transducer/rnn.py +++ b/egs/librispeech/ASR/transducer/rnn.py @@ -117,12 +117,8 @@ def __init__( ) if bias: - self.bias_ih = nn.Parameter( - torch.empty(4 * hidden_size, **factory_kwargs) - ) - self.bias_hh = nn.Parameter( - torch.empty(4 * hidden_size, **factory_kwargs) - ) + self.bias_ih = nn.Parameter(torch.empty(4 * hidden_size, **factory_kwargs)) + self.bias_hh = nn.Parameter(torch.empty(4 * hidden_size, **factory_kwargs)) else: self.register_parameter("bias_ih", None) self.register_parameter("bias_hh", None) @@ -348,9 +344,7 @@ def __init__( device=device, dtype=dtype, ) - first_layer = LayerNormLSTMLayer( - input_size=input_size, **factory_kwargs - ) + first_layer = LayerNormLSTMLayer(input_size=input_size, **factory_kwargs) layers = [first_layer] for i in range(1, num_layers): layers.append( @@ -385,9 +379,7 @@ def forward( - List[(next_h, next_c)] containing the hidden states for all layers """ - output_states = torch.jit.annotate( - List[Tuple[torch.Tensor, torch.Tensor]], [] - ) + output_states = torch.jit.annotate(List[Tuple[torch.Tensor, torch.Tensor]], []) output = input for i, rnn_layer in enumerate(self.layers): state = states[i] @@ -456,12 +448,8 @@ def __init__( ) if bias: - self.bias_ih = nn.Parameter( - torch.empty(3 * hidden_size, **factory_kwargs) - ) - self.bias_hh = nn.Parameter( - torch.empty(3 * hidden_size, **factory_kwargs) - ) + self.bias_ih = nn.Parameter(torch.empty(3 * hidden_size, **factory_kwargs)) + self.bias_hh = nn.Parameter(torch.empty(3 * hidden_size, **factory_kwargs)) else: self.register_parameter("bias_ih", None) self.register_parameter("bias_hh", None) diff --git a/egs/librispeech/ASR/transducer/test_rnn.py b/egs/librispeech/ASR/transducer/test_rnn.py index 8591e2d8aa..74c94cc703 100755 --- a/egs/librispeech/ASR/transducer/test_rnn.py +++ b/egs/librispeech/ASR/transducer/test_rnn.py @@ -254,9 +254,7 @@ def test_layernorm_lstm_layer_with_projection_forward(device="cpu"): for name, self_param in self_layer.cell.named_parameters(): getattr(torch_layer, f"{name}_l0").copy_(self_param) - torch_y, (torch_h, torch_c) = torch_layer( - x_clone, (h.unsqueeze(0), c.unsqueeze(0)) - ) + torch_y, (torch_h, torch_c) = torch_layer(x_clone, (h.unsqueeze(0), c.unsqueeze(0))) assert_allclose(self_y, torch_y) assert_allclose(self_h, torch_h) assert_allclose(self_c, torch_c) @@ -303,9 +301,7 @@ def test_layernorm_lstm_layer_forward(device="cpu"): for name, self_param in self_layer.cell.named_parameters(): getattr(torch_layer, f"{name}_l0").copy_(self_param) - torch_y, (torch_h, torch_c) = torch_layer( - x_clone, (h.unsqueeze(0), c.unsqueeze(0)) - ) + torch_y, (torch_h, torch_c) = torch_layer(x_clone, (h.unsqueeze(0), c.unsqueeze(0))) assert_allclose(self_y, torch_y) assert_allclose(self_h, torch_h) assert_allclose(self_c, torch_c) @@ -594,9 +590,7 @@ def test_layernorm_gru_cell_forward(device="cpu"): assert_allclose(self_h, torch_h, atol=1e-5) - ( - self_h.reshape(-1) * torch.arange(self_h.numel(), device=device) - ).sum().backward() + (self_h.reshape(-1) * torch.arange(self_h.numel(), device=device)).sum().backward() ( torch_h.reshape(-1) * torch.arange(torch_h.numel(), device=device) ).sum().backward() @@ -718,9 +712,7 @@ def test_layernorm_gru_forward(device="cpu"): T = torch.randint(low=2, high=100, size=(1,)) x = torch.rand(N, T, input_size, device=device).requires_grad_() - states = [ - torch.rand(N, hidden_size, device=device) for _ in range(num_layers) - ] + states = [torch.rand(N, hidden_size, device=device) for _ in range(num_layers)] x_clone = x.detach().clone().requires_grad_() diff --git a/egs/librispeech/ASR/transducer/train.py b/egs/librispeech/ASR/transducer/train.py index 1dd65eddb5..674ea10a63 100755 --- a/egs/librispeech/ASR/transducer/train.py +++ b/egs/librispeech/ASR/transducer/train.py @@ -396,9 +396,7 @@ def compute_loss( info = MetricsTracker() with warnings.catch_warnings(): warnings.simplefilter("ignore") - info["frames"] = ( - (feature_lens // params.subsampling_factor).sum().item() - ) + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() # Note: We use reduction=sum while computing the loss. info["loss"] = loss.detach().cpu().item() @@ -520,9 +518,7 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -659,9 +655,7 @@ def remove_short_and_long_utt(c: Cut): cur_lr = optimizer._rate if tb_writer is not None: - tb_writer.add_scalar( - "train/learning_rate", cur_lr, params.batch_idx_train - ) + tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train) tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) if rank == 0: diff --git a/egs/librispeech/ASR/transducer_lstm/beam_search.py b/egs/librispeech/ASR/transducer_lstm/beam_search.py index 3531a96334..5342c3e8c5 100644 --- a/egs/librispeech/ASR/transducer_lstm/beam_search.py +++ b/egs/librispeech/ASR/transducer_lstm/beam_search.py @@ -38,9 +38,7 @@ def greedy_search(model: Transducer, encoder_out: torch.Tensor) -> List[int]: blank_id = model.decoder.blank_id device = model.device - sos = torch.tensor([blank_id], device=device, dtype=torch.int64).reshape( - 1, 1 - ) + sos = torch.tensor([blank_id], device=device, dtype=torch.int64).reshape(1, 1) decoder_out, (h, c) = model.decoder(sos) T = encoder_out.size(1) t = 0 @@ -124,9 +122,7 @@ def beam_search( max_u = 20000 # terminate after this number of steps u = 0 - cache: Dict[ - str, Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]] - ] = {} + cache: Dict[str, Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = {} while t < T and u < max_u: # fmt: off @@ -158,9 +154,9 @@ def beam_search( cached_key = "_".join(map(str, y_star.ys)) if cached_key not in cache: - decoder_input = torch.tensor( - [y_star.ys[-1]], device=device - ).reshape(1, 1) + decoder_input = torch.tensor([y_star.ys[-1]], device=device).reshape( + 1, 1 + ) decoder_out, decoder_state = model.decoder( decoder_input, diff --git a/egs/librispeech/ASR/transducer_lstm/decode.py b/egs/librispeech/ASR/transducer_lstm/decode.py index 604235e2ab..9511ca6d75 100755 --- a/egs/librispeech/ASR/transducer_lstm/decode.py +++ b/egs/librispeech/ASR/transducer_lstm/decode.py @@ -225,9 +225,7 @@ def decode_one_batch( supervisions = batch["supervisions"] feature_lens = supervisions["num_frames"].to(device) - encoder_out, encoder_out_lens = model.encoder( - x=feature, x_lens=feature_lens - ) + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) hyps = [] batch_size = encoder_out.size(0) @@ -242,9 +240,7 @@ def decode_one_batch( model=model, encoder_out=encoder_out_i, beam=params.beam_size ) else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") hyps.append(sp.decode(hyp).split()) if params.decoding_method == "greedy_search": @@ -315,9 +311,7 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -350,8 +344,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) diff --git a/egs/librispeech/ASR/transducer_lstm/encoder.py b/egs/librispeech/ASR/transducer_lstm/encoder.py index 3dc992dd29..038d800773 100644 --- a/egs/librispeech/ASR/transducer_lstm/encoder.py +++ b/egs/librispeech/ASR/transducer_lstm/encoder.py @@ -48,9 +48,7 @@ def __init__( if vgg_frontend: self.encoder_embed = VggSubsampling(num_features, real_hidden_size) else: - self.encoder_embed = Conv2dSubsampling( - num_features, real_hidden_size - ) + self.encoder_embed = Conv2dSubsampling(num_features, real_hidden_size) self.rnn = nn.LSTM( input_size=hidden_size, diff --git a/egs/librispeech/ASR/transducer_lstm/train.py b/egs/librispeech/ASR/transducer_lstm/train.py index cdb801e795..57bda63fd5 100755 --- a/egs/librispeech/ASR/transducer_lstm/train.py +++ b/egs/librispeech/ASR/transducer_lstm/train.py @@ -400,9 +400,7 @@ def compute_loss( info = MetricsTracker() with warnings.catch_warnings(): warnings.simplefilter("ignore") - info["frames"] = ( - (feature_lens // params.subsampling_factor).sum().item() - ) + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() # Note: We use reduction=sum while computing the loss. info["loss"] = loss.detach().cpu().item() @@ -524,9 +522,7 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -665,9 +661,7 @@ def remove_short_and_long_utt(c: Cut): cur_lr = optimizer._rate if tb_writer is not None: - tb_writer.add_scalar( - "train/learning_rate", cur_lr, params.batch_idx_train - ) + tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train) tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) if rank == 0: diff --git a/egs/librispeech/ASR/transducer_stateless/alignment.py b/egs/librispeech/ASR/transducer_stateless/alignment.py index f143611eae..65f2c58d87 100644 --- a/egs/librispeech/ASR/transducer_stateless/alignment.py +++ b/egs/librispeech/ASR/transducer_stateless/alignment.py @@ -193,9 +193,7 @@ def force_alignment( decoder_out = model.decoder(decoder_input, need_pad=False) # decoder_output is of shape (num_active_items, 1, decoder_output_dim) - current_encoder_out = current_encoder_out.expand( - decoder_out.size(0), 1, -1 - ) + current_encoder_out = current_encoder_out.expand(decoder_out.size(0), 1, -1) logits = model.joiner( current_encoder_out, diff --git a/egs/librispeech/ASR/transducer_stateless/beam_search.py b/egs/librispeech/ASR/transducer_stateless/beam_search.py index ea985f30da..1d79eef9d9 100644 --- a/egs/librispeech/ASR/transducer_stateless/beam_search.py +++ b/egs/librispeech/ASR/transducer_stateless/beam_search.py @@ -316,9 +316,9 @@ def greedy_search( y = logits.argmax().item() if y != blank_id: hyp.append(y) - decoder_input = torch.tensor( - [hyp[-context_size:]], device=device - ).reshape(1, context_size) + decoder_input = torch.tensor([hyp[-context_size:]], device=device).reshape( + 1, context_size + ) decoder_out = model.decoder(decoder_input, need_pad=False) @@ -478,9 +478,7 @@ def add(self, hyp: Hypothesis) -> None: key = hyp.key if key in self: old_hyp = self._data[key] # shallow copy - torch.logaddexp( - old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob - ) + torch.logaddexp(old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob) else: self._data[key] = hyp @@ -496,9 +494,7 @@ def get_most_probable(self, length_norm: bool = False) -> Hypothesis: Return the hypothesis that has the largest `log_prob`. """ if length_norm: - return max( - self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys) - ) + return max(self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys)) else: return max(self._data.values(), key=lambda hyp: hyp.log_prob) @@ -786,9 +782,7 @@ def modified_beam_search( log_probs_shape = k2.ragged.create_ragged_shape2( row_splits=row_splits, cached_tot_size=log_probs.numel() ) - ragged_log_probs = k2.RaggedTensor( - shape=log_probs_shape, value=log_probs - ) + ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) for i in range(batch_size): topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) @@ -887,9 +881,7 @@ def _deprecated_modified_beam_search( decoder_out = model.decoder(decoder_input, need_pad=False) # decoder_output is of shape (num_hyps, 1, decoder_output_dim) - current_encoder_out = current_encoder_out.expand( - decoder_out.size(0), 1, -1 - ) + current_encoder_out = current_encoder_out.expand(decoder_out.size(0), 1, -1) logits = model.joiner( current_encoder_out, @@ -959,9 +951,9 @@ def beam_search( device = model.device - decoder_input = torch.tensor( - [blank_id] * context_size, device=device - ).reshape(1, context_size) + decoder_input = torch.tensor([blank_id] * context_size, device=device).reshape( + 1, context_size + ) decoder_out = model.decoder(decoder_input, need_pad=False) diff --git a/egs/librispeech/ASR/transducer_stateless/compute_ali.py b/egs/librispeech/ASR/transducer_stateless/compute_ali.py index 48769e9d17..c91198bb9c 100755 --- a/egs/librispeech/ASR/transducer_stateless/compute_ali.py +++ b/egs/librispeech/ASR/transducer_stateless/compute_ali.py @@ -124,8 +124,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) return parser @@ -162,9 +161,7 @@ def compute_alignments( feature_lens = supervisions["num_frames"].to(device) - encoder_out, encoder_out_lens = model.encoder( - x=feature, x_lens=feature_lens - ) + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) batch_size = encoder_out.size(0) @@ -204,9 +201,7 @@ def compute_alignments( if batch_idx % 2 == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return CutSet.from_cuts(cuts) diff --git a/egs/librispeech/ASR/transducer_stateless/conformer.py b/egs/librispeech/ASR/transducer_stateless/conformer.py index cde52c9fc7..01e8c5b213 100644 --- a/egs/librispeech/ASR/transducer_stateless/conformer.py +++ b/egs/librispeech/ASR/transducer_stateless/conformer.py @@ -209,10 +209,7 @@ def get_init_state( NOTE: the returned tensors are on the given device. """ - if ( - len(self._init_state) == 2 - and self._init_state[0].size(1) == left_context - ): + if len(self._init_state) == 2 and self._init_state[0].size(1) == left_context: # Note: It is OK to share the init state as it is # not going to be modified by the model return self._init_state @@ -421,9 +418,7 @@ def __init__( causal: bool = False, ) -> None: super(ConformerEncoderLayer, self).__init__() - self.self_attn = RelPositionMultiheadAttention( - d_model, nhead, dropout=0.0 - ) + self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0) self.feed_forward = nn.Sequential( nn.Linear(d_model, dim_feedforward), @@ -439,22 +434,16 @@ def __init__( nn.Linear(dim_feedforward, d_model), ) - self.conv_module = ConvolutionModule( - d_model, cnn_module_kernel, causal=causal - ) + self.conv_module = ConvolutionModule(d_model, cnn_module_kernel, causal=causal) - self.norm_ff_macaron = nn.LayerNorm( - d_model - ) # for the macaron style FNN module + self.norm_ff_macaron = nn.LayerNorm(d_model) # for the macaron style FNN module self.norm_ff = nn.LayerNorm(d_model) # for the FNN module self.norm_mha = nn.LayerNorm(d_model) # for the MHA module self.ff_scale = 0.5 self.norm_conv = nn.LayerNorm(d_model) # for the CNN module - self.norm_final = nn.LayerNorm( - d_model - ) # for the final output of the block + self.norm_final = nn.LayerNorm(d_model) # for the final output of the block self.dropout = nn.Dropout(dropout) @@ -486,9 +475,7 @@ def forward( residual = src if self.normalize_before: src = self.norm_ff_macaron(src) - src = residual + self.ff_scale * self.dropout( - self.feed_forward_macaron(src) - ) + src = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(src)) if not self.normalize_before: src = self.norm_ff_macaron(src) @@ -514,9 +501,7 @@ def forward( if self.normalize_before: src = self.norm_conv(src) - src, _ = self.conv_module( - src, src_key_padding_mask=src_key_padding_mask - ) + src, _ = self.conv_module(src, src_key_padding_mask=src_key_padding_mask) src = residual + self.dropout(src) if not self.normalize_before: @@ -581,9 +566,7 @@ def chunk_forward( residual = src if self.normalize_before: src = self.norm_ff_macaron(src) - src = residual + self.ff_scale * self.dropout( - self.feed_forward_macaron(src) - ) + src = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(src)) if not self.normalize_before: src = self.norm_ff_macaron(src) @@ -625,9 +608,7 @@ def chunk_forward( if self.normalize_before: src = self.norm_conv(src) - src, conv_cache = self.conv_module( - src, states[1], right_context=right_context - ) + src, conv_cache = self.conv_module(src, states[1], right_context=right_context) states[1] = conv_cache src = residual + self.dropout(src) @@ -779,9 +760,7 @@ class RelPositionalEncoding(torch.nn.Module): """ - def __init__( - self, d_model: int, dropout_rate: float, max_len: int = 5000 - ) -> None: + def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None: """Construct an PositionalEncoding object.""" super(RelPositionalEncoding, self).__init__() self.d_model = d_model @@ -798,9 +777,7 @@ def extend_pe(self, x: Tensor, left_context: int = 0) -> None: # the length of self.pe is 2 * input_len - 1 if self.pe.size(1) >= x_size_1 * 2 - 1: # Note: TorchScript doesn't implement operator== for torch.Device - if self.pe.dtype != x.dtype or str(self.pe.device) != str( - x.device - ): + if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device): self.pe = self.pe.to(dtype=x.dtype, device=x.device) return # Suppose `i` means to the position of query vector and `j` means the @@ -826,9 +803,7 @@ def extend_pe(self, x: Tensor, left_context: int = 0) -> None: pe = torch.cat([pe_positive, pe_negative], dim=1) self.pe = pe.to(device=x.device, dtype=x.dtype) - def forward( - self, x: torch.Tensor, left_context: int = 0 - ) -> Tuple[Tensor, Tensor]: + def forward(self, x: torch.Tensor, left_context: int = 0) -> Tuple[Tensor, Tensor]: """Add positional encoding. Args: @@ -1092,9 +1067,9 @@ def multi_head_attention_forward( if torch.equal(query, key) and torch.equal(key, value): # self-attention - q, k, v = nn.functional.linear( - query, in_proj_weight, in_proj_bias - ).chunk(3, dim=-1) + q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk( + 3, dim=-1 + ) elif torch.equal(key, value): # encoder-decoder attention @@ -1163,31 +1138,22 @@ def multi_head_attention_forward( if attn_mask.dim() == 2: attn_mask = attn_mask.unsqueeze(0) if list(attn_mask.size()) != [1, query.size(0), key.size(0)]: - raise RuntimeError( - "The size of the 2D attn_mask is not correct." - ) + raise RuntimeError("The size of the 2D attn_mask is not correct.") elif attn_mask.dim() == 3: if list(attn_mask.size()) != [ bsz * num_heads, query.size(0), key.size(0), ]: - raise RuntimeError( - "The size of the 3D attn_mask is not correct." - ) + raise RuntimeError("The size of the 3D attn_mask is not correct.") else: raise RuntimeError( - "attn_mask's dimension {} is not supported".format( - attn_mask.dim() - ) + "attn_mask's dimension {} is not supported".format(attn_mask.dim()) ) # attn_mask's dim is 3 now. # convert ByteTensor key_padding_mask to bool - if ( - key_padding_mask is not None - and key_padding_mask.dtype == torch.uint8 - ): + if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: warnings.warn( "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead." ) @@ -1228,14 +1194,10 @@ def multi_head_attention_forward( # first compute matrix a and matrix c # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) - matrix_ac = torch.matmul( - q_with_bias_u, k - ) # (batch, head, time1, time2) + matrix_ac = torch.matmul(q_with_bias_u, k) # (batch, head, time1, time2) # compute matrix b and matrix d - matrix_bd = torch.matmul( - q_with_bias_v, p - ) # (batch, head, time1, 2*time1-1) + matrix_bd = torch.matmul(q_with_bias_v, p) # (batch, head, time1, 2*time1-1) matrix_bd = self.rel_shift(matrix_bd, left_context=left_context) @@ -1243,9 +1205,7 @@ def multi_head_attention_forward( matrix_ac + matrix_bd ) * scaling # (batch, head, time1, time2) - attn_output_weights = attn_output_weights.view( - bsz * num_heads, tgt_len, -1 - ) + attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1) assert list(attn_output_weights.size()) == [ bsz * num_heads, @@ -1290,9 +1250,7 @@ def multi_head_attention_forward( attn_output_weights = attn_output_weights.view( bsz, num_heads, tgt_len, src_len ) - attn_output_weights = attn_output_weights.masked_fill( - combined_mask, 0.0 - ) + attn_output_weights = attn_output_weights.masked_fill(combined_mask, 0.0) attn_output_weights = attn_output_weights.view( bsz * num_heads, tgt_len, src_len ) @@ -1304,13 +1262,9 @@ def multi_head_attention_forward( attn_output = torch.bmm(attn_output_weights, v) assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim] attn_output = ( - attn_output.transpose(0, 1) - .contiguous() - .view(tgt_len, bsz, embed_dim) - ) - attn_output = nn.functional.linear( - attn_output, out_proj_weight, out_proj_bias + attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) ) + attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias) if need_weights: # average attention weights over heads @@ -1418,16 +1372,12 @@ def forward( # manualy padding self.lorder zeros to the left x = nn.functional.pad(x, (self.lorder, 0), "constant", 0.0) else: - assert ( - not self.training - ), "Cache should be None in training time" + assert not self.training, "Cache should be None in training time" assert cache.size(0) == self.lorder x = torch.cat([cache.permute(1, 2, 0), x], dim=2) if right_context > 0: cache = x.permute(2, 0, 1)[ - -(self.lorder + right_context) : ( # noqa - -right_context - ), + -(self.lorder + right_context) : (-right_context), # noqa ..., ] else: diff --git a/egs/librispeech/ASR/transducer_stateless/decode.py b/egs/librispeech/ASR/transducer_stateless/decode.py index 74bba9cad6..688e214c83 100755 --- a/egs/librispeech/ASR/transducer_stateless/decode.py +++ b/egs/librispeech/ASR/transducer_stateless/decode.py @@ -171,8 +171,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -230,9 +229,7 @@ def decode_one_batch( supervisions = batch["supervisions"] feature_lens = supervisions["num_frames"].to(device) - encoder_out, encoder_out_lens = model.encoder( - x=feature, x_lens=feature_lens - ) + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) hyps = [] @@ -248,10 +245,7 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif ( - params.decoding_method == "greedy_search" - and params.max_sym_per_frame == 1 - ): + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -374,9 +368,7 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -409,8 +401,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -450,9 +441,7 @@ def main(): params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-states-{params.max_states}" elif "beam_search" in params.decoding_method: - params.suffix += ( - f"-{params.decoding_method}-beam-size-{params.beam_size}" - ) + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" diff --git a/egs/librispeech/ASR/transducer_stateless/decoder.py b/egs/librispeech/ASR/transducer_stateless/decoder.py index fbc2373a9d..a182d91e2a 100644 --- a/egs/librispeech/ASR/transducer_stateless/decoder.py +++ b/egs/librispeech/ASR/transducer_stateless/decoder.py @@ -87,9 +87,7 @@ def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor: if self.context_size > 1: embedding_out = embedding_out.permute(0, 2, 1) if need_pad is True: - embedding_out = F.pad( - embedding_out, pad=(self.context_size - 1, 0) - ) + embedding_out = F.pad(embedding_out, pad=(self.context_size - 1, 0)) else: # During inference time, there is no need to do extra padding # as we only need one output diff --git a/egs/librispeech/ASR/transducer_stateless/export.py b/egs/librispeech/ASR/transducer_stateless/export.py index 8bd0bdea1e..c617e6c4c0 100755 --- a/egs/librispeech/ASR/transducer_stateless/export.py +++ b/egs/librispeech/ASR/transducer_stateless/export.py @@ -109,8 +109,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) return parser @@ -244,9 +243,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/transducer_stateless/joiner.py b/egs/librispeech/ASR/transducer_stateless/joiner.py index 93cccbd8cd..e1625992dd 100644 --- a/egs/librispeech/ASR/transducer_stateless/joiner.py +++ b/egs/librispeech/ASR/transducer_stateless/joiner.py @@ -60,13 +60,9 @@ def forward( encoder_out_len: List[int] = encoder_out_len.tolist() decoder_out_len: List[int] = decoder_out_len.tolist() - encoder_out_list = [ - encoder_out[i, : encoder_out_len[i], :] for i in range(N) - ] + encoder_out_list = [encoder_out[i, : encoder_out_len[i], :] for i in range(N)] - decoder_out_list = [ - decoder_out[i, : decoder_out_len[i], :] for i in range(N) - ] + decoder_out_list = [decoder_out[i, : decoder_out_len[i], :] for i in range(N)] x = [ e.unsqueeze(1) + d.unsqueeze(0) diff --git a/egs/librispeech/ASR/transducer_stateless/pretrained.py b/egs/librispeech/ASR/transducer_stateless/pretrained.py index b645218015..c393974e6d 100755 --- a/egs/librispeech/ASR/transducer_stateless/pretrained.py +++ b/egs/librispeech/ASR/transducer_stateless/pretrained.py @@ -167,8 +167,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -198,8 +197,7 @@ def read_sound_files( for f in filenames: wave, sample_rate = torchaudio.load(f) assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" + f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}" ) # We use only the first channel ans.append(wave[0]) @@ -259,9 +257,7 @@ def main(): features = fbank(waves) feature_lengths = [f.size(0) for f in features] - features = pad_sequence( - features, batch_first=True, padding_value=math.log(1e-10) - ) + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) feature_lengths = torch.tensor(feature_lengths, device=device) @@ -334,9 +330,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/transducer_stateless/test_compute_ali.py b/egs/librispeech/ASR/transducer_stateless/test_compute_ali.py index b00fc34f13..9af46846a1 100755 --- a/egs/librispeech/ASR/transducer_stateless/test_compute_ali.py +++ b/egs/librispeech/ASR/transducer_stateless/test_compute_ali.py @@ -140,16 +140,13 @@ def main(): token_alignment[i, : token_alignment_length[i]].tolist(), sp=sp ) word_starting_time = [ - "{:.2f}".format(i * frame_shift_in_second) - for i in word_starting_frames + "{:.2f}".format(i * frame_shift_in_second) for i in word_starting_frames ] words = supervisions["text"][i].split() assert len(word_starting_frames) == len(words) - word_starting_time_dict[cuts[i].id] = list( - zip(words, word_starting_time) - ) + word_starting_time_dict[cuts[i].id] = list(zip(words, word_starting_time)) # This is a demo script and we exit here after processing # one batch. @@ -160,9 +157,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/transducer_stateless/test_conformer.py b/egs/librispeech/ASR/transducer_stateless/test_conformer.py index d1350c8ab4..65b08d425f 100755 --- a/egs/librispeech/ASR/transducer_stateless/test_conformer.py +++ b/egs/librispeech/ASR/transducer_stateless/test_conformer.py @@ -29,9 +29,7 @@ def test_conformer(): feature_dim = 50 - c = Conformer( - num_features=feature_dim, output_dim=256, d_model=128, nhead=4 - ) + c = Conformer(num_features=feature_dim, output_dim=256, d_model=128, nhead=4) batch_size = 5 seq_len = 20 # Just make sure the forward pass runs. diff --git a/egs/librispeech/ASR/transducer_stateless/train.py b/egs/librispeech/ASR/transducer_stateless/train.py index ae93f3348d..c86125f445 100755 --- a/egs/librispeech/ASR/transducer_stateless/train.py +++ b/egs/librispeech/ASR/transducer_stateless/train.py @@ -136,8 +136,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( @@ -422,9 +421,7 @@ def compute_loss( info = MetricsTracker() with warnings.catch_warnings(): warnings.simplefilter("ignore") - info["frames"] = ( - (feature_lens // params.subsampling_factor).sum().item() - ) + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() # Note: We use reduction=sum while computing the loss. info["loss"] = loss.detach().cpu().item() @@ -545,9 +542,7 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -664,13 +659,9 @@ def remove_short_and_long_utt(c: Cut): num_removed = num_in_total - num_left removed_percent = num_removed / num_in_total * 100 - logging.info( - f"Before removing short and long utterances: {num_in_total}" - ) + logging.info(f"Before removing short and long utterances: {num_in_total}") logging.info(f"After removing short and long utterances: {num_left}") - logging.info( - f"Removed {num_removed} utterances ({removed_percent:.5f}%)" - ) + logging.info(f"Removed {num_removed} utterances ({removed_percent:.5f}%)") except TypeError as e: # You can ignore this error as previous versions of Lhotse work fine # for the above code. In recent versions of Lhotse, it uses @@ -698,9 +689,7 @@ def remove_short_and_long_utt(c: Cut): cur_lr = optimizer._rate if tb_writer is not None: - tb_writer.add_scalar( - "train/learning_rate", cur_lr, params.batch_idx_train - ) + tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train) tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) if rank == 0: diff --git a/egs/librispeech/ASR/transducer_stateless/transformer.py b/egs/librispeech/ASR/transducer_stateless/transformer.py index e851dcc328..b3ff153c18 100644 --- a/egs/librispeech/ASR/transducer_stateless/transformer.py +++ b/egs/librispeech/ASR/transducer_stateless/transformer.py @@ -250,9 +250,7 @@ def _get_activation_fn(activation: str): elif activation == "gelu": return nn.functional.gelu - raise RuntimeError( - "activation should be relu/gelu, not {}".format(activation) - ) + raise RuntimeError("activation should be relu/gelu, not {}".format(activation)) class PositionalEncoding(nn.Module): diff --git a/egs/librispeech/ASR/transducer_stateless2/decode.py b/egs/librispeech/ASR/transducer_stateless2/decode.py index ac2807241c..c642b16bd1 100755 --- a/egs/librispeech/ASR/transducer_stateless2/decode.py +++ b/egs/librispeech/ASR/transducer_stateless2/decode.py @@ -171,8 +171,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -230,9 +229,7 @@ def decode_one_batch( supervisions = batch["supervisions"] feature_lens = supervisions["num_frames"].to(device) - encoder_out, encoder_out_lens = model.encoder( - x=feature, x_lens=feature_lens - ) + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) hyps = [] @@ -248,10 +245,7 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif ( - params.decoding_method == "greedy_search" - and params.max_sym_per_frame == 1 - ): + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -374,9 +368,7 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -409,8 +401,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -450,9 +441,7 @@ def main(): params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-states-{params.max_states}" elif "beam_search" in params.decoding_method: - params.suffix += ( - f"-{params.decoding_method}-beam-size-{params.beam_size}" - ) + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" diff --git a/egs/librispeech/ASR/transducer_stateless2/export.py b/egs/librispeech/ASR/transducer_stateless2/export.py index 57c1a60947..229c514b9d 100755 --- a/egs/librispeech/ASR/transducer_stateless2/export.py +++ b/egs/librispeech/ASR/transducer_stateless2/export.py @@ -104,8 +104,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) return parser @@ -176,9 +175,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/transducer_stateless2/pretrained.py b/egs/librispeech/ASR/transducer_stateless2/pretrained.py index 292f77f036..9053bc6e03 100755 --- a/egs/librispeech/ASR/transducer_stateless2/pretrained.py +++ b/egs/librispeech/ASR/transducer_stateless2/pretrained.py @@ -167,8 +167,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -198,8 +197,7 @@ def read_sound_files( for f in filenames: wave, sample_rate = torchaudio.load(f) assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" + f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}" ) # We use only the first channel ans.append(wave[0]) @@ -259,9 +257,7 @@ def main(): features = fbank(waves) feature_lengths = [f.size(0) for f in features] - features = pad_sequence( - features, batch_first=True, padding_value=math.log(1e-10) - ) + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) feature_lengths = torch.tensor(feature_lengths, device=device) @@ -334,9 +330,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/transducer_stateless2/train.py b/egs/librispeech/ASR/transducer_stateless2/train.py index ea15c9040c..71c9c5df75 100755 --- a/egs/librispeech/ASR/transducer_stateless2/train.py +++ b/egs/librispeech/ASR/transducer_stateless2/train.py @@ -136,8 +136,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( @@ -410,9 +409,7 @@ def compute_loss( info = MetricsTracker() with warnings.catch_warnings(): warnings.simplefilter("ignore") - info["frames"] = ( - (feature_lens // params.subsampling_factor).sum().item() - ) + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() # Note: We use reduction=sum while computing the loss. info["loss"] = loss.detach().cpu().item() @@ -533,9 +530,7 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -652,13 +647,9 @@ def remove_short_and_long_utt(c: Cut): num_removed = num_in_total - num_left removed_percent = num_removed / num_in_total * 100 - logging.info( - f"Before removing short and long utterances: {num_in_total}" - ) + logging.info(f"Before removing short and long utterances: {num_in_total}") logging.info(f"After removing short and long utterances: {num_left}") - logging.info( - f"Removed {num_removed} utterances ({removed_percent:.5f}%)" - ) + logging.info(f"Removed {num_removed} utterances ({removed_percent:.5f}%)") except TypeError as e: # You can ignore this error as previous versions of Lhotse work fine # for the above code. In recent versions of Lhotse, it uses @@ -686,9 +677,7 @@ def remove_short_and_long_utt(c: Cut): cur_lr = optimizer._rate if tb_writer is not None: - tb_writer.add_scalar( - "train/learning_rate", cur_lr, params.batch_idx_train - ) + tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train) tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) if rank == 0: diff --git a/egs/librispeech/ASR/transducer_stateless_multi_datasets/decode.py b/egs/librispeech/ASR/transducer_stateless_multi_datasets/decode.py index d596e05cb7..2538210281 100755 --- a/egs/librispeech/ASR/transducer_stateless_multi_datasets/decode.py +++ b/egs/librispeech/ASR/transducer_stateless_multi_datasets/decode.py @@ -172,8 +172,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -231,9 +230,7 @@ def decode_one_batch( supervisions = batch["supervisions"] feature_lens = supervisions["num_frames"].to(device) - encoder_out, encoder_out_lens = model.encoder( - x=feature, x_lens=feature_lens - ) + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) hyps = [] @@ -249,10 +246,7 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif ( - params.decoding_method == "greedy_search" - and params.max_sym_per_frame == 1 - ): + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -375,9 +369,7 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -410,8 +402,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -451,9 +442,7 @@ def main(): params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-states-{params.max_states}" elif "beam_search" in params.decoding_method: - params.suffix += ( - f"-{params.decoding_method}-beam-size-{params.beam_size}" - ) + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" diff --git a/egs/librispeech/ASR/transducer_stateless_multi_datasets/export.py b/egs/librispeech/ASR/transducer_stateless_multi_datasets/export.py index b6b69d9327..97b0eea4a7 100755 --- a/egs/librispeech/ASR/transducer_stateless_multi_datasets/export.py +++ b/egs/librispeech/ASR/transducer_stateless_multi_datasets/export.py @@ -110,8 +110,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) return parser @@ -247,9 +246,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/transducer_stateless_multi_datasets/pretrained.py b/egs/librispeech/ASR/transducer_stateless_multi_datasets/pretrained.py index f297fa2b2f..c698a35b01 100755 --- a/egs/librispeech/ASR/transducer_stateless_multi_datasets/pretrained.py +++ b/egs/librispeech/ASR/transducer_stateless_multi_datasets/pretrained.py @@ -167,8 +167,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -198,8 +197,7 @@ def read_sound_files( for f in filenames: wave, sample_rate = torchaudio.load(f) assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" + f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}" ) # We use only the first channel ans.append(wave[0]) @@ -259,9 +257,7 @@ def main(): features = fbank(waves) feature_lengths = [f.size(0) for f in features] - features = pad_sequence( - features, batch_first=True, padding_value=math.log(1e-10) - ) + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) feature_lengths = torch.tensor(feature_lengths, device=device) @@ -334,9 +330,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/transducer_stateless_multi_datasets/test_asr_datamodule.py b/egs/librispeech/ASR/transducer_stateless_multi_datasets/test_asr_datamodule.py index ef51a78118..1e1188ca67 100755 --- a/egs/librispeech/ASR/transducer_stateless_multi_datasets/test_asr_datamodule.py +++ b/egs/librispeech/ASR/transducer_stateless_multi_datasets/test_asr_datamodule.py @@ -41,9 +41,7 @@ def test_dataset(): print(args) if args.enable_musan: - cuts_musan = load_manifest( - Path(args.manifest_dir) / "musan_cuts.jsonl.gz" - ) + cuts_musan = load_manifest(Path(args.manifest_dir) / "musan_cuts.jsonl.gz") else: cuts_musan = None diff --git a/egs/librispeech/ASR/transducer_stateless_multi_datasets/train.py b/egs/librispeech/ASR/transducer_stateless_multi_datasets/train.py index 27912738c8..e5b7dc390c 100755 --- a/egs/librispeech/ASR/transducer_stateless_multi_datasets/train.py +++ b/egs/librispeech/ASR/transducer_stateless_multi_datasets/train.py @@ -114,8 +114,7 @@ def get_parser(): "--full-libri", type=str2bool, default=True, - help="When enabled, use 960h LibriSpeech. " - "Otherwise, use 100h subset.", + help="When enabled, use 960h LibriSpeech. " "Otherwise, use 100h subset.", ) parser.add_argument( @@ -170,8 +169,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( @@ -469,9 +467,7 @@ def compute_loss( info = MetricsTracker() with warnings.catch_warnings(): warnings.simplefilter("ignore") - info["frames"] = ( - (feature_lens // params.subsampling_factor).sum().item() - ) + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() # Note: We use reduction=sum while computing the loss. info["loss"] = loss.detach().cpu().item() @@ -635,9 +631,7 @@ def train_one_epoch( f"train/current_{prefix}_", params.batch_idx_train, ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) libri_tot_loss.write_summary( tb_writer, "train/libri_tot_", params.batch_idx_train ) @@ -784,9 +778,7 @@ def run(rank, world_size, args): train_giga_cuts = train_giga_cuts.repeat(times=None) if args.enable_musan: - cuts_musan = load_manifest( - Path(args.manifest_dir) / "musan_cuts.jsonl.gz" - ) + cuts_musan = load_manifest(Path(args.manifest_dir) / "musan_cuts.jsonl.gz") else: cuts_musan = None @@ -825,9 +817,7 @@ def run(rank, world_size, args): cur_lr = optimizer._rate if tb_writer is not None: - tb_writer.add_scalar( - "train/learning_rate", cur_lr, params.batch_idx_train - ) + tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train) tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) if rank == 0: diff --git a/egs/ptb/LM/local/sort_lm_training_data.py b/egs/ptb/LM/local/sort_lm_training_data.py index af54dbd073..bed3856e44 100755 --- a/egs/ptb/LM/local/sort_lm_training_data.py +++ b/egs/ptb/LM/local/sort_lm_training_data.py @@ -135,9 +135,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/ptb/LM/local/test_prepare_lm_training_data.py b/egs/ptb/LM/local/test_prepare_lm_training_data.py index 877720e7b0..3790045fa7 100755 --- a/egs/ptb/LM/local/test_prepare_lm_training_data.py +++ b/egs/ptb/LM/local/test_prepare_lm_training_data.py @@ -54,9 +54,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/spgispeech/ASR/local/compute_fbank_musan.py b/egs/spgispeech/ASR/local/compute_fbank_musan.py index 6cb8b65aed..9bea28a41c 100755 --- a/egs/spgispeech/ASR/local/compute_fbank_musan.py +++ b/egs/spgispeech/ASR/local/compute_fbank_musan.py @@ -87,9 +87,7 @@ def compute_fbank_musan(): # create chunks of Musan with duration 5 - 10 seconds musan_cuts = ( CutSet.from_manifests( - recordings=combine( - part["recordings"] for part in manifests.values() - ) + recordings=combine(part["recordings"] for part in manifests.values()) ) .cut_into_windows(10.0) .filter(lambda c: c.duration > 5) @@ -108,8 +106,6 @@ def compute_fbank_musan(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) compute_fbank_musan() diff --git a/egs/spgispeech/ASR/local/compute_fbank_spgispeech.py b/egs/spgispeech/ASR/local/compute_fbank_spgispeech.py index 8116e7605e..20ff6d7ab1 100755 --- a/egs/spgispeech/ASR/local/compute_fbank_spgispeech.py +++ b/egs/spgispeech/ASR/local/compute_fbank_spgispeech.py @@ -103,11 +103,7 @@ def compute_fbank_spgispeech(args): chunk_size=chunk_size, ) start = args.start - stop = ( - min(args.stop, args.num_splits) - if args.stop > 0 - else args.num_splits - ) + stop = min(args.stop, args.num_splits) if args.stop > 0 else args.num_splits num_digits = len(str(args.num_splits)) for i in range(start, stop): idx = f"{i + 1}".zfill(num_digits) @@ -129,9 +125,7 @@ def compute_fbank_spgispeech(args): logging.info(f"{partition} already exists - skipping.") continue logging.info(f"Processing {partition}") - cut_set = load_manifest_lazy( - src_dir / f"cuts_{partition}_raw.jsonl.gz" - ) + cut_set = load_manifest_lazy(src_dir / f"cuts_{partition}_raw.jsonl.gz") cut_set = cut_set.compute_and_store_features_batch( extractor=extractor, storage_path=output_dir / f"feats_{partition}", @@ -144,9 +138,7 @@ def compute_fbank_spgispeech(args): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) args = get_args() diff --git a/egs/spgispeech/ASR/local/prepare_splits.py b/egs/spgispeech/ASR/local/prepare_splits.py index 8c8f1c1335..508d4acd80 100755 --- a/egs/spgispeech/ASR/local/prepare_splits.py +++ b/egs/spgispeech/ASR/local/prepare_splits.py @@ -55,9 +55,7 @@ def split_spgispeech_train(): # Add speed perturbation train_cuts = ( - train_cuts - + train_cuts.perturb_speed(0.9) - + train_cuts.perturb_speed(1.1) + train_cuts + train_cuts.perturb_speed(0.9) + train_cuts.perturb_speed(1.1) ) # Write the manifests to disk. @@ -73,9 +71,7 @@ def split_spgispeech_train(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) split_spgispeech_train() diff --git a/egs/spgispeech/ASR/pruned_transducer_stateless2/asr_datamodule.py b/egs/spgispeech/ASR/pruned_transducer_stateless2/asr_datamodule.py index f165f6e602..d94a925038 100644 --- a/egs/spgispeech/ASR/pruned_transducer_stateless2/asr_datamodule.py +++ b/egs/spgispeech/ASR/pruned_transducer_stateless2/asr_datamodule.py @@ -176,17 +176,13 @@ def train_dataloaders( The state dict for the training sampler. """ logging.info("About to get Musan cuts") - cuts_musan = load_manifest( - self.args.manifest_dir / "cuts_musan.jsonl.gz" - ) + cuts_musan = load_manifest(self.args.manifest_dir / "cuts_musan.jsonl.gz") transforms = [] if self.args.enable_musan: logging.info("Enable MUSAN") transforms.append( - CutMix( - cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True - ) + CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) ) else: logging.info("Disable MUSAN") @@ -208,9 +204,7 @@ def train_dataloaders( input_transforms = [] if self.args.enable_spec_aug: logging.info("Enable SpecAugment") - logging.info( - f"Time warp factor: {self.args.spec_aug_time_warp_factor}" - ) + logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") input_transforms.append( SpecAugment( time_warp_factor=self.args.spec_aug_time_warp_factor, @@ -227,9 +221,7 @@ def train_dataloaders( if self.args.on_the_fly_feats: train = K2SpeechRecognitionDataset( cut_transforms=transforms, - input_strategy=OnTheFlyFeatures( - Fbank(FbankConfig(num_mel_bins=80)) - ), + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), input_transforms=input_transforms, ) else: @@ -282,9 +274,7 @@ def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: if self.args.on_the_fly_feats: validate = K2SpeechRecognitionDataset( cut_transforms=transforms, - input_strategy=OnTheFlyFeatures( - Fbank(FbankConfig(num_mel_bins=80)) - ), + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), ) else: validate = K2SpeechRecognitionDataset( @@ -328,9 +318,7 @@ def test_dataloaders(self, cuts: CutSet) -> DataLoader: @lru_cache() def train_cuts(self) -> CutSet: logging.info("About to get SPGISpeech train cuts") - return load_manifest_lazy( - self.args.manifest_dir / "cuts_train_shuf.jsonl.gz" - ) + return load_manifest_lazy(self.args.manifest_dir / "cuts_train_shuf.jsonl.gz") @lru_cache() def dev_cuts(self) -> CutSet: diff --git a/egs/spgispeech/ASR/pruned_transducer_stateless2/decode.py b/egs/spgispeech/ASR/pruned_transducer_stateless2/decode.py index c39bd05300..098da3ff01 100755 --- a/egs/spgispeech/ASR/pruned_transducer_stateless2/decode.py +++ b/egs/spgispeech/ASR/pruned_transducer_stateless2/decode.py @@ -76,11 +76,7 @@ ) from train import get_params, get_transducer_model -from icefall.checkpoint import ( - average_checkpoints, - find_checkpoints, - load_checkpoint, -) +from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint from icefall.utils import ( AttributeDict, setup_logger, @@ -187,8 +183,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -246,9 +241,7 @@ def decode_one_batch( supervisions = batch["supervisions"] feature_lens = supervisions["num_frames"].to(device) - encoder_out, encoder_out_lens = model.encoder( - x=feature, x_lens=feature_lens - ) + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) hyps = [] if params.decoding_method == "fast_beam_search": @@ -263,10 +256,7 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif ( - params.decoding_method == "greedy_search" - and params.max_sym_per_frame == 1 - ): + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -389,9 +379,7 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -424,9 +412,7 @@ def save_results( # we also compute CER for spgispeech dataset. results_char = [] for res in results: - results_char.append( - (res[0], list("".join(res[1])), list("".join(res[2]))) - ) + results_char.append((res[0], list("".join(res[1])), list("".join(res[2])))) cers_filename = ( params.res_dir / f"cers-{test_set_name}-{key}-{params.suffix}.txt" ) @@ -438,32 +424,23 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(wers_filename)) - test_set_wers = { - k: v for k, v in sorted(test_set_wers.items(), key=lambda x: x[1]) - } - test_set_cers = { - k: v for k, v in sorted(test_set_cers.items(), key=lambda x: x[1]) - } + test_set_wers = {k: v for k, v in sorted(test_set_wers.items(), key=lambda x: x[1])} + test_set_cers = {k: v for k, v in sorted(test_set_cers.items(), key=lambda x: x[1])} errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER\tCER", file=f) for key in test_set_wers: print( - "{}\t{}\t{}".format( - key, test_set_wers[key], test_set_cers[key] - ), + "{}\t{}\t{}".format(key, test_set_wers[key], test_set_cers[key]), file=f, ) s = "\nFor {}, WER/CER of different settings are:\n".format(test_set_name) note = "\tbest for {}".format(test_set_name) for key in test_set_wers: - s += "{}\t{}\t{}{}\n".format( - key, test_set_wers[key], test_set_cers[key], note - ) + s += "{}\t{}\t{}{}\n".format(key, test_set_wers[key], test_set_cers[key], note) note = "" logging.info(s) @@ -496,9 +473,7 @@ def main(): params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-states-{params.max_states}" elif "beam_search" in params.decoding_method: - params.suffix += ( - f"-{params.decoding_method}-beam-size-{params.beam_size}" - ) + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -530,8 +505,7 @@ def main(): ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( diff --git a/egs/spgispeech/ASR/pruned_transducer_stateless2/export.py b/egs/spgispeech/ASR/pruned_transducer_stateless2/export.py index 77faa3c0e3..e79cb300d7 100755 --- a/egs/spgispeech/ASR/pruned_transducer_stateless2/export.py +++ b/egs/spgispeech/ASR/pruned_transducer_stateless2/export.py @@ -50,11 +50,7 @@ import torch from train import get_params, get_transducer_model -from icefall.checkpoint import ( - average_checkpoints, - find_checkpoints, - load_checkpoint, -) +from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint from icefall.utils import str2bool @@ -119,8 +115,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) return parser @@ -196,9 +191,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/spgispeech/ASR/pruned_transducer_stateless2/train.py b/egs/spgispeech/ASR/pruned_transducer_stateless2/train.py index dda29b3e5a..2136358944 100755 --- a/egs/spgispeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/spgispeech/ASR/pruned_transducer_stateless2/train.py @@ -77,9 +77,7 @@ from icefall.env import get_env_info from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool -LRSchedulerType = Union[ - torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler -] +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] def get_parser(): @@ -155,8 +153,7 @@ def get_parser(): "--initial-lr", type=float, default=0.003, - help="The initial learning rate. This value should not need to be " - "changed.", + help="The initial learning rate. This value should not need to be " "changed.", ) parser.add_argument( @@ -179,8 +176,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( @@ -203,8 +199,7 @@ def get_parser(): "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)" - "part.", + help="The scale to smooth the loss with am (output of encoder network)" "part.", ) parser.add_argument( @@ -554,23 +549,16 @@ def compute_loss( # overwhelming the simple_loss and causing it to diverge, # in case it had not fully learned the alignment yet. pruned_loss_scale = ( - 0.0 - if warmup < 1.0 - else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) - ) - loss = ( - params.simple_loss_scale * simple_loss - + pruned_loss_scale * pruned_loss + 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) ) + loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss assert loss.requires_grad == is_training info = MetricsTracker() with warnings.catch_warnings(): warnings.simplefilter("ignore") - info["frames"] = ( - (feature_lens // params.subsampling_factor).sum().item() - ) + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() # Note: We use reduction=sum while computing the loss. info["loss"] = loss.detach().cpu().item() @@ -733,9 +721,7 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") diff --git a/egs/tal_csasr/ASR/local/compute_fbank_tal_csasr.py b/egs/tal_csasr/ASR/local/compute_fbank_tal_csasr.py index 4582609ace..602e50d29d 100755 --- a/egs/tal_csasr/ASR/local/compute_fbank_tal_csasr.py +++ b/egs/tal_csasr/ASR/local/compute_fbank_tal_csasr.py @@ -84,9 +84,7 @@ def compute_fbank_tal_csasr(num_mel_bins: int = 80): ) if "train" in partition: cut_set = ( - cut_set - + cut_set.perturb_speed(0.9) - + cut_set.perturb_speed(1.1) + cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) ) cut_set = cut_set.compute_and_store_features( extractor=extractor, @@ -112,9 +110,7 @@ def get_args(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) diff --git a/egs/tal_csasr/ASR/local/prepare_char.py b/egs/tal_csasr/ASR/local/prepare_char.py index 2c5b8b8b3f..1262baf63e 100755 --- a/egs/tal_csasr/ASR/local/prepare_char.py +++ b/egs/tal_csasr/ASR/local/prepare_char.py @@ -87,9 +87,7 @@ def lexicon_to_fst_no_sil( cur_state = loop_state word = word2id[word] - pieces = [ - token2id[i] if i in token2id else token2id[""] for i in pieces - ] + pieces = [token2id[i] if i in token2id else token2id[" "] for i in pieces] for i in range(len(pieces) - 1): w = word if i == 0 else eps diff --git a/egs/tal_csasr/ASR/local/prepare_lang.py b/egs/tal_csasr/ASR/local/prepare_lang.py index e5ae89ec4f..c8cf9b881b 100755 --- a/egs/tal_csasr/ASR/local/prepare_lang.py +++ b/egs/tal_csasr/ASR/local/prepare_lang.py @@ -317,9 +317,7 @@ def lexicon_to_fst( def get_args(): parser = argparse.ArgumentParser() - parser.add_argument( - "--lang-dir", type=str, help="The lang dir, data/lang_phone" - ) + parser.add_argument("--lang-dir", type=str, help="The lang dir, data/lang_phone") return parser.parse_args() diff --git a/egs/tal_csasr/ASR/local/test_prepare_lang.py b/egs/tal_csasr/ASR/local/test_prepare_lang.py index d4cf62bba5..74e025ad78 100755 --- a/egs/tal_csasr/ASR/local/test_prepare_lang.py +++ b/egs/tal_csasr/ASR/local/test_prepare_lang.py @@ -88,9 +88,7 @@ def test_read_lexicon(filename: str): fsa.aux_labels_sym = k2.SymbolTable.from_file("words.txt") fsa.draw("L.pdf", title="L") - fsa_disambig = lexicon_to_fst( - lexicon_disambig, phone2id=phone2id, word2id=word2id - ) + fsa_disambig = lexicon_to_fst(lexicon_disambig, phone2id=phone2id, word2id=word2id) fsa_disambig.labels_sym = k2.SymbolTable.from_file("phones.txt") fsa_disambig.aux_labels_sym = k2.SymbolTable.from_file("words.txt") fsa_disambig.draw("L_disambig.pdf", title="L_disambig") diff --git a/egs/tal_csasr/ASR/local/text2token.py b/egs/tal_csasr/ASR/local/text2token.py index 71be2a6133..85047c367f 100755 --- a/egs/tal_csasr/ASR/local/text2token.py +++ b/egs/tal_csasr/ASR/local/text2token.py @@ -56,9 +56,7 @@ def get_parser(): parser.add_argument( "--skip-ncols", "-s", default=0, type=int, help="skip first n columns" ) - parser.add_argument( - "--space", default=" ", type=str, help="space symbol" - ) + parser.add_argument("--space", default=" ", type=str, help="space symbol") parser.add_argument( "--non-lang-syms", "-l", @@ -66,9 +64,7 @@ def get_parser(): type=str, help="list of non-linguistic symobles, e.g., etc.", ) - parser.add_argument( - "text", type=str, default=False, nargs="?", help="input text" - ) + parser.add_argument("text", type=str, default=False, nargs="?", help="input text") parser.add_argument( "--trans_type", "-t", @@ -108,8 +104,7 @@ def token2id( if token_type == "lazy_pinyin": text = lazy_pinyin(chars_list) sub_ids = [ - token_table[txt] if txt in token_table else oov_id - for txt in text + token_table[txt] if txt in token_table else oov_id for txt in text ] ids.append(sub_ids) else: # token_type = "pinyin" @@ -135,9 +130,7 @@ def main(): if args.text: f = codecs.open(args.text, encoding="utf-8") else: - f = codecs.getreader("utf-8")( - sys.stdin if is_python2 else sys.stdin.buffer - ) + f = codecs.getreader("utf-8")(sys.stdin if is_python2 else sys.stdin.buffer) sys.stdout = codecs.getwriter("utf-8")( sys.stdout if is_python2 else sys.stdout.buffer diff --git a/egs/tal_csasr/ASR/pruned_transducer_stateless5/asr_datamodule.py b/egs/tal_csasr/ASR/pruned_transducer_stateless5/asr_datamodule.py index 49bfb148b6..2240c1c1dd 100644 --- a/egs/tal_csasr/ASR/pruned_transducer_stateless5/asr_datamodule.py +++ b/egs/tal_csasr/ASR/pruned_transducer_stateless5/asr_datamodule.py @@ -222,17 +222,13 @@ def train_dataloaders( The state dict for the training sampler. """ logging.info("About to get Musan cuts") - cuts_musan = load_manifest( - self.args.manifest_dir / "musan_cuts.jsonl.gz" - ) + cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") transforms = [] if self.args.enable_musan: logging.info("Enable MUSAN") transforms.append( - CutMix( - cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True - ) + CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) ) else: logging.info("Disable MUSAN") @@ -254,9 +250,7 @@ def train_dataloaders( input_transforms = [] if self.args.enable_spec_aug: logging.info("Enable SpecAugment") - logging.info( - f"Time warp factor: {self.args.spec_aug_time_warp_factor}" - ) + logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") # Set the value of num_frame_masks according to Lhotse's version. # In different Lhotse's versions, the default of num_frame_masks is # different. @@ -300,9 +294,7 @@ def train_dataloaders( # Drop feats to be on the safe side. train = K2SpeechRecognitionDataset( cut_transforms=transforms, - input_strategy=OnTheFlyFeatures( - Fbank(FbankConfig(num_mel_bins=80)) - ), + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), input_transforms=input_transforms, return_cuts=self.args.return_cuts, ) @@ -360,9 +352,7 @@ def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: if self.args.on_the_fly_feats: validate = K2SpeechRecognitionDataset( cut_transforms=transforms, - input_strategy=OnTheFlyFeatures( - Fbank(FbankConfig(num_mel_bins=80)) - ), + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), return_cuts=self.args.return_cuts, ) else: diff --git a/egs/tal_csasr/ASR/pruned_transducer_stateless5/decode.py b/egs/tal_csasr/ASR/pruned_transducer_stateless5/decode.py index b624913f56..82e1a94374 100755 --- a/egs/tal_csasr/ASR/pruned_transducer_stateless5/decode.py +++ b/egs/tal_csasr/ASR/pruned_transducer_stateless5/decode.py @@ -208,8 +208,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -268,9 +267,7 @@ def decode_one_batch( supervisions = batch["supervisions"] feature_lens = supervisions["num_frames"].to(device) - encoder_out, encoder_out_lens = model.encoder( - x=feature, x_lens=feature_lens - ) + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) hyps = [] zh_hyps = [] en_hyps = [] @@ -303,10 +300,7 @@ def decode_one_batch( hyps.append(chars_new) zh_hyps.append(zh_text) en_hyps.append(en_text) - elif ( - params.decoding_method == "greedy_search" - and params.max_sym_per_frame == 1 - ): + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -375,9 +369,7 @@ def decode_one_batch( f"Unsupported decoding method: {params.decoding_method}" ) for i in range(encoder_out.size(0)): - hyp = sp.decode( - [lexicon.token_table[idx] for idx in hyp_tokens[i]] - ) + hyp = sp.decode([lexicon.token_table[idx] for idx in hyp_tokens[i]]) chars = pattern.split(hyp.upper()) chars_new = [] zh_text = [] @@ -506,9 +498,7 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results, zh_results, en_results @@ -541,8 +531,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -585,9 +574,7 @@ def main(): params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-states-{params.max_states}" elif "beam_search" in params.decoding_method: - params.suffix += ( - f"-{params.decoding_method}-beam-size-{params.beam_size}" - ) + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -619,9 +606,9 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -648,9 +635,9 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" diff --git a/egs/tal_csasr/ASR/pruned_transducer_stateless5/export.py b/egs/tal_csasr/ASR/pruned_transducer_stateless5/export.py index 8f900208ae..d0875c5f5b 100755 --- a/egs/tal_csasr/ASR/pruned_transducer_stateless5/export.py +++ b/egs/tal_csasr/ASR/pruned_transducer_stateless5/export.py @@ -139,8 +139,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) add_model_arguments(parser) @@ -176,9 +175,9 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -205,9 +204,9 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -277,9 +276,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/tal_csasr/ASR/pruned_transducer_stateless5/pretrained.py b/egs/tal_csasr/ASR/pruned_transducer_stateless5/pretrained.py index dbe213b244..da4e3bc2ff 100755 --- a/egs/tal_csasr/ASR/pruned_transducer_stateless5/pretrained.py +++ b/egs/tal_csasr/ASR/pruned_transducer_stateless5/pretrained.py @@ -165,8 +165,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -198,8 +197,7 @@ def read_sound_files( for f in filenames: wave, sample_rate = torchaudio.load(f) assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" + f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}" ) # We use only the first channel ans.append(wave[0]) @@ -263,15 +261,11 @@ def main(): features = fbank(waves) feature_lengths = [f.size(0) for f in features] - features = pad_sequence( - features, batch_first=True, padding_value=math.log(1e-10) - ) + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) feature_lengths = torch.tensor(feature_lengths, device=device) - encoder_out, encoder_out_lens = model.encoder( - x=features, x_lens=feature_lengths - ) + encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lengths) num_waves = encoder_out.size(0) hyps = [] @@ -367,9 +361,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/tal_csasr/ASR/pruned_transducer_stateless5/train.py b/egs/tal_csasr/ASR/pruned_transducer_stateless5/train.py index ca35eba452..97d4341573 100755 --- a/egs/tal_csasr/ASR/pruned_transducer_stateless5/train.py +++ b/egs/tal_csasr/ASR/pruned_transducer_stateless5/train.py @@ -86,9 +86,7 @@ from icefall.lexicon import Lexicon from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool -LRSchedulerType = Union[ - torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler -] +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] def add_model_arguments(parser: argparse.ArgumentParser): @@ -214,8 +212,7 @@ def get_parser(): "--initial-lr", type=float, default=0.003, - help="The initial learning rate. This value should not need " - "to be changed.", + help="The initial learning rate. This value should not need " "to be changed.", ) parser.add_argument( @@ -238,8 +235,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( @@ -262,8 +258,7 @@ def get_parser(): "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)" - "part.", + help="The scale to smooth the loss with am (output of encoder network)" "part.", ) parser.add_argument( @@ -600,11 +595,7 @@ def compute_loss( warmup: a floating point value which increases throughout training; values >= 1.0 are fully warmed up and have all modules present. """ - device = ( - model.device - if isinstance(model, DDP) - else next(model.parameters()).device - ) + device = model.device if isinstance(model, DDP) else next(model.parameters()).device feature = batch["inputs"] # at entry, feature is (N, T, C) assert feature.ndim == 3 @@ -634,22 +625,15 @@ def compute_loss( # overwhelming the simple_loss and causing it to diverge, # in case it had not fully learned the alignment yet. pruned_loss_scale = ( - 0.0 - if warmup < 1.0 - else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) - ) - loss = ( - params.simple_loss_scale * simple_loss - + pruned_loss_scale * pruned_loss + 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) ) + loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss assert loss.requires_grad == is_training info = MetricsTracker() with warnings.catch_warnings(): warnings.simplefilter("ignore") - info["frames"] = ( - (feature_lens // params.subsampling_factor).sum().item() - ) + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() # Note: We use reduction=sum while computing the loss. info["loss"] = loss.detach().cpu().item() @@ -828,9 +812,7 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -944,7 +926,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2 ** 22 + 2**22 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/tedlium3/ASR/local/compute_fbank_tedlium.py b/egs/tedlium3/ASR/local/compute_fbank_tedlium.py index 327962a799..733ebf2354 100755 --- a/egs/tedlium3/ASR/local/compute_fbank_tedlium.py +++ b/egs/tedlium3/ASR/local/compute_fbank_tedlium.py @@ -83,9 +83,7 @@ def compute_fbank_tedlium(): ) if "train" in partition: cut_set = ( - cut_set - + cut_set.perturb_speed(0.9) - + cut_set.perturb_speed(1.1) + cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) ) cur_num_jobs = num_jobs if ex is None else 80 cur_num_jobs = min(cur_num_jobs, len(cut_set)) @@ -104,9 +102,7 @@ def compute_fbank_tedlium(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) diff --git a/egs/tedlium3/ASR/local/convert_transcript_words_to_bpe_ids.py b/egs/tedlium3/ASR/local/convert_transcript_words_to_bpe_ids.py index 49544ccb30..9dbcc9d9e6 100644 --- a/egs/tedlium3/ASR/local/convert_transcript_words_to_bpe_ids.py +++ b/egs/tedlium3/ASR/local/convert_transcript_words_to_bpe_ids.py @@ -25,9 +25,7 @@ def get_args(): parser = argparse.ArgumentParser() - parser.add_argument( - "--texts", type=List[str], help="The input transcripts list." - ) + parser.add_argument("--texts", type=List[str], help="The input transcripts list.") parser.add_argument( "--bpe-model", type=str, diff --git a/egs/tedlium3/ASR/local/prepare_lexicon.py b/egs/tedlium3/ASR/local/prepare_lexicon.py index 35dd332e8a..b9160b6d4d 100755 --- a/egs/tedlium3/ASR/local/prepare_lexicon.py +++ b/egs/tedlium3/ASR/local/prepare_lexicon.py @@ -23,11 +23,12 @@ 1. Generate lexicon_words.txt. """ -import lhotse import argparse import logging from pathlib import Path +import lhotse + def get_args(): parser = argparse.ArgumentParser() @@ -61,9 +62,7 @@ def prepare_lexicon(manifests_dir: str, lang_dir: str): words = set() lexicon = Path(lang_dir) / "lexicon_words.txt" - sups = lhotse.load_manifest( - f"{manifests_dir}/tedlium_supervisions_train.jsonl.gz" - ) + sups = lhotse.load_manifest(f"{manifests_dir}/tedlium_supervisions_train.jsonl.gz") for s in sups: # list the words units and filter the empty item words_list = list(filter(None, s.text.split())) @@ -88,9 +87,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) diff --git a/egs/tedlium3/ASR/local/prepare_transcripts.py b/egs/tedlium3/ASR/local/prepare_transcripts.py index 1039ac5bbd..7ea4e89a44 100755 --- a/egs/tedlium3/ASR/local/prepare_transcripts.py +++ b/egs/tedlium3/ASR/local/prepare_transcripts.py @@ -23,11 +23,12 @@ 1. Generate train.text. """ -import lhotse import argparse import logging from pathlib import Path +import lhotse + def get_args(): parser = argparse.ArgumentParser() @@ -61,9 +62,7 @@ def prepare_transcripts(manifests_dir: str, lang_dir: str): texts = [] train_text = Path(lang_dir) / "train.text" - sups = lhotse.load_manifest( - f"{manifests_dir}/tedlium_supervisions_train.jsonl.gz" - ) + sups = lhotse.load_manifest(f"{manifests_dir}/tedlium_supervisions_train.jsonl.gz") for s in sups: texts.append(s.text) @@ -83,9 +82,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) diff --git a/egs/tedlium3/ASR/pruned_transducer_stateless/decode.py b/egs/tedlium3/ASR/pruned_transducer_stateless/decode.py index 2b294e601f..8ca875c242 100755 --- a/egs/tedlium3/ASR/pruned_transducer_stateless/decode.py +++ b/egs/tedlium3/ASR/pruned_transducer_stateless/decode.py @@ -172,8 +172,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -231,9 +230,7 @@ def decode_one_batch( supervisions = batch["supervisions"] feature_lens = supervisions["num_frames"].to(device) - encoder_out, encoder_out_lens = model.encoder( - x=feature, x_lens=feature_lens - ) + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) hyps = [] if params.decoding_method == "fast_beam_search": @@ -248,10 +245,7 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif ( - params.decoding_method == "greedy_search" - and params.max_sym_per_frame == 1 - ): + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -374,9 +368,7 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -409,8 +401,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) diff --git a/egs/tedlium3/ASR/pruned_transducer_stateless/export.py b/egs/tedlium3/ASR/pruned_transducer_stateless/export.py index a1c3bcea32..71a9e2d715 100644 --- a/egs/tedlium3/ASR/pruned_transducer_stateless/export.py +++ b/egs/tedlium3/ASR/pruned_transducer_stateless/export.py @@ -106,8 +106,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) return parser @@ -179,9 +178,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/tedlium3/ASR/pruned_transducer_stateless/pretrained.py b/egs/tedlium3/ASR/pruned_transducer_stateless/pretrained.py index 8480ac029c..e8a453c802 100644 --- a/egs/tedlium3/ASR/pruned_transducer_stateless/pretrained.py +++ b/egs/tedlium3/ASR/pruned_transducer_stateless/pretrained.py @@ -165,8 +165,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( @@ -204,8 +203,7 @@ def read_sound_files( for f in filenames: wave, sample_rate = torchaudio.load(f) assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" + f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}" ) # We use only the first channel ans.append(wave[0]) @@ -271,9 +269,7 @@ def main(): features = fbank(waves) feature_lengths = [f.size(0) for f in features] - features = pad_sequence( - features, batch_first=True, padding_value=math.log(1e-10) - ) + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) feature_lengths = torch.tensor(feature_lengths, device=device) @@ -298,10 +294,7 @@ def main(): ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif ( - params.decoding_method == "greedy_search" - and params.max_sym_per_frame == 1 - ): + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -353,9 +346,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/tedlium3/ASR/pruned_transducer_stateless/train.py b/egs/tedlium3/ASR/pruned_transducer_stateless/train.py index 8d5cdf6831..59d80a0d82 100755 --- a/egs/tedlium3/ASR/pruned_transducer_stateless/train.py +++ b/egs/tedlium3/ASR/pruned_transducer_stateless/train.py @@ -133,8 +133,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( @@ -157,8 +156,7 @@ def get_parser(): "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)" - "part.", + help="The scale to smooth the loss with am (output of encoder network)" "part.", ) parser.add_argument( @@ -556,9 +554,7 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -678,9 +674,7 @@ def remove_short_and_long_utt(c: Cut): cur_lr = optimizer._rate if tb_writer is not None: - tb_writer.add_scalar( - "train/learning_rate", cur_lr, params.batch_idx_train - ) + tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train) tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) if rank == 0: diff --git a/egs/tedlium3/ASR/transducer_stateless/asr_datamodule.py b/egs/tedlium3/ASR/transducer_stateless/asr_datamodule.py index 94784c4c4f..c647392f06 100644 --- a/egs/tedlium3/ASR/transducer_stateless/asr_datamodule.py +++ b/egs/tedlium3/ASR/transducer_stateless/asr_datamodule.py @@ -18,7 +18,6 @@ import argparse import logging - from functools import lru_cache from pathlib import Path from typing import Any, Dict, Optional @@ -171,9 +170,7 @@ def add_arguments(cls, parser: argparse.ArgumentParser): ) def train_dataloaders( - self, - cuts_train: CutSet, - sampler_state_dict: Optional[Dict[str, Any]] = None + self, cuts_train: CutSet, sampler_state_dict: Optional[Dict[str, Any]] = None ) -> DataLoader: """ Args: @@ -186,9 +183,7 @@ def train_dataloaders( input_transforms = [] if self.args.enable_spec_aug: logging.info("Enable SpecAugment") - logging.info( - f"Time warp factor: {self.args.spec_aug_time_warp_factor}" - ) + logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") input_transforms.append( SpecAugment( @@ -208,13 +203,9 @@ def train_dataloaders( transforms = [] if self.args.enable_musan: logging.info("Enable MUSAN") - cuts_musan = load_manifest( - self.args.manifest_dir / "musan_cuts.jsonl.gz" - ) + cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") transforms.append( - CutMix( - cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True - ) + CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) ) else: logging.info("Disable MUSAN") @@ -247,9 +238,7 @@ def train_dataloaders( # Drop feats to be on the safe side. train = K2SpeechRecognitionDataset( cut_transforms=transforms, - input_strategy=OnTheFlyFeatures( - Fbank(FbankConfig(num_mel_bins=80)) - ), + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), input_transforms=input_transforms, return_cuts=self.args.return_cuts, ) @@ -306,9 +295,7 @@ def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: if self.args.on_the_fly_feats: validate = K2SpeechRecognitionDataset( cut_transforms=transforms, - input_strategy=OnTheFlyFeatures( - Fbank(FbankConfig(num_mel_bins=80)) - ), + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), return_cuts=self.args.return_cuts, ) else: @@ -339,9 +326,7 @@ def test_dataloaders(self, cuts_test: CutSet) -> DataLoader: logging.debug("About to create test dataset") if self.args.on_the_fly_feats: test = K2SpeechRecognitionDataset( - input_strategy=OnTheFlyFeatures( - Fbank(FbankConfig(num_mel_bins=80)) - ), + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), return_cuts=self.args.return_cuts, ) else: @@ -375,13 +360,9 @@ def train_cuts(self) -> CutSet: @lru_cache() def dev_cuts(self) -> CutSet: logging.info("About to get dev cuts") - return load_manifest_lazy( - self.args.manifest_dir / "tedlium_cuts_dev.jsonl.gz" - ) + return load_manifest_lazy(self.args.manifest_dir / "tedlium_cuts_dev.jsonl.gz") @lru_cache() def test_cuts(self) -> CutSet: logging.info("About to get test cuts") - return load_manifest_lazy( - self.args.manifest_dir / "tedlium_cuts_test.jsonl.gz" - ) + return load_manifest_lazy(self.args.manifest_dir / "tedlium_cuts_test.jsonl.gz") diff --git a/egs/tedlium3/ASR/transducer_stateless/beam_search.py b/egs/tedlium3/ASR/transducer_stateless/beam_search.py index 77caf64601..1f99edaf3e 100644 --- a/egs/tedlium3/ASR/transducer_stateless/beam_search.py +++ b/egs/tedlium3/ASR/transducer_stateless/beam_search.py @@ -87,9 +87,9 @@ def greedy_search( y = logits.argmax().item() if y != blank_id and y != unk_id: hyp.append(y) - decoder_input = torch.tensor( - [hyp[-context_size:]], device=device - ).reshape(1, context_size) + decoder_input = torch.tensor([hyp[-context_size:]], device=device).reshape( + 1, context_size + ) decoder_out = model.decoder(decoder_input, need_pad=False) @@ -148,9 +148,7 @@ def add(self, hyp: Hypothesis) -> None: key = hyp.key if key in self: old_hyp = self._data[key] # shallow copy - torch.logaddexp( - old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob - ) + torch.logaddexp(old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob) else: self._data[key] = hyp @@ -166,9 +164,7 @@ def get_most_probable(self, length_norm: bool = False) -> Hypothesis: Return the hypothesis that has the largest `log_prob`. """ if length_norm: - return max( - self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys) - ) + return max(self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys)) else: return max(self._data.values(), key=lambda hyp: hyp.log_prob) @@ -344,9 +340,9 @@ def modified_beam_search( device = model.device - decoder_input = torch.tensor( - [blank_id] * context_size, device=device - ).reshape(1, context_size) + decoder_input = torch.tensor([blank_id] * context_size, device=device).reshape( + 1, context_size + ) decoder_out = model.decoder(decoder_input, need_pad=False) @@ -383,9 +379,7 @@ def modified_beam_search( decoder_out = model.decoder(decoder_input, need_pad=False) # decoder_output is of shape (num_hyps, 1, decoder_output_dim) - current_encoder_out = current_encoder_out.expand( - decoder_out.size(0), 1, -1 - ) + current_encoder_out = current_encoder_out.expand(decoder_out.size(0), 1, -1) logits = model.joiner( current_encoder_out, @@ -454,9 +448,9 @@ def beam_search( device = model.device - decoder_input = torch.tensor( - [blank_id] * context_size, device=device - ).reshape(1, context_size) + decoder_input = torch.tensor([blank_id] * context_size, device=device).reshape( + 1, context_size + ) decoder_out = model.decoder(decoder_input, need_pad=False) diff --git a/egs/tedlium3/ASR/transducer_stateless/decode.py b/egs/tedlium3/ASR/transducer_stateless/decode.py index d3e9e55e7e..e5ab2c1076 100755 --- a/egs/tedlium3/ASR/transducer_stateless/decode.py +++ b/egs/tedlium3/ASR/transducer_stateless/decode.py @@ -130,8 +130,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -250,9 +249,7 @@ def decode_one_batch( supervisions = batch["supervisions"] feature_lens = supervisions["num_frames"].to(device) - encoder_out, encoder_out_lens = model.encoder( - x=feature, x_lens=feature_lens - ) + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) hyps = [] batch_size = encoder_out.size(0) @@ -275,9 +272,7 @@ def decode_one_batch( model=model, encoder_out=encoder_out_i, beam=params.beam_size ) else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") hyps.append(sp.decode(hyp).split()) if params.decoding_method == "greedy_search": @@ -348,9 +343,7 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -383,8 +376,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) diff --git a/egs/tedlium3/ASR/transducer_stateless/decoder.py b/egs/tedlium3/ASR/transducer_stateless/decoder.py index f0c6f32b6a..f9a3814c6a 100644 --- a/egs/tedlium3/ASR/transducer_stateless/decoder.py +++ b/egs/tedlium3/ASR/transducer_stateless/decoder.py @@ -90,9 +90,7 @@ def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor: if self.context_size > 1: embedding_out = embedding_out.permute(0, 2, 1) if need_pad is True: - embedding_out = F.pad( - embedding_out, pad=(self.context_size - 1, 0) - ) + embedding_out = F.pad(embedding_out, pad=(self.context_size - 1, 0)) else: # During inference time, there is no need to do extra padding # as we only need one output diff --git a/egs/tedlium3/ASR/transducer_stateless/export.py b/egs/tedlium3/ASR/transducer_stateless/export.py index c32b1d0022..c2ec7a5901 100644 --- a/egs/tedlium3/ASR/transducer_stateless/export.py +++ b/egs/tedlium3/ASR/transducer_stateless/export.py @@ -110,8 +110,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) return parser @@ -247,9 +246,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/tedlium3/ASR/transducer_stateless/pretrained.py b/egs/tedlium3/ASR/transducer_stateless/pretrained.py index c0e3bb844b..070b070a72 100644 --- a/egs/tedlium3/ASR/transducer_stateless/pretrained.py +++ b/egs/tedlium3/ASR/transducer_stateless/pretrained.py @@ -127,8 +127,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -223,8 +222,7 @@ def read_sound_files( for f in filenames: wave, sample_rate = torchaudio.load(f) assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" + f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}" ) # We use only the first channel ans.append(wave[0]) @@ -285,9 +283,7 @@ def main(): features = fbank(waves) feature_lengths = [f.size(0) for f in features] - features = pad_sequence( - features, batch_first=True, padding_value=math.log(1e-10) - ) + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) feature_lengths = torch.tensor(feature_lengths, device=device) @@ -335,9 +331,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/tedlium3/ASR/transducer_stateless/train.py b/egs/tedlium3/ASR/transducer_stateless/train.py index 09cbf4a004..4fc13b1da4 100755 --- a/egs/tedlium3/ASR/transducer_stateless/train.py +++ b/egs/tedlium3/ASR/transducer_stateless/train.py @@ -133,8 +133,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( @@ -525,9 +524,7 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -647,9 +644,7 @@ def remove_short_and_long_utt(c: Cut): cur_lr = optimizer._rate if tb_writer is not None: - tb_writer.add_scalar( - "train/learning_rate", cur_lr, params.batch_idx_train - ) + tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train) tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) if rank == 0: diff --git a/egs/timit/ASR/RESULTS.md b/egs/timit/ASR/RESULTS.md index b78c16b889..d8ceb82b62 100644 --- a/egs/timit/ASR/RESULTS.md +++ b/egs/timit/ASR/RESULTS.md @@ -71,4 +71,4 @@ python tdnn_ligru_ctc/decode.py --epoch 25 \ --avg 17 \ --max-duration 20 \ --lang-dir data/lang_phone -``` \ No newline at end of file +``` diff --git a/egs/timit/ASR/local/compile_hlg.py b/egs/timit/ASR/local/compile_hlg.py index 58cab4cf2a..32c248d7e4 100644 --- a/egs/timit/ASR/local/compile_hlg.py +++ b/egs/timit/ASR/local/compile_hlg.py @@ -146,9 +146,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) diff --git a/egs/timit/ASR/local/compute_fbank_timit.py b/egs/timit/ASR/local/compute_fbank_timit.py index f25786a0c4..ecdf10ba92 100644 --- a/egs/timit/ASR/local/compute_fbank_timit.py +++ b/egs/timit/ASR/local/compute_fbank_timit.py @@ -85,9 +85,7 @@ def compute_fbank_timit(): ) if partition == "TRAIN": cut_set = ( - cut_set - + cut_set.perturb_speed(0.9) - + cut_set.perturb_speed(1.1) + cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) ) cut_set = cut_set.compute_and_store_features( extractor=extractor, @@ -101,9 +99,7 @@ def compute_fbank_timit(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) diff --git a/egs/timit/ASR/local/prepare_lexicon.py b/egs/timit/ASR/local/prepare_lexicon.py index 04023a9ab5..0cf0f0debc 100644 --- a/egs/timit/ASR/local/prepare_lexicon.py +++ b/egs/timit/ASR/local/prepare_lexicon.py @@ -62,9 +62,7 @@ def prepare_lexicon(manifests_dir: str, lang_dir: str): phones = set() - supervisions_train = ( - Path(manifests_dir) / "timit_supervisions_TRAIN.jsonl.gz" - ) + supervisions_train = Path(manifests_dir) / "timit_supervisions_TRAIN.jsonl.gz" lexicon = Path(lang_dir) / "lexicon.txt" logging.info(f"Loading {supervisions_train}!") @@ -97,9 +95,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) diff --git a/egs/timit/ASR/prepare.sh b/egs/timit/ASR/prepare.sh index ae1b96a68c..d11cd3a059 100644 --- a/egs/timit/ASR/prepare.sh +++ b/egs/timit/ASR/prepare.sh @@ -20,9 +20,9 @@ stop_stage=100 # - $dl_dir/lm # This directory contains the language model(LM) downloaded from # https://huggingface.co/luomingshuang/timit_lm, and the LM is based -# on 39 phones. About how to get these LM files, you can know it +# on 39 phones. About how to get these LM files, you can know it # from https://github.com/luomingshuang/Train_LM_with_kaldilm. -# +# # - lm_3_gram.arpa # - lm_4_gram.arpa # diff --git a/egs/timit/ASR/tdnn_ligru_ctc/decode.py b/egs/timit/ASR/tdnn_ligru_ctc/decode.py index 4f2aa23407..4beeed18c7 100644 --- a/egs/timit/ASR/tdnn_ligru_ctc/decode.py +++ b/egs/timit/ASR/tdnn_ligru_ctc/decode.py @@ -336,9 +336,7 @@ def decode_dataset( if batch_idx % 100 == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -400,9 +398,7 @@ def main(): logging.info(f"device: {device}") - HLG = k2.Fsa.from_dict( - torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu") - ) + HLG = k2.Fsa.from_dict(torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu")) HLG = HLG.to(device) assert HLG.requires_grad is False @@ -462,9 +458,7 @@ def main(): if params.export: logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt") - torch.save( - {"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt" - ) + torch.save({"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt") return model.to(device) @@ -485,9 +479,7 @@ def main(): G=G, ) - save_results( - params=params, test_set_name=test_set, results_dict=results_dict - ) + save_results(params=params, test_set_name=test_set, results_dict=results_dict) logging.info("Done!") diff --git a/egs/timit/ASR/tdnn_ligru_ctc/model.py b/egs/timit/ASR/tdnn_ligru_ctc/model.py index 4d2199ace4..9a594a9694 100644 --- a/egs/timit/ASR/tdnn_ligru_ctc/model.py +++ b/egs/timit/ASR/tdnn_ligru_ctc/model.py @@ -16,11 +16,11 @@ # limitations under the License. +from typing import Optional + import torch import torch.nn as nn - from torch import Tensor -from typing import Optional class TdnnLiGRU(nn.Module): @@ -261,9 +261,7 @@ def _forward_ligru(self, x, hx: Optional[Tensor]): h = [] if hx is not None: if self.bidirectional: - hx = hx.reshape( - self.num_layers, self.batch_size * 2, self.hidden_size - ) + hx = hx.reshape(self.num_layers, self.batch_size * 2, self.hidden_size) # Processing the different layers for i, ligru_lay in enumerate(self.rnn): if hx is not None: @@ -445,9 +443,7 @@ def _sample_drop_mask(self, w): if self.drop_mask_cnt + self.batch_size > self.N_drop_masks: self.drop_mask_cnt = 0 self.drop_masks = self.drop( - torch.ones( - self.N_drop_masks, self.hidden_size, device=w.device - ) + torch.ones(self.N_drop_masks, self.hidden_size, device=w.device) ).data # Sampling the mask diff --git a/egs/timit/ASR/tdnn_ligru_ctc/pretrained.py b/egs/timit/ASR/tdnn_ligru_ctc/pretrained.py index 7da2859442..4ef1344126 100644 --- a/egs/timit/ASR/tdnn_ligru_ctc/pretrained.py +++ b/egs/timit/ASR/tdnn_ligru_ctc/pretrained.py @@ -29,11 +29,7 @@ from model import TdnnLiGRU from torch.nn.utils.rnn import pad_sequence -from icefall.decode import ( - get_lattice, - one_best_decoding, - rescore_with_whole_lattice, -) +from icefall.decode import get_lattice, one_best_decoding, rescore_with_whole_lattice from icefall.utils import AttributeDict, get_texts @@ -58,9 +54,7 @@ def get_parser(): help="Path to words.txt", ) - parser.add_argument( - "--HLG", type=str, required=True, help="Path to HLG.pt." - ) + parser.add_argument("--HLG", type=str, required=True, help="Path to HLG.pt.") parser.add_argument( "--method", @@ -145,8 +139,7 @@ def read_sound_files( for f in filenames: wave, sample_rate = torchaudio.load(f) assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" + f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}" ) # We use only the first channel ans.append(wave[0]) @@ -215,9 +208,7 @@ def main(): logging.info("Decoding started") features = fbank(waves) - features = pad_sequence( - features, batch_first=True, padding_value=math.log(1e-10) - ) + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) features = features.permute(0, 2, 1) # now features is (N, C, T) with torch.no_grad(): @@ -269,9 +260,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/timit/ASR/tdnn_ligru_ctc/train.py b/egs/timit/ASR/tdnn_ligru_ctc/train.py index 452c2a7cb6..48b7feda07 100644 --- a/egs/timit/ASR/tdnn_ligru_ctc/train.py +++ b/egs/timit/ASR/tdnn_ligru_ctc/train.py @@ -449,9 +449,7 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) if batch_idx > 0 and batch_idx % params.valid_interval == 0: valid_info = compute_validation_loss( diff --git a/egs/timit/ASR/tdnn_lstm_ctc/asr_datamodule.py b/egs/timit/ASR/tdnn_lstm_ctc/asr_datamodule.py index 1554e987fa..51ca4cc6eb 100644 --- a/egs/timit/ASR/tdnn_lstm_ctc/asr_datamodule.py +++ b/egs/timit/ASR/tdnn_lstm_ctc/asr_datamodule.py @@ -154,9 +154,7 @@ def train_dataloaders(self) -> DataLoader: cuts_train = self.train_cuts() logging.info("About to get Musan cuts") - cuts_musan = load_manifest( - self.args.feature_dir / "musan_cuts.jsonl.gz" - ) + cuts_musan = load_manifest(self.args.feature_dir / "musan_cuts.jsonl.gz") logging.info("About to create train dataset") transforms = [CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20))] @@ -178,9 +176,9 @@ def train_dataloaders(self) -> DataLoader: # In different Lhotse's versions, the default of num_frame_masks is # different. num_frame_masks = 10 - num_frame_masks_parameter = inspect.signature( - SpecAugment.__init__ - ).parameters["num_frame_masks"] + num_frame_masks_parameter = inspect.signature(SpecAugment.__init__).parameters[ + "num_frame_masks" + ] if num_frame_masks_parameter.default == 1: num_frame_masks = 2 logging.info(f"Num frame mask: {num_frame_masks}") @@ -212,9 +210,7 @@ def train_dataloaders(self) -> DataLoader: # Drop feats to be on the safe side. train = K2SpeechRecognitionDataset( cut_transforms=transforms, - input_strategy=OnTheFlyFeatures( - Fbank(FbankConfig(num_mel_bins=80)) - ), + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), input_transforms=input_transforms, return_cuts=self.args.return_cuts, ) @@ -263,9 +259,7 @@ def valid_dataloaders(self) -> DataLoader: if self.args.on_the_fly_feats: validate = K2SpeechRecognitionDataset( cut_transforms=transforms, - input_strategy=OnTheFlyFeatures( - Fbank(FbankConfig(num_mel_bins=80)) - ), + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), return_cuts=self.args.return_cuts, ) else: @@ -299,20 +293,14 @@ def test_dataloaders(self) -> Union[DataLoader, List[DataLoader]]: for cuts_test in cuts: logging.debug("About to create test dataset") test = K2SpeechRecognitionDataset( - input_strategy=OnTheFlyFeatures( - Fbank(FbankConfig(num_mel_bins=80)) - ) + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) if self.args.on_the_fly_feats else PrecomputedFeatures(), return_cuts=self.args.return_cuts, ) - sampler = SingleCutSampler( - cuts_test, max_duration=self.args.max_duration - ) + sampler = SingleCutSampler(cuts_test, max_duration=self.args.max_duration) logging.debug("About to create test dataloader") - test_dl = DataLoader( - test, batch_size=None, sampler=sampler, num_workers=1 - ) + test_dl = DataLoader(test, batch_size=None, sampler=sampler, num_workers=1) test_loaders.append(test_dl) if is_list: diff --git a/egs/timit/ASR/tdnn_lstm_ctc/decode.py b/egs/timit/ASR/tdnn_lstm_ctc/decode.py index 5e7300cf2a..502a48def8 100644 --- a/egs/timit/ASR/tdnn_lstm_ctc/decode.py +++ b/egs/timit/ASR/tdnn_lstm_ctc/decode.py @@ -335,9 +335,7 @@ def decode_dataset( if batch_idx % 100 == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -399,9 +397,7 @@ def main(): logging.info(f"device: {device}") - HLG = k2.Fsa.from_dict( - torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu") - ) + HLG = k2.Fsa.from_dict(torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu")) HLG = HLG.to(device) assert HLG.requires_grad is False @@ -461,9 +457,7 @@ def main(): if params.export: logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt") - torch.save( - {"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt" - ) + torch.save({"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt") return model.to(device) @@ -483,9 +477,7 @@ def main(): G=G, ) - save_results( - params=params, test_set_name=test_set, results_dict=results_dict - ) + save_results(params=params, test_set_name=test_set, results_dict=results_dict) logging.info("Done!") diff --git a/egs/timit/ASR/tdnn_lstm_ctc/model.py b/egs/timit/ASR/tdnn_lstm_ctc/model.py index 51edb97e2c..e211ad80d8 100644 --- a/egs/timit/ASR/tdnn_lstm_ctc/model.py +++ b/egs/timit/ASR/tdnn_lstm_ctc/model.py @@ -74,10 +74,7 @@ def __init__( nn.BatchNorm1d(num_features=512, affine=False), ) self.lstms = nn.ModuleList( - [ - nn.LSTM(input_size=512, hidden_size=512, num_layers=1) - for _ in range(4) - ] + [nn.LSTM(input_size=512, hidden_size=512, num_layers=1) for _ in range(4)] ) self.lstm_bnorms = nn.ModuleList( [nn.BatchNorm1d(num_features=512, affine=False) for _ in range(5)] diff --git a/egs/timit/ASR/tdnn_lstm_ctc/pretrained.py b/egs/timit/ASR/tdnn_lstm_ctc/pretrained.py index 5f478da1c0..3f143912e3 100644 --- a/egs/timit/ASR/tdnn_lstm_ctc/pretrained.py +++ b/egs/timit/ASR/tdnn_lstm_ctc/pretrained.py @@ -29,11 +29,7 @@ from model import TdnnLstm from torch.nn.utils.rnn import pad_sequence -from icefall.decode import ( - get_lattice, - one_best_decoding, - rescore_with_whole_lattice, -) +from icefall.decode import get_lattice, one_best_decoding, rescore_with_whole_lattice from icefall.utils import AttributeDict, get_texts @@ -58,9 +54,7 @@ def get_parser(): help="Path to words.txt", ) - parser.add_argument( - "--HLG", type=str, required=True, help="Path to HLG.pt." - ) + parser.add_argument("--HLG", type=str, required=True, help="Path to HLG.pt.") parser.add_argument( "--method", @@ -145,8 +139,7 @@ def read_sound_files( for f in filenames: wave, sample_rate = torchaudio.load(f) assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" + f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}" ) # We use only the first channel ans.append(wave[0]) @@ -215,9 +208,7 @@ def main(): logging.info("Decoding started") features = fbank(waves) - features = pad_sequence( - features, batch_first=True, padding_value=math.log(1e-10) - ) + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) features = features.permute(0, 2, 1) # now features is (N, C, T) with torch.no_grad(): @@ -269,9 +260,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/timit/ASR/tdnn_lstm_ctc/train.py b/egs/timit/ASR/tdnn_lstm_ctc/train.py index 849256b986..be1ecffaa2 100644 --- a/egs/timit/ASR/tdnn_lstm_ctc/train.py +++ b/egs/timit/ASR/tdnn_lstm_ctc/train.py @@ -449,9 +449,7 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) if batch_idx > 0 and batch_idx % params.valid_interval == 0: valid_info = compute_validation_loss( diff --git a/egs/wenetspeech/ASR/local/compute_fbank_wenetspeech_dev_test.py b/egs/wenetspeech/ASR/local/compute_fbank_wenetspeech_dev_test.py index 8a9f6ed301..bd73e520e0 100755 --- a/egs/wenetspeech/ASR/local/compute_fbank_wenetspeech_dev_test.py +++ b/egs/wenetspeech/ASR/local/compute_fbank_wenetspeech_dev_test.py @@ -20,12 +20,7 @@ from pathlib import Path import torch -from lhotse import ( - CutSet, - KaldifeatFbank, - KaldifeatFbankConfig, - LilcomHdf5Writer, -) +from lhotse import CutSet, KaldifeatFbank, KaldifeatFbankConfig, LilcomHdf5Writer # Torch's multithreaded behavior needs to be disabled or # it wastes a lot of CPU and slow things down. @@ -83,9 +78,7 @@ def compute_fbank_wenetspeech_dev_test(): def main(): - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) compute_fbank_wenetspeech_dev_test() diff --git a/egs/wenetspeech/ASR/local/compute_fbank_wenetspeech_splits.py b/egs/wenetspeech/ASR/local/compute_fbank_wenetspeech_splits.py index a882b61138..1b257fb702 100755 --- a/egs/wenetspeech/ASR/local/compute_fbank_wenetspeech_splits.py +++ b/egs/wenetspeech/ASR/local/compute_fbank_wenetspeech_splits.py @@ -152,9 +152,7 @@ def main(): date_time = now.strftime("%Y-%m-%d-%H-%M-%S") log_filename = "log-compute_fbank_wenetspeech_splits" - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" log_filename = f"{log_filename}-{date_time}" logging.basicConfig( diff --git a/egs/wenetspeech/ASR/local/prepare_char.py b/egs/wenetspeech/ASR/local/prepare_char.py index 8bc073c75f..d8622842f2 100755 --- a/egs/wenetspeech/ASR/local/prepare_char.py +++ b/egs/wenetspeech/ASR/local/prepare_char.py @@ -83,9 +83,7 @@ def lexicon_to_fst_no_sil( cur_state = loop_state word = word2id[word] - pieces = [ - token2id[i] if i in token2id else token2id[" "] for i in pieces - ] + pieces = [token2id[i] if i in token2id else token2id[" "] for i in pieces] for i in range(len(pieces) - 1): w = word if i == 0 else eps @@ -138,9 +136,7 @@ def contain_oov(token_sym_table: Dict[str, int], tokens: List[str]) -> bool: return False -def generate_lexicon( - token_sym_table: Dict[str, int], words: List[str] -) -> Lexicon: +def generate_lexicon(token_sym_table: Dict[str, int], words: List[str]) -> Lexicon: """Generate a lexicon from a word list and token_sym_table. Args: token_sym_table: diff --git a/egs/wenetspeech/ASR/local/preprocess_wenetspeech.py b/egs/wenetspeech/ASR/local/preprocess_wenetspeech.py index 817969c470..93ce750f82 100755 --- a/egs/wenetspeech/ASR/local/preprocess_wenetspeech.py +++ b/egs/wenetspeech/ASR/local/preprocess_wenetspeech.py @@ -115,11 +115,7 @@ def preprocess_wenet_speech(): f"Speed perturb for {partition} with factors 0.9 and 1.1 " "(Perturbing may take 8 minutes and saving may take 20 minutes)" ) - cut_set = ( - cut_set - + cut_set.perturb_speed(0.9) - + cut_set.perturb_speed(1.1) - ) + cut_set = cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) logging.info(f"Saving to {raw_cuts_path}") cut_set.to_file(raw_cuts_path) diff --git a/egs/wenetspeech/ASR/local/text2token.py b/egs/wenetspeech/ASR/local/text2token.py index 1c463cf1ca..d1d237a681 100755 --- a/egs/wenetspeech/ASR/local/text2token.py +++ b/egs/wenetspeech/ASR/local/text2token.py @@ -56,9 +56,7 @@ def get_parser(): parser.add_argument( "--skip-ncols", "-s", default=0, type=int, help="skip first n columns" ) - parser.add_argument( - "--space", default=" ", type=str, help="space symbol" - ) + parser.add_argument("--space", default=" ", type=str, help="space symbol") parser.add_argument( "--non-lang-syms", "-l", @@ -66,9 +64,7 @@ def get_parser(): type=str, help="list of non-linguistic symobles, e.g., etc.", ) - parser.add_argument( - "text", type=str, default=False, nargs="?", help="input text" - ) + parser.add_argument("text", type=str, default=False, nargs="?", help="input text") parser.add_argument( "--trans_type", "-t", @@ -108,8 +104,7 @@ def token2id( if token_type == "lazy_pinyin": text = lazy_pinyin(chars_list) sub_ids = [ - token_table[txt] if txt in token_table else oov_id - for txt in text + token_table[txt] if txt in token_table else oov_id for txt in text ] ids.append(sub_ids) else: # token_type = "pinyin" @@ -135,9 +130,7 @@ def main(): if args.text: f = codecs.open(args.text, encoding="utf-8") else: - f = codecs.getreader("utf-8")( - sys.stdin if is_python2 else sys.stdin.buffer - ) + f = codecs.getreader("utf-8")(sys.stdin if is_python2 else sys.stdin.buffer) sys.stdout = codecs.getwriter("utf-8")( sys.stdout if is_python2 else sys.stdout.buffer diff --git a/egs/wenetspeech/ASR/prepare.sh b/egs/wenetspeech/ASR/prepare.sh index 755fbb2d70..da7d7e061a 100755 --- a/egs/wenetspeech/ASR/prepare.sh +++ b/egs/wenetspeech/ASR/prepare.sh @@ -190,7 +190,7 @@ if [ $stage -le 15 ] && [ $stop_stage -ge 15 ]; then mkdir -p $lang_char_dir if ! which jq; then - echo "This script is intended to be used with jq but you have not installed jq + echo "This script is intended to be used with jq but you have not installed jq Note: in Linux, you can install jq with the following command: 1. wget -O jq https://github.com/stedolan/jq/releases/download/jq-1.6/jq-linux64 2. chmod +x ./jq diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py index 10c953e3b3..9c07263a23 100644 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py @@ -212,17 +212,13 @@ def train_dataloaders( The state dict for the training sampler. """ logging.info("About to get Musan cuts") - cuts_musan = load_manifest( - self.args.manifest_dir / "musan_cuts.jsonl.gz" - ) + cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") transforms = [] if self.args.enable_musan: logging.info("Enable MUSAN") transforms.append( - CutMix( - cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True - ) + CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) ) else: logging.info("Disable MUSAN") @@ -244,9 +240,7 @@ def train_dataloaders( input_transforms = [] if self.args.enable_spec_aug: logging.info("Enable SpecAugment") - logging.info( - f"Time warp factor: {self.args.spec_aug_time_warp_factor}" - ) + logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") # Set the value of num_frame_masks according to Lhotse's version. # In different Lhotse's versions, the default of num_frame_masks is # different. @@ -289,9 +283,7 @@ def train_dataloaders( # Drop feats to be on the safe side. train = K2SpeechRecognitionDataset( cut_transforms=transforms, - input_strategy=OnTheFlyFeatures( - Fbank(FbankConfig(num_mel_bins=80)) - ), + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), input_transforms=input_transforms, return_cuts=self.args.return_cuts, ) @@ -348,9 +340,7 @@ def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: if self.args.on_the_fly_feats: validate = K2SpeechRecognitionDataset( cut_transforms=transforms, - input_strategy=OnTheFlyFeatures( - Fbank(FbankConfig(num_mel_bins=80)) - ), + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), return_cuts=self.args.return_cuts, ) else: @@ -414,8 +404,7 @@ def test_dataloaders(self, cuts: CutSet) -> DataLoader: def train_cuts(self) -> CutSet: logging.info("About to get train cuts") cuts_train = load_manifest_lazy( - self.args.manifest_dir - / f"cuts_{self.args.training_subset}.jsonl.gz" + self.args.manifest_dir / f"cuts_{self.args.training_subset}.jsonl.gz" ) return cuts_train @@ -427,13 +416,9 @@ def valid_cuts(self) -> CutSet: @lru_cache() def test_net_cuts(self) -> List[CutSet]: logging.info("About to get TEST_NET cuts") - return load_manifest_lazy( - self.args.manifest_dir / "cuts_TEST_NET.jsonl.gz" - ) + return load_manifest_lazy(self.args.manifest_dir / "cuts_TEST_NET.jsonl.gz") @lru_cache() def test_meeting_cuts(self) -> List[CutSet]: logging.info("About to get TEST_MEETING cuts") - return load_manifest_lazy( - self.args.manifest_dir / "cuts_TEST_MEETING.jsonl.gz" - ) + return load_manifest_lazy(self.args.manifest_dir / "cuts_TEST_MEETING.jsonl.gz") diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/decode.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/decode.py index f0c9bebec6..cd9ed57b9f 100755 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/decode.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/decode.py @@ -114,11 +114,7 @@ from train import get_params, get_transducer_model from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler -from icefall.checkpoint import ( - average_checkpoints, - find_checkpoints, - load_checkpoint, -) +from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint from icefall.lexicon import Lexicon from icefall.utils import ( AttributeDict, @@ -252,8 +248,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -328,9 +323,7 @@ def decode_one_batch( supervisions = batch["supervisions"] feature_lens = supervisions["num_frames"].to(device) - encoder_out, encoder_out_lens = model.encoder( - x=feature, x_lens=feature_lens - ) + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) hyps = [] if params.decoding_method == "fast_beam_search": @@ -389,10 +382,7 @@ def decode_one_batch( ) for i in range(encoder_out.size(0)): hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) - elif ( - params.decoding_method == "greedy_search" - and params.max_sym_per_frame == 1 - ): + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -515,9 +505,7 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -550,8 +538,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -663,9 +650,7 @@ def main(): ) decoding_graph.scores *= params.ngram_lm_scale else: - decoding_graph = k2.trivial_graph( - params.vocab_size - 1, device=device - ) + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) else: decoding_graph = None @@ -716,8 +701,7 @@ def main(): ) dev_shards = [ - str(path) - for path in sorted(glob.glob(os.path.join(dev, "shared-*.tar"))) + str(path) for path in sorted(glob.glob(os.path.join(dev, "shared-*.tar"))) ] cuts_dev_webdataset = CutSet.from_webdataset( dev_shards, @@ -727,8 +711,7 @@ def main(): ) test_net_shards = [ - str(path) - for path in sorted(glob.glob(os.path.join(test_net, "shared-*.tar"))) + str(path) for path in sorted(glob.glob(os.path.join(test_net, "shared-*.tar"))) ] cuts_test_net_webdataset = CutSet.from_webdataset( test_net_shards, @@ -739,9 +722,7 @@ def main(): test_meeting_shards = [ str(path) - for path in sorted( - glob.glob(os.path.join(test_meeting, "shared-*.tar")) - ) + for path in sorted(glob.glob(os.path.join(test_meeting, "shared-*.tar"))) ] cuts_test_meeting_webdataset = CutSet.from_webdataset( test_meeting_shards, diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/export.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/export.py index 933642a0fa..df2fc5df55 100755 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/export.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/export.py @@ -205,8 +205,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) return parser @@ -468,13 +467,9 @@ def export_joiner_model_onnx( - projected_decoder_out: a tensor of shape (N, joiner_dim) """ - encoder_proj_filename = str(joiner_filename).replace( - ".onnx", "_encoder_proj.onnx" - ) + encoder_proj_filename = str(joiner_filename).replace(".onnx", "_encoder_proj.onnx") - decoder_proj_filename = str(joiner_filename).replace( - ".onnx", "_decoder_proj.onnx" - ) + decoder_proj_filename = str(joiner_filename).replace(".onnx", "_decoder_proj.onnx") encoder_out_dim = joiner_model.encoder_proj.weight.shape[1] decoder_out_dim = joiner_model.decoder_proj.weight.shape[1] @@ -645,9 +640,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/jit_pretrained.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/jit_pretrained.py index e5cc47bfef..42ffbcfb8d 100755 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/jit_pretrained.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/jit_pretrained.py @@ -146,8 +146,7 @@ def read_sound_files( for f in filenames: wave, sample_rate = torchaudio.load(f) assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" + f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}" ) # We use only the first channel ans.append(wave[0]) @@ -331,9 +330,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/onnx_check.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/onnx_check.py index c396c50ef9..a46ff5a07a 100755 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/onnx_check.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/onnx_check.py @@ -219,9 +219,7 @@ def test_joiner( ) # Now test encoder_proj - joiner_encoder_proj_inputs = { - encoder_proj_input_name: encoder_out.numpy() - } + joiner_encoder_proj_inputs = {encoder_proj_input_name: encoder_out.numpy()} joiner_encoder_proj_out = joiner_encoder_proj_session.run( [encoder_proj_output_name], joiner_encoder_proj_inputs )[0] @@ -230,16 +228,10 @@ def test_joiner( torch_joiner_encoder_proj_out = model.joiner.encoder_proj(encoder_out) assert torch.allclose( joiner_encoder_proj_out, torch_joiner_encoder_proj_out, atol=1e-5 - ), ( - (joiner_encoder_proj_out - torch_joiner_encoder_proj_out) - .abs() - .max() - ) + ), ((joiner_encoder_proj_out - torch_joiner_encoder_proj_out).abs().max()) # Now test decoder_proj - joiner_decoder_proj_inputs = { - decoder_proj_input_name: decoder_out.numpy() - } + joiner_decoder_proj_inputs = {decoder_proj_input_name: decoder_out.numpy()} joiner_decoder_proj_out = joiner_decoder_proj_session.run( [decoder_proj_output_name], joiner_decoder_proj_inputs )[0] @@ -248,11 +240,7 @@ def test_joiner( torch_joiner_decoder_proj_out = model.joiner.decoder_proj(decoder_out) assert torch.allclose( joiner_decoder_proj_out, torch_joiner_decoder_proj_out, atol=1e-5 - ), ( - (joiner_decoder_proj_out - torch_joiner_decoder_proj_out) - .abs() - .max() - ) + ), ((joiner_decoder_proj_out - torch_joiner_decoder_proj_out).abs().max()) @torch.no_grad() @@ -304,9 +292,7 @@ def main(): if __name__ == "__main__": torch.manual_seed(20220727) - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/onnx_pretrained.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/onnx_pretrained.py index 3770fbbb42..ca1e408fa7 100755 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/onnx_pretrained.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/onnx_pretrained.py @@ -150,8 +150,7 @@ def read_sound_files( for f in filenames: wave, sample_rate = torchaudio.load(f) assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" + f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}" ) # We use only the first channel ans.append(wave[0]) @@ -200,11 +199,7 @@ def greedy_search( projected_encoder_out = joiner_encoder_proj.run( [joiner_encoder_proj.get_outputs()[0].name], - { - joiner_encoder_proj.get_inputs()[ - 0 - ].name: packed_encoder_out.data.numpy() - }, + {joiner_encoder_proj.get_inputs()[0].name: packed_encoder_out.data.numpy()}, )[0] blank_id = 0 # hard-code to 0 @@ -389,9 +384,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/pretrained.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/pretrained.py index 9a549efd94..aaf7ac8740 100755 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/pretrained.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/pretrained.py @@ -158,8 +158,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( @@ -190,8 +189,7 @@ def read_sound_files( for f in filenames: wave, sample_rate = torchaudio.load(f) assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" + f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}" ) # We use only the first channel ans.append(wave[0]) @@ -253,9 +251,7 @@ def main(): features = fbank(waves) feature_lengths = [f.size(0) for f in features] - features = pad_sequence( - features, batch_first=True, padding_value=math.log(1e-10) - ) + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) feature_lengths = torch.tensor(feature_lengths, device=device) @@ -280,10 +276,7 @@ def main(): ) for i in range(encoder_out.size(0)): hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) - elif ( - params.decoding_method == "greedy_search" - and params.max_sym_per_frame == 1 - ): + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -335,9 +328,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/train.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/train.py index d3cc7c9c92..7aba0711da 100644 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/train.py @@ -115,9 +115,7 @@ from icefall.lexicon import Lexicon from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool -LRSchedulerType = Union[ - torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler -] +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] def get_parser(): @@ -219,8 +217,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( @@ -243,8 +240,7 @@ def get_parser(): "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)" - "part.", + help="The scale to smooth the loss with am (output of encoder network)" "part.", ) parser.add_argument( @@ -590,22 +586,15 @@ def compute_loss( # overwhelming the simple_loss and causing it to diverge, # in case it had not fully learned the alignment yet. pruned_loss_scale = ( - 0.0 - if warmup < 1.0 - else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) - ) - loss = ( - params.simple_loss_scale * simple_loss - + pruned_loss_scale * pruned_loss + 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) ) + loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss assert loss.requires_grad == is_training info = MetricsTracker() with warnings.catch_warnings(): warnings.simplefilter("ignore") - info["frames"] = ( - (feature_lens // params.subsampling_factor).sum().item() - ) + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() # Note: We use reduction=sum while computing the loss. info["loss"] = loss.detach().cpu().item() @@ -762,9 +751,7 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -864,7 +851,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2 ** 22 + 2**22 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/conformer.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/conformer.py index dd27c17f0c..9bb55d07ad 100644 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/conformer.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/conformer.py @@ -210,10 +210,7 @@ def get_init_state( (num_encoder_layers, cnn_module_kernel - 1, encoder_dim). NOTE: the returned tensors are on the given device. """ - if ( - len(self._init_state) == 2 - and self._init_state[0].size(1) == left_context - ): + if len(self._init_state) == 2 and self._init_state[0].size(1) == left_context: # Note: It is OK to share the init state as it is # not going to be modified by the model return self._init_state @@ -433,9 +430,7 @@ def __init__( self.d_model = d_model - self.self_attn = RelPositionMultiheadAttention( - d_model, nhead, dropout=0.0 - ) + self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0) self.feed_forward = nn.Sequential( ScaledLinear(d_model, dim_feedforward), @@ -453,9 +448,7 @@ def __init__( ScaledLinear(dim_feedforward, d_model, initial_scale=0.25), ) - self.conv_module = ConvolutionModule( - d_model, cnn_module_kernel, causal=causal - ) + self.conv_module = ConvolutionModule(d_model, cnn_module_kernel, causal=causal) self.norm_final = BasicNorm(d_model) @@ -520,9 +513,7 @@ def forward( src = src + self.dropout(src_att) # convolution module - conv, _ = self.conv_module( - src, src_key_padding_mask=src_key_padding_mask - ) + conv, _ = self.conv_module(src, src_key_padding_mask=src_key_padding_mask) src = src + self.dropout(conv) # feed forward module @@ -766,9 +757,7 @@ class RelPositionalEncoding(torch.nn.Module): max_len: Maximum input length. """ - def __init__( - self, d_model: int, dropout_rate: float, max_len: int = 5000 - ) -> None: + def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None: """Construct an PositionalEncoding object.""" super(RelPositionalEncoding, self).__init__() self.d_model = d_model @@ -784,9 +773,7 @@ def extend_pe(self, x: Tensor, left_context: int = 0) -> None: # the length of self.pe is 2 * input_len - 1 if self.pe.size(1) >= x_size_1 * 2 - 1: # Note: TorchScript doesn't implement operator== for torch.Device - if self.pe.dtype != x.dtype or str(self.pe.device) != str( - x.device - ): + if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device): self.pe = self.pe.to(dtype=x.dtype, device=x.device) return # Suppose `i` means to the position of query vector and `j` means the @@ -1073,9 +1060,9 @@ def multi_head_attention_forward( if torch.equal(query, key) and torch.equal(key, value): # self-attention - q, k, v = nn.functional.linear( - query, in_proj_weight, in_proj_bias - ).chunk(3, dim=-1) + q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk( + 3, dim=-1 + ) elif torch.equal(key, value): # encoder-decoder attention @@ -1144,31 +1131,22 @@ def multi_head_attention_forward( if attn_mask.dim() == 2: attn_mask = attn_mask.unsqueeze(0) if list(attn_mask.size()) != [1, query.size(0), key.size(0)]: - raise RuntimeError( - "The size of the 2D attn_mask is not correct." - ) + raise RuntimeError("The size of the 2D attn_mask is not correct.") elif attn_mask.dim() == 3: if list(attn_mask.size()) != [ bsz * num_heads, query.size(0), key.size(0), ]: - raise RuntimeError( - "The size of the 3D attn_mask is not correct." - ) + raise RuntimeError("The size of the 3D attn_mask is not correct.") else: raise RuntimeError( - "attn_mask's dimension {} is not supported".format( - attn_mask.dim() - ) + "attn_mask's dimension {} is not supported".format(attn_mask.dim()) ) # attn_mask's dim is 3 now. # convert ByteTensor key_padding_mask to bool - if ( - key_padding_mask is not None - and key_padding_mask.dtype == torch.uint8 - ): + if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: warnings.warn( "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead." ) @@ -1208,23 +1186,15 @@ def multi_head_attention_forward( # first compute matrix a and matrix c # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) - matrix_ac = torch.matmul( - q_with_bias_u, k - ) # (batch, head, time1, time2) + matrix_ac = torch.matmul(q_with_bias_u, k) # (batch, head, time1, time2) # compute matrix b and matrix d - matrix_bd = torch.matmul( - q_with_bias_v, p - ) # (batch, head, time1, 2*time1-1) + matrix_bd = torch.matmul(q_with_bias_v, p) # (batch, head, time1, 2*time1-1) matrix_bd = self.rel_shift(matrix_bd, left_context) - attn_output_weights = ( - matrix_ac + matrix_bd - ) # (batch, head, time1, time2) + attn_output_weights = matrix_ac + matrix_bd # (batch, head, time1, time2) - attn_output_weights = attn_output_weights.view( - bsz * num_heads, tgt_len, -1 - ) + attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1) assert list(attn_output_weights.size()) == [ bsz * num_heads, @@ -1265,21 +1235,17 @@ def multi_head_attention_forward( ): if attn_mask.size(0) != 1: attn_mask = attn_mask.view(bsz, num_heads, tgt_len, src_len) - combined_mask = attn_mask | key_padding_mask.unsqueeze( - 1 - ).unsqueeze(2) + combined_mask = attn_mask | key_padding_mask.unsqueeze(1).unsqueeze(2) else: # attn_mask.shape == (1, tgt_len, src_len) - combined_mask = attn_mask.unsqueeze( - 0 - ) | key_padding_mask.unsqueeze(1).unsqueeze(2) + combined_mask = attn_mask.unsqueeze(0) | key_padding_mask.unsqueeze( + 1 + ).unsqueeze(2) attn_output_weights = attn_output_weights.view( bsz, num_heads, tgt_len, src_len ) - attn_output_weights = attn_output_weights.masked_fill( - combined_mask, 0.0 - ) + attn_output_weights = attn_output_weights.masked_fill(combined_mask, 0.0) attn_output_weights = attn_output_weights.view( bsz * num_heads, tgt_len, src_len ) @@ -1291,13 +1257,9 @@ def multi_head_attention_forward( attn_output = torch.bmm(attn_output_weights, v) assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim] attn_output = ( - attn_output.transpose(0, 1) - .contiguous() - .view(tgt_len, bsz, embed_dim) - ) - attn_output = nn.functional.linear( - attn_output, out_proj_weight, out_proj_bias + attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) ) + attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias) if need_weights: # average attention weights over heads @@ -1430,16 +1392,12 @@ def forward( # manualy padding self.lorder zeros to the left x = nn.functional.pad(x, (self.lorder, 0), "constant", 0.0) else: - assert ( - not self.training - ), "Cache should be None in training time" + assert not self.training, "Cache should be None in training time" assert cache.size(0) == self.lorder x = torch.cat([cache.permute(1, 2, 0), x], dim=2) if right_context > 0: cache = x.permute(2, 0, 1)[ - -(self.lorder + right_context) : ( # noqa - -right_context - ), + -(self.lorder + right_context) : (-right_context), # noqa ..., ] else: diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode.py index 344e312835..166497c31f 100755 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode.py @@ -244,8 +244,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -342,9 +341,7 @@ def decode_one_batch( simulate_streaming=True, ) else: - encoder_out, encoder_out_lens = model.encoder( - x=feature, x_lens=feature_lens - ) + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) hyps = [] @@ -360,10 +357,7 @@ def decode_one_batch( ) for i in range(encoder_out.size(0)): hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) - elif ( - params.decoding_method == "greedy_search" - and params.max_sym_per_frame == 1 - ): + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -484,9 +478,7 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -519,8 +511,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -589,9 +580,9 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -618,9 +609,9 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -720,8 +711,7 @@ def main(): ) dev_shards = [ - str(path) - for path in sorted(glob.glob(os.path.join(dev, "shared-*.tar"))) + str(path) for path in sorted(glob.glob(os.path.join(dev, "shared-*.tar"))) ] cuts_dev_webdataset = CutSet.from_webdataset( dev_shards, @@ -731,8 +721,7 @@ def main(): ) test_net_shards = [ - str(path) - for path in sorted(glob.glob(os.path.join(test_net, "shared-*.tar"))) + str(path) for path in sorted(glob.glob(os.path.join(test_net, "shared-*.tar"))) ] cuts_test_net_webdataset = CutSet.from_webdataset( test_net_shards, @@ -743,9 +732,7 @@ def main(): test_meeting_shards = [ str(path) - for path in sorted( - glob.glob(os.path.join(test_meeting, "shared-*.tar")) - ) + for path in sorted(glob.glob(os.path.join(test_meeting, "shared-*.tar"))) ] cuts_test_meeting_webdataset = CutSet.from_webdataset( test_meeting_shards, diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode_stream.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode_stream.py index 3862485546..e522943c07 100644 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode_stream.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode_stream.py @@ -75,9 +75,7 @@ def __init__( # encoder.streaming_forward self.done_frames: int = 0 - self.pad_length = ( - params.right_context + 2 - ) * params.subsampling_factor + 3 + self.pad_length = (params.right_context + 2) * params.subsampling_factor + 3 if params.decoding_method == "greedy_search": self.hyp = [params.blank_id] * params.context_size @@ -91,13 +89,11 @@ def __init__( ) elif params.decoding_method == "fast_beam_search": # The rnnt_decoding_stream for fast_beam_search. - self.rnnt_decoding_stream: k2.RnntDecodingStream = ( - k2.RnntDecodingStream(decoding_graph) + self.rnnt_decoding_stream: k2.RnntDecodingStream = k2.RnntDecodingStream( + decoding_graph ) else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") @property def done(self) -> bool: @@ -126,13 +122,10 @@ def get_feature_frames(self, chunk_size: int) -> Tuple[torch.Tensor, int]: """Consume chunk_size frames of features""" chunk_length = chunk_size + self.pad_length - ret_length = min( - self.num_frames - self.num_processed_frames, chunk_length - ) + ret_length = min(self.num_frames - self.num_processed_frames, chunk_length) ret_features = self.features[ - self.num_processed_frames : self.num_processed_frames # noqa - + ret_length + self.num_processed_frames : self.num_processed_frames + ret_length # noqa ] self.num_processed_frames += chunk_size diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/export.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/export.py index d0a7fd69f6..ff2c4db38a 100644 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/export.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/export.py @@ -131,8 +131,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) add_model_arguments(parser) @@ -201,9 +200,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/pretrained.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/pretrained.py index 1b064c874c..7e4829a609 100644 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/pretrained.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/pretrained.py @@ -157,8 +157,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( @@ -190,8 +189,7 @@ def read_sound_files( for f in filenames: wave, sample_rate = torchaudio.load(f) assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" + f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}" ) # We use only the first channel ans.append(wave[0]) @@ -253,9 +251,7 @@ def main(): features = fbank(waves) feature_lengths = [f.size(0) for f in features] - features = pad_sequence( - features, batch_first=True, padding_value=math.log(1e-10) - ) + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) feature_lengths = torch.tensor(feature_lengths, device=device) @@ -280,10 +276,7 @@ def main(): ) for i in range(encoder_out.size(0)): hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) - elif ( - params.decoding_method == "greedy_search" - and params.max_sym_per_frame == 1 - ): + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -335,9 +328,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/streaming_beam_search.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/streaming_beam_search.py index 651aff6c93..810d941354 100644 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/streaming_beam_search.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/streaming_beam_search.py @@ -173,14 +173,10 @@ def modified_beam_search( log_probs_shape = k2.ragged.create_ragged_shape2( row_splits=row_splits, cached_tot_size=log_probs.numel() ) - ragged_log_probs = k2.RaggedTensor( - shape=log_probs_shape, value=log_probs - ) + ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) for i in range(batch_size): - topk_log_probs, topk_indexes = ragged_log_probs[i].topk( - num_active_paths - ) + topk_log_probs, topk_indexes = ragged_log_probs[i].topk(num_active_paths) with warnings.catch_warnings(): warnings.simplefilter("ignore") diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/streaming_decode.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/streaming_decode.py index ff96c64875..6909f40bef 100644 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/streaming_decode.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/streaming_decode.py @@ -201,8 +201,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( @@ -311,9 +310,7 @@ def decode_one_chunk( encoder_out = model.joiner.encoder_proj(encoder_out) if params.decoding_method == "greedy_search": - greedy_search( - model=model, encoder_out=encoder_out, streams=decode_streams - ) + greedy_search(model=model, encoder_out=encoder_out, streams=decode_streams) elif params.decoding_method == "fast_beam_search": processed_lens = processed_lens + encoder_out_lens fast_beam_search_one_best( @@ -333,9 +330,7 @@ def decode_one_chunk( num_active_paths=params.num_active_paths, ) else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") states = [torch.unbind(states[0], dim=2), torch.unbind(states[1], dim=2)] @@ -389,9 +384,7 @@ def decode_dataset( decode_results = [] # Contain decode streams currently running. decode_streams = [] - initial_states = model.encoder.get_init_state( - params.left_context, device=device - ) + initial_states = model.encoder.get_init_state(params.left_context, device=device) for num, cut in enumerate(cuts): # each utterance has a DecodeStream. decode_stream = DecodeStream( @@ -461,9 +454,7 @@ def decode_dataset( elif params.decoding_method == "modified_beam_search": key = f"num_active_paths_{params.num_active_paths}" else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") return {key: decode_results} @@ -499,8 +490,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -565,9 +555,9 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -594,9 +584,9 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/train.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/train.py index 2052e9da76..5f614e77c4 100755 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/train.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/train.py @@ -98,9 +98,7 @@ from icefall.lexicon import Lexicon from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool -LRSchedulerType = Union[ - torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler -] +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] def add_model_arguments(parser: argparse.ArgumentParser): @@ -260,8 +258,7 @@ def get_parser(): "--initial-lr", type=float, default=0.003, - help="The initial learning rate. This value should not need " - "to be changed.", + help="The initial learning rate. This value should not need " "to be changed.", ) parser.add_argument( @@ -284,8 +281,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( @@ -308,8 +304,7 @@ def get_parser(): "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)" - "part.", + help="The scale to smooth the loss with am (output of encoder network)" "part.", ) parser.add_argument( @@ -665,11 +660,7 @@ def compute_loss( warmup: a floating point value which increases throughout training; values >= 1.0 are fully warmed up and have all modules present. """ - device = ( - model.device - if isinstance(model, DDP) - else next(model.parameters()).device - ) + device = model.device if isinstance(model, DDP) else next(model.parameters()).device feature = batch["inputs"] # at entry, feature is (N, T, C) assert feature.ndim == 3 @@ -701,23 +692,16 @@ def compute_loss( # overwhelming the simple_loss and causing it to diverge, # in case it had not fully learned the alignment yet. pruned_loss_scale = ( - 0.0 - if warmup < 1.0 - else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) - ) - loss = ( - params.simple_loss_scale * simple_loss - + pruned_loss_scale * pruned_loss + 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) ) + loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss assert loss.requires_grad == is_training info = MetricsTracker() with warnings.catch_warnings(): warnings.simplefilter("ignore") - info["frames"] = ( - (feature_lens // params.subsampling_factor).sum().item() - ) + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() # Note: We use reduction=sum while computing the loss. info["loss"] = loss.detach().cpu().item() @@ -841,9 +825,7 @@ def train_one_epoch( scaler.update() optimizer.zero_grad() except: # noqa - display_and_save_batch( - batch, params=params, graph_compiler=graph_compiler - ) + display_and_save_batch(batch, params=params, graph_compiler=graph_compiler) raise if params.print_diagnostics and batch_idx == 5: @@ -901,9 +883,7 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -1016,7 +996,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2 ** 22 + 2**22 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) @@ -1184,9 +1164,7 @@ def scan_pessimistic_batches_for_oom( f"Failing criterion: {criterion} " f"(={crit_values[criterion]}) ..." ) - display_and_save_batch( - batch, params=params, graph_compiler=graph_compiler - ) + display_and_save_batch(batch, params=params, graph_compiler=graph_compiler) raise diff --git a/egs/yesno/ASR/local/compile_hlg.py b/egs/yesno/ASR/local/compile_hlg.py index f83be05cfb..7234ca929e 100755 --- a/egs/yesno/ASR/local/compile_hlg.py +++ b/egs/yesno/ASR/local/compile_hlg.py @@ -128,9 +128,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) diff --git a/egs/yesno/ASR/local/compute_fbank_yesno.py b/egs/yesno/ASR/local/compute_fbank_yesno.py index 9a4e8a36f6..75d95df687 100755 --- a/egs/yesno/ASR/local/compute_fbank_yesno.py +++ b/egs/yesno/ASR/local/compute_fbank_yesno.py @@ -54,9 +54,7 @@ def compute_fbank_yesno(): dataset_parts, ) - extractor = Fbank( - FbankConfig(sampling_rate=8000, num_mel_bins=num_mel_bins) - ) + extractor = Fbank(FbankConfig(sampling_rate=8000, num_mel_bins=num_mel_bins)) with get_executor() as ex: # Initialize the executor only once. for partition, m in manifests.items(): @@ -71,9 +69,7 @@ def compute_fbank_yesno(): ) if "train" in partition: cut_set = ( - cut_set - + cut_set.perturb_speed(0.9) - + cut_set.perturb_speed(1.1) + cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) ) cut_set = cut_set.compute_and_store_features( extractor=extractor, @@ -87,9 +83,7 @@ def compute_fbank_yesno(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) diff --git a/egs/yesno/ASR/tdnn/decode.py b/egs/yesno/ASR/tdnn/decode.py index 9d4ab4b613..d5efb41df9 100755 --- a/egs/yesno/ASR/tdnn/decode.py +++ b/egs/yesno/ASR/tdnn/decode.py @@ -201,9 +201,7 @@ def decode_dataset( if batch_idx % 100 == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -274,9 +272,7 @@ def main(): logging.info(f"device: {device}") - HLG = k2.Fsa.from_dict( - torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu") - ) + HLG = k2.Fsa.from_dict(torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu")) HLG = HLG.to(device) assert HLG.requires_grad is False @@ -297,9 +293,7 @@ def main(): if params.export: logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt") - torch.save( - {"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt" - ) + torch.save({"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt") return model.to(device) @@ -317,9 +311,7 @@ def main(): word_table=lexicon.word_table, ) - save_results( - exp_dir=params.exp_dir, test_set_name="test_set", results=results - ) + save_results(exp_dir=params.exp_dir, test_set_name="test_set", results=results) logging.info("Done!") diff --git a/egs/yesno/ASR/tdnn/pretrained.py b/egs/yesno/ASR/tdnn/pretrained.py index 14220be192..88d5eca5d4 100755 --- a/egs/yesno/ASR/tdnn/pretrained.py +++ b/egs/yesno/ASR/tdnn/pretrained.py @@ -53,9 +53,7 @@ def get_parser(): help="Path to words.txt", ) - parser.add_argument( - "--HLG", type=str, required=True, help="Path to HLG.pt." - ) + parser.add_argument("--HLG", type=str, required=True, help="Path to HLG.pt.") parser.add_argument( "sound_files", @@ -102,8 +100,7 @@ def read_sound_files( for f in filenames: wave, sample_rate = torchaudio.load(f) assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" + f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}" ) # We use only the first channel ans.append(wave[0]) @@ -159,9 +156,7 @@ def main(): logging.info("Decoding started") features = fbank(waves) - features = pad_sequence( - features, batch_first=True, padding_value=math.log(1e-10) - ) + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) # Note: We don't use key padding mask for attention during decoding with torch.no_grad(): @@ -201,9 +196,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/yesno/ASR/tdnn/train.py b/egs/yesno/ASR/tdnn/train.py index f32a27f356..335493491a 100755 --- a/egs/yesno/ASR/tdnn/train.py +++ b/egs/yesno/ASR/tdnn/train.py @@ -430,9 +430,7 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) if batch_idx > 0 and batch_idx % params.valid_interval == 0: valid_info = compute_validation_loss( diff --git a/egs/yesno/ASR/transducer/decode.py b/egs/yesno/ASR/transducer/decode.py index 6714180db8..7f13e417af 100755 --- a/egs/yesno/ASR/transducer/decode.py +++ b/egs/yesno/ASR/transducer/decode.py @@ -116,9 +116,7 @@ def decode_one_batch( # at entry, feature is (N, T, C) feature_lens = batch["supervisions"]["num_frames"].to(device) - encoder_out, encoder_out_lens = model.encoder( - x=feature, x_lens=feature_lens - ) + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) hyps = [] batch_size = encoder_out.size(0) @@ -186,9 +184,7 @@ def decode_dataset( if batch_idx % 100 == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -303,9 +299,7 @@ def main(): model=model, ) - save_results( - exp_dir=params.exp_dir, test_set_name="test_set", results=results - ) + save_results(exp_dir=params.exp_dir, test_set_name="test_set", results=results) logging.info("Done!") diff --git a/egs/yesno/ASR/transducer/train.py b/egs/yesno/ASR/transducer/train.py index deb92107dd..88866ae815 100755 --- a/egs/yesno/ASR/transducer/train.py +++ b/egs/yesno/ASR/transducer/train.py @@ -430,9 +430,7 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) if batch_idx > 0 and batch_idx % params.valid_interval == 0: valid_info = compute_validation_loss( diff --git a/icefall/char_graph_compiler.py b/icefall/char_graph_compiler.py index 235160e149..c31db6e4c3 100644 --- a/icefall/char_graph_compiler.py +++ b/icefall/char_graph_compiler.py @@ -71,9 +71,7 @@ def texts_to_ids(self, texts: List[str]) -> List[List[int]]: for text in texts: text = re.sub(whitespace, "", text) sub_ids = [ - self.token_table[txt] - if txt in self.token_table - else self.oov_id + self.token_table[txt] if txt in self.token_table else self.oov_id for txt in text ] ids.append(sub_ids) @@ -96,9 +94,7 @@ def texts_to_ids_with_bpe(self, texts: List[str]) -> List[List[int]]: for text in texts: text = text.split("/") sub_ids = [ - self.token_table[txt] - if txt in self.token_table - else self.oov_id + self.token_table[txt] if txt in self.token_table else self.oov_id for txt in text ] ids.append(sub_ids) diff --git a/icefall/checkpoint.py b/icefall/checkpoint.py index 5069b78e82..8aa0a8eeb1 100644 --- a/icefall/checkpoint.py +++ b/icefall/checkpoint.py @@ -292,15 +292,11 @@ def find_checkpoints(out_dir: Path, iteration: int = 0) -> List[str]: """ checkpoints = list(glob.glob(f"{out_dir}/checkpoint-[0-9]*.pt")) pattern = re.compile(r"checkpoint-([0-9]+).pt") - iter_checkpoints = [ - (int(pattern.search(c).group(1)), c) for c in checkpoints - ] + iter_checkpoints = [(int(pattern.search(c).group(1)), c) for c in checkpoints] # iter_checkpoints is a list of tuples. Each tuple contains # two elements: (iteration_number, checkpoint-iteration_number.pt) - iter_checkpoints = sorted( - iter_checkpoints, reverse=True, key=lambda x: x[0] - ) + iter_checkpoints = sorted(iter_checkpoints, reverse=True, key=lambda x: x[0]) if iteration >= 0: ans = [ic[1] for ic in iter_checkpoints if ic[0] >= iteration] else: @@ -469,7 +465,5 @@ def average_state_dict( v = state_dict_1[k] if torch.is_floating_point(v): v *= weight_1 - v += ( - state_dict_2[k].to(device=state_dict_1[k].device) * weight_2 - ) + v += state_dict_2[k].to(device=state_dict_1[k].device) * weight_2 v *= scaling_factor diff --git a/icefall/decode.py b/icefall/decode.py index 099e2d171b..e4c614c4e1 100644 --- a/icefall/decode.py +++ b/icefall/decode.py @@ -334,13 +334,9 @@ def intersect(self, lattice: k2.Fsa, use_double_scores=True) -> "Nbest": if hasattr(lattice, "aux_labels"): # delete token IDs as it is not needed del word_fsa.aux_labels - word_fsa_with_epsilon_loops = k2.linear_fsa_with_self_loops( - word_fsa - ) + word_fsa_with_epsilon_loops = k2.linear_fsa_with_self_loops(word_fsa) else: - word_fsa_with_epsilon_loops = k2.linear_fst_with_self_loops( - word_fsa - ) + word_fsa_with_epsilon_loops = k2.linear_fst_with_self_loops(word_fsa) path_to_utt_map = self.shape.row_ids(1) @@ -370,9 +366,7 @@ def intersect(self, lattice: k2.Fsa, use_double_scores=True) -> "Nbest": # path_lattice has word IDs as labels and token IDs as aux_labels path_lattice = k2.top_sort(k2.connect(path_lattice)) - one_best = k2.shortest_path( - path_lattice, use_double_scores=use_double_scores - ) + one_best = k2.shortest_path(path_lattice, use_double_scores=use_double_scores) one_best = k2.invert(one_best) # Now one_best has token IDs as labels and word IDs as aux_labels @@ -442,9 +436,7 @@ def tot_scores(self) -> k2.RaggedTensor: scores_shape = self.fsa.arcs.shape().remove_axis(1) # scores_shape has axes [path][arc] - ragged_scores = k2.RaggedTensor( - scores_shape, self.fsa.scores.contiguous() - ) + ragged_scores = k2.RaggedTensor(scores_shape, self.fsa.scores.contiguous()) tot_scores = ragged_scores.sum() @@ -483,9 +475,7 @@ def one_best_decoding( am_scores = saved_am_scores / lm_scale lattice.scores = am_scores + lattice.lm_scores - best_path = k2.shortest_path( - lattice, use_double_scores=use_double_scores - ) + best_path = k2.shortest_path(lattice, use_double_scores=use_double_scores) key = f"lm_scale_{lm_scale}" ans[key] = best_path return ans @@ -696,9 +686,7 @@ def rescore_with_n_best_list( logging.info(f"num_paths before decreasing: {num_paths}") num_paths = int(num_paths / 2) if loop_count >= max_loop_count or num_paths <= 0: - logging.info( - "Return None as the resulting lattice is too large." - ) + logging.info("Return None as the resulting lattice is too large.") return None logging.info( "This OOM is not an error. You can ignore it. " @@ -805,13 +793,9 @@ def rescore_with_whole_lattice( except RuntimeError as e: logging.info(f"Caught exception:\n{e}\n") if loop_count >= max_loop_count: - logging.info( - "Return None as the resulting lattice is too large." - ) + logging.info("Return None as the resulting lattice is too large.") return None - logging.info( - f"num_arcs before pruning: {inv_lattice.arcs.num_elements()}" - ) + logging.info(f"num_arcs before pruning: {inv_lattice.arcs.num_elements()}") logging.info( "This OOM is not an error. You can ignore it. " "If your model does not converge well, or --max-duration " @@ -823,9 +807,7 @@ def rescore_with_whole_lattice( prune_th_list[loop_count], True, ) - logging.info( - f"num_arcs after pruning: {inv_lattice.arcs.num_elements()}" - ) + logging.info(f"num_arcs after pruning: {inv_lattice.arcs.num_elements()}") loop_count += 1 # lat has token IDs as labels @@ -912,9 +894,7 @@ def rescore_with_attention_decoder( logging.info(f"num_paths before decreasing: {num_paths}") num_paths = int(num_paths / 2) if loop_count >= max_loop_count or num_paths <= 0: - logging.info( - "Return None as the resulting lattice is too large." - ) + logging.info("Return None as the resulting lattice is too large.") return None logging.info( "This OOM is not an error. You can ignore it. " diff --git a/icefall/diagnostics.py b/icefall/diagnostics.py index b075aceac3..207c12bf16 100644 --- a/icefall/diagnostics.py +++ b/icefall/diagnostics.py @@ -19,7 +19,7 @@ import random from dataclasses import dataclass -from typing import Optional, Tuple, List +from typing import List, Optional, Tuple import torch from torch import Tensor, nn @@ -78,11 +78,11 @@ def get_tensor_stats( elif stats_type == "abs": x = x.abs() elif stats_type == "rms": - x = x ** 2 + x = x**2 elif stats_type == "positive": x = (x > 0).to(dtype=torch.float) else: - assert stats_type in [ "value", "max", "min" ] + assert stats_type in ["value", "max", "min"] sum_dims = [d for d in range(x.ndim) if d != dim] if len(sum_dims) > 0: @@ -121,7 +121,9 @@ def __init__(self, opts: TensorDiagnosticOptions, name: str): self.name = name self.class_name = None # will assign in accumulate() - self.stats = None # we'll later assign a list to this data member. It's a list of dict. + self.stats = ( + None # we'll later assign a list to this data member. It's a list of dict. + ) # the keys into self.stats[dim] are strings, whose values can be # "abs", "max", "min" ,"value", "positive", "rms", "value". @@ -133,7 +135,6 @@ def __init__(self, opts: TensorDiagnosticOptions, name: str): # only adding a new element to the list if there was a different dim. # if the string in the key is "eigs", if we detect a length mismatch we put None as the value. - def accumulate(self, x, class_name: Optional[str] = None): """ Accumulate tensors. @@ -185,17 +186,12 @@ def accumulate(self, x, class_name: Optional[str] = None): done = True break if not done: - if ( - this_dim_stats[stats_type] != [] - and stats_type == "eigs" - ): + if this_dim_stats[stats_type] != [] and stats_type == "eigs": # >1 size encountered on this dim, e.g. it's a batch or time dimension, # don't accumulat "eigs" stats type, it uses too much memory this_dim_stats[stats_type] = None else: - this_dim_stats[stats_type].append( - TensorAndCount(stats, count) - ) + this_dim_stats[stats_type].append(TensorAndCount(stats, count)) def print_diagnostics(self): """Print diagnostics for each dimension of the tensor.""" @@ -211,7 +207,6 @@ def print_diagnostics(self): assert stats_type == "eigs" continue - def get_count(count): return 1 if stats_type in ["max", "min"] else count @@ -229,9 +224,7 @@ def get_count(count): eigs, _ = torch.symeig(stats) stats = eigs.abs().sqrt() except: # noqa - print( - "Error getting eigenvalues, trying another method." - ) + print("Error getting eigenvalues, trying another method.") eigs, _ = torch.eig(stats) stats = eigs.abs().sqrt() # sqrt so it reflects data magnitude, like stddev- not variance @@ -242,9 +235,9 @@ def get_count(count): # if `summarize` we print percentiles of the stats; else, # we print out individual elements. - summarize = ( - len(stats_list) > 1 - ) or self.opts.dim_is_summarized(stats.numel()) + summarize = (len(stats_list) > 1) or self.opts.dim_is_summarized( + stats.numel() + ) if summarize: # usually `summarize` will be true # print out percentiles. stats = stats.sort()[0] @@ -261,15 +254,15 @@ def get_count(count): ans = stats.tolist() ans = ["%.2g" % x for x in ans] ans = "[" + " ".join(ans) + "]" - if stats_type in [ "value", "rms", "eigs" ]: + if stats_type in ["value", "rms", "eigs"]: # This norm is useful because it is strictly less than the largest # sqrt(eigenvalue) of the variance, which we print out, and shows, # speaking in an approximate way, how much of that largest eigenvalue # can be attributed to the mean of the distribution. - norm = (stats ** 2).sum().sqrt().item() + norm = (stats**2).sum().sqrt().item() ans += f", norm={norm:.2g}" mean = stats.mean().item() - rms = (stats ** 2).mean().sqrt().item() + rms = (stats**2).mean().sqrt().item() ans += f", mean={mean:.3g}, rms={rms:.3g}" # OK, "ans" contains the actual stats, e.g. @@ -277,17 +270,16 @@ def get_count(count): sizes = [x.tensor.shape[0] for x in stats_list] size_str = ( - f"{sizes[0]}" - if len(sizes) == 1 - else f"{min(sizes)}..{max(sizes)}" + f"{sizes[0]}" if len(sizes) == 1 else f"{min(sizes)}..{max(sizes)}" + ) + maybe_class_name = ( + f" type={self.class_name}," if self.class_name is not None else "" ) - maybe_class_name = f" type={self.class_name}," if self.class_name is not None else "" print( f"module={self.name},{maybe_class_name} dim={dim}, size={size_str}, {stats_type} {ans}" ) - class ModelDiagnostic(object): """This class stores diagnostics for all tensors in the torch.nn.Module. @@ -345,32 +337,32 @@ def attach_diagnostics( # (matters for name, since the variable gets overwritten). # These closures don't really capture by value, only by # "the final value the variable got in the function" :-( - def forward_hook( - _module, _input, _output, _model_diagnostic=ans, _name=name - ): + def forward_hook(_module, _input, _output, _model_diagnostic=ans, _name=name): if isinstance(_output, tuple) and len(_output) == 1: _output = _output[0] if isinstance(_output, Tensor): - _model_diagnostic[f"{_name}.output"].accumulate(_output, - class_name=type(_module).__name__) + _model_diagnostic[f"{_name}.output"].accumulate( + _output, class_name=type(_module).__name__ + ) elif isinstance(_output, tuple): for i, o in enumerate(_output): - _model_diagnostic[f"{_name}.output[{i}]"].accumulate(o, - class_name=type(_module).__name__) + _model_diagnostic[f"{_name}.output[{i}]"].accumulate( + o, class_name=type(_module).__name__ + ) - def backward_hook( - _module, _input, _output, _model_diagnostic=ans, _name=name - ): + def backward_hook(_module, _input, _output, _model_diagnostic=ans, _name=name): if isinstance(_output, tuple) and len(_output) == 1: _output = _output[0] if isinstance(_output, Tensor): - _model_diagnostic[f"{_name}.grad"].accumulate(_output, - class_name=type(_module).__name__) + _model_diagnostic[f"{_name}.grad"].accumulate( + _output, class_name=type(_module).__name__ + ) elif isinstance(_output, tuple): for i, o in enumerate(_output): - _model_diagnostic[f"{_name}.grad[{i}]"].accumulate(o, - class_name=type(_module).__name__) + _model_diagnostic[f"{_name}.grad[{i}]"].accumulate( + o, class_name=type(_module).__name__ + ) module.register_forward_hook(forward_hook) module.register_backward_hook(backward_hook) diff --git a/icefall/dist.py b/icefall/dist.py index 7016beafb0..9df1c5bd13 100644 --- a/icefall/dist.py +++ b/icefall/dist.py @@ -29,9 +29,7 @@ def setup_dist(rank, world_size, master_port=None, use_ddp_launch=False): os.environ["MASTER_ADDR"] = "localhost" if "MASTER_PORT" not in os.environ: - os.environ["MASTER_PORT"] = ( - "12354" if master_port is None else str(master_port) - ) + os.environ["MASTER_PORT"] = "12354" if master_port is None else str(master_port) if use_ddp_launch is False: dist.init_process_group("nccl", rank=rank, world_size=world_size) diff --git a/icefall/env.py b/icefall/env.py index 8aeda6be2f..373e9a9ff5 100644 --- a/icefall/env.py +++ b/icefall/env.py @@ -53,9 +53,7 @@ def get_git_sha1(): ) > 0 ) - git_commit = ( - git_commit + "-dirty" if dirty_commit else git_commit + "-clean" - ) + git_commit = git_commit + "-dirty" if dirty_commit else git_commit + "-clean" except: # noqa return None diff --git a/icefall/graph_compiler.py b/icefall/graph_compiler.py index 570ed7d7a3..e2ff03f614 100644 --- a/icefall/graph_compiler.py +++ b/icefall/graph_compiler.py @@ -75,9 +75,7 @@ def compile(self, texts: List[str]) -> k2.Fsa: # NOTE: k2.compose runs on CUDA only when treat_epsilons_specially # is False, so we add epsilon self-loops here - fsa_with_self_loops = k2.remove_epsilon_and_add_self_loops( - transcript_fsa - ) + fsa_with_self_loops = k2.remove_epsilon_and_add_self_loops(transcript_fsa) fsa_with_self_loops = k2.arc_sort(fsa_with_self_loops) diff --git a/icefall/hooks.py b/icefall/hooks.py index fbcf5e1485..398a5f689d 100644 --- a/icefall/hooks.py +++ b/icefall/hooks.py @@ -14,10 +14,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging import random + import torch from torch import Tensor, nn -import logging def register_inf_check_hooks(model: nn.Module) -> None: @@ -56,7 +57,7 @@ def backward_hook(_module, _input, _output, _name=name): if isinstance(_output, Tensor): if not torch.isfinite(_output.to(torch.float32).sum()): logging.warning( - f"The sum of {_name}.grad is not finite" # ": {_output}" + f"The sum of {_name}.grad is not finite" # ": {_output}" ) elif isinstance(_output, tuple): for i, o in enumerate(_output): @@ -65,28 +66,20 @@ def backward_hook(_module, _input, _output, _name=name): if not isinstance(o, Tensor): continue if not torch.isfinite(o.to(torch.float32).sum()): - logging.warning( - f"The sum of {_name}.grad[{i}] is not finite" - ) + logging.warning(f"The sum of {_name}.grad[{i}] is not finite") module.register_forward_hook(forward_hook) module.register_backward_hook(backward_hook) - for name, parameter in model.named_parameters(): - def param_backward_hook( - grad, _name=name - ): + def param_backward_hook(grad, _name=name): if not torch.isfinite(grad.to(torch.float32).sum()): - logging.warning( - f"The sum of {_name}.param_grad is not finite" - ) + logging.warning(f"The sum of {_name}.param_grad is not finite") parameter.register_hook(param_backward_hook) - def _test_inf_check_hooks(): model = nn.Sequential(nn.Linear(100, 50), nn.Linear(50, 80)) diff --git a/icefall/lexicon.py b/icefall/lexicon.py index 80bd7c1ee8..22e1b78bbc 100644 --- a/icefall/lexicon.py +++ b/icefall/lexicon.py @@ -49,18 +49,12 @@ def read_lexicon(filename: str) -> List[Tuple[str, List[str]]]: continue if len(a) < 2: - logging.info( - f"Found bad line {line} in lexicon file {filename}" - ) - logging.info( - "Every line is expected to contain at least 2 fields" - ) + logging.info(f"Found bad line {line} in lexicon file {filename}") + logging.info("Every line is expected to contain at least 2 fields") sys.exit(1) word = a[0] if word == " ": - logging.info( - f"Found bad line {line} in lexicon file {filename}" - ) + logging.info(f"Found bad line {line} in lexicon file {filename}") logging.info(" should not be a valid word") sys.exit(1) @@ -119,9 +113,7 @@ def convert_lexicon_to_ragged( lexicon_tmp = read_lexicon(filename) lexicon = dict(lexicon_tmp) if len(lexicon_tmp) != len(lexicon): - raise RuntimeError( - "It's assumed that each word has a unique pronunciation" - ) + raise RuntimeError("It's assumed that each word has a unique pronunciation") for i in range(disambig_id): w = word_table[i] diff --git a/icefall/mmi.py b/icefall/mmi.py index 2c479fc2c9..16ed6e032e 100644 --- a/icefall/mmi.py +++ b/icefall/mmi.py @@ -63,10 +63,7 @@ def _compute_mmi_loss_exact_optimized( # [0, num_fsas, 1, num_fsas, 2, num_fsas, ... ] num_den_graphs_indexes = ( - torch.stack([num_graphs_indexes, den_graphs_indexes]) - .t() - .reshape(-1) - .to(device) + torch.stack([num_graphs_indexes, den_graphs_indexes]).t().reshape(-1).to(device) ) num_den_reordered_graphs = k2.index(num_den_graphs, num_den_graphs_indexes) @@ -115,20 +112,12 @@ def _compute_mmi_loss_exact_non_optimized( num_graphs, den_graphs = graph_compiler.compile(texts, replicate_den=True) # TODO: pass output_beam as function argument - num_lats = k2.intersect_dense( - num_graphs, dense_fsa_vec, output_beam=beam_size - ) - den_lats = k2.intersect_dense( - den_graphs, dense_fsa_vec, output_beam=beam_size - ) + num_lats = k2.intersect_dense(num_graphs, dense_fsa_vec, output_beam=beam_size) + den_lats = k2.intersect_dense(den_graphs, dense_fsa_vec, output_beam=beam_size) - num_tot_scores = num_lats.get_tot_scores( - log_semiring=True, use_double_scores=True - ) + num_tot_scores = num_lats.get_tot_scores(log_semiring=True, use_double_scores=True) - den_tot_scores = den_lats.get_tot_scores( - log_semiring=True, use_double_scores=True - ) + den_tot_scores = den_lats.get_tot_scores(log_semiring=True, use_double_scores=True) tot_scores = num_tot_scores - den_scale * den_tot_scores @@ -168,13 +157,9 @@ def _compute_mmi_loss_pruned( max_active_states=10000, ) - num_tot_scores = num_lats.get_tot_scores( - log_semiring=True, use_double_scores=True - ) + num_tot_scores = num_lats.get_tot_scores(log_semiring=True, use_double_scores=True) - den_tot_scores = den_lats.get_tot_scores( - log_semiring=True, use_double_scores=True - ) + den_tot_scores = den_lats.get_tot_scores(log_semiring=True, use_double_scores=True) tot_scores = num_tot_scores - den_scale * den_tot_scores diff --git a/icefall/mmi_graph_compiler.py b/icefall/mmi_graph_compiler.py index 0d901227d0..9f680f83da 100644 --- a/icefall/mmi_graph_compiler.py +++ b/icefall/mmi_graph_compiler.py @@ -137,9 +137,7 @@ def compile( transcript_fsa ) - transcript_fsa_with_self_loops = k2.arc_sort( - transcript_fsa_with_self_loops - ) + transcript_fsa_with_self_loops = k2.arc_sort(transcript_fsa_with_self_loops) num = k2.compose( self.ctc_topo_P, @@ -155,9 +153,7 @@ def compile( ctc_topo_P_vec = k2.create_fsa_vec([self.ctc_topo_P]) if replicate_den: - indexes = torch.zeros( - len(texts), dtype=torch.int32, device=self.device - ) + indexes = torch.zeros(len(texts), dtype=torch.int32, device=self.device) den = k2.index_fsa(ctc_topo_P_vec, indexes) else: den = ctc_topo_P_vec diff --git a/icefall/rnn_lm/dataset.py b/icefall/rnn_lm/dataset.py index 598e329c4c..4bf9825034 100644 --- a/icefall/rnn_lm/dataset.py +++ b/icefall/rnn_lm/dataset.py @@ -155,12 +155,8 @@ def __call__( sentence_tokens_with_sos = add_sos(sentence_tokens, self.sos_id) sentence_tokens_with_eos = add_eos(sentence_tokens, self.eos_id) - x = sentence_tokens_with_sos.pad( - mode="constant", padding_value=self.blank_id - ) - y = sentence_tokens_with_eos.pad( - mode="constant", padding_value=self.blank_id - ) + x = sentence_tokens_with_sos.pad(mode="constant", padding_value=self.blank_id) + y = sentence_tokens_with_eos.pad(mode="constant", padding_value=self.blank_id) sentence_token_lengths += 1 # plus 1 since we added a SOS return x.to(torch.int64), y.to(torch.int64), sentence_token_lengths diff --git a/icefall/rnn_lm/export.py b/icefall/rnn_lm/export.py index 094035fcea..2411cb1f0b 100644 --- a/icefall/rnn_lm/export.py +++ b/icefall/rnn_lm/export.py @@ -159,9 +159,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/icefall/rnn_lm/model.py b/icefall/rnn_lm/model.py index a6144727a9..9eef88840a 100644 --- a/icefall/rnn_lm/model.py +++ b/icefall/rnn_lm/model.py @@ -129,9 +129,7 @@ def predict_batch(self, tokens, token_lens, sos_id, eos_id, blank_id): tokens_eos = add_eos(tokens, eos_id) sos_tokens_row_splits = sos_tokens.shape.row_splits(1) - sentence_lengths = ( - sos_tokens_row_splits[1:] - sos_tokens_row_splits[:-1] - ) + sentence_lengths = sos_tokens_row_splits[1:] - sos_tokens_row_splits[:-1] x_tokens = sos_tokens.pad(mode="constant", padding_value=blank_id) y_tokens = tokens_eos.pad(mode="constant", padding_value=blank_id) @@ -161,12 +159,12 @@ def score_token(self, tokens: torch.Tensor, state=None): if state: h, c = state else: - h = torch.zeros( - self.rnn.num_layers, batch_size, self.rnn.input_size - ).to(device) - c = torch.zeros( - self.rnn.num_layers, batch_size, self.rnn.input_size - ).to(device) + h = torch.zeros(self.rnn.num_layers, batch_size, self.rnn.input_size).to( + device + ) + c = torch.zeros(self.rnn.num_layers, batch_size, self.rnn.input_size).to( + device + ) embedding = self.input_embedding(tokens) rnn_out, states = self.rnn(embedding, (h, c)) @@ -181,12 +179,8 @@ def forward_with_state( if state: h, c = state else: - h = torch.zeros( - self.rnn.num_layers, batch_size, self.rnn.input_size - ) - c = torch.zeros( - self.rnn.num_layers, batch_size, self.rnn.input_size - ) + h = torch.zeros(self.rnn.num_layers, batch_size, self.rnn.input_size) + c = torch.zeros(self.rnn.num_layers, batch_size, self.rnn.input_size) device = next(self.parameters()).device @@ -194,9 +188,7 @@ def forward_with_state( tokens_eos = add_eos(tokens, eos_id) sos_tokens_row_splits = sos_tokens.shape.row_splits(1) - sentence_lengths = ( - sos_tokens_row_splits[1:] - sos_tokens_row_splits[:-1] - ) + sentence_lengths = sos_tokens_row_splits[1:] - sos_tokens_row_splits[:-1] x_tokens = sos_tokens.pad(mode="constant", padding_value=blank_id) y_tokens = tokens_eos.pad(mode="constant", padding_value=blank_id) diff --git a/icefall/rnn_lm/train.py b/icefall/rnn_lm/train.py index bb5f03fb91..3ba5bfbee7 100755 --- a/icefall/rnn_lm/train.py +++ b/icefall/rnn_lm/train.py @@ -446,17 +446,13 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) tb_writer.add_scalar( "train/current_ppl", this_batch_ppl, params.batch_idx_train ) - tb_writer.add_scalar( - "train/tot_ppl", tot_ppl, params.batch_idx_train - ) + tb_writer.add_scalar("train/tot_ppl", tot_ppl, params.batch_idx_train) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") diff --git a/icefall/shared/make_kn_lm.py b/icefall/shared/make_kn_lm.py index c2edd823ee..b1220d55e7 100755 --- a/icefall/shared/make_kn_lm.py +++ b/icefall/shared/make_kn_lm.py @@ -15,30 +15,43 @@ # The data structure is based on: kaldi/egs/wsj/s5/utils/lang/make_phone_lm.py # The smoothing algorithm is based on: http://www.speech.sri.com/projects/srilm/manpages/ngram-discount.7.html -import sys -import os -import re +import argparse import io import math -import argparse +import os +import re +import sys from collections import Counter, defaultdict - -parser = argparse.ArgumentParser(description=""" +parser = argparse.ArgumentParser( + description=""" Generate kneser-ney language model as arpa format. By default, it will read the corpus from standard input, and output to standard output. - """) -parser.add_argument("-ngram-order", type=int, default=4, choices=[2, 3, 4, 5, 6, 7], help="Order of n-gram") + """ +) +parser.add_argument( + "-ngram-order", + type=int, + default=4, + choices=[2, 3, 4, 5, 6, 7], + help="Order of n-gram", +) parser.add_argument("-text", type=str, default=None, help="Path to the corpus file") -parser.add_argument("-lm", type=str, default=None, help="Path to output arpa file for language models") -parser.add_argument("-verbose", type=int, default=0, choices=[0, 1, 2, 3, 4, 5], help="Verbose level") +parser.add_argument( + "-lm", type=str, default=None, help="Path to output arpa file for language models" +) +parser.add_argument( + "-verbose", type=int, default=0, choices=[0, 1, 2, 3, 4, 5], help="Verbose level" +) args = parser.parse_args() -default_encoding = "latin-1" # For encoding-agnostic scripts, we assume byte stream as input. - # Need to be very careful about the use of strip() and split() - # in this case, because there is a latin-1 whitespace character - # (nbsp) which is part of the unicode encoding range. - # Ref: kaldi/egs/wsj/s5/utils/lang/bpe/prepend_words.py @ 69cd717 +default_encoding = ( + "latin-1" # For encoding-agnostic scripts, we assume byte stream as input. +) +# Need to be very careful about the use of strip() and split() +# in this case, because there is a latin-1 whitespace character +# (nbsp) which is part of the unicode encoding range. +# Ref: kaldi/egs/wsj/s5/utils/lang/bpe/prepend_words.py @ 69cd717 strip_chars = " \t\r\n" whitespace = re.compile("[ \t]+") @@ -52,7 +65,9 @@ def __init__(self): # The 'lambda: defaultdict(float)' is an anonymous function taking no # arguments that returns a new defaultdict(float). self.word_to_count = defaultdict(int) - self.word_to_context = defaultdict(set) # using a set to count the number of unique contexts + self.word_to_context = defaultdict( + set + ) # using a set to count the number of unique contexts self.word_to_f = dict() # discounted probability self.word_to_bow = dict() # back-off weight self.total_count = 0 @@ -62,10 +77,15 @@ def words(self): def __str__(self): # e.g. returns ' total=12: 3->4, 4->6, -1->2' - return ' total={0}: {1}'.format( + return " total={0}: {1}".format( str(self.total_count), - ', '.join(['{0} -> {1}'.format(word, count) - for word, count in self.word_to_count.items()])) + ", ".join( + [ + "{0} -> {1}".format(word, count) + for word, count in self.word_to_count.items() + ] + ), + ) def add_count(self, predicted_word, context_word, count): assert count >= 0 @@ -85,7 +105,7 @@ class NgramCounts: # accumulating the 4-gram count for the '8' in the sequence '5 6 7 8', we'd # do as follows: self.counts[3][[5,6,7]][8] += 1.0 where the [3] indexes an # array, the [[5,6,7]] indexes a dict, and the [8] indexes a dict. - def __init__(self, ngram_order, bos_symbol=' ', eos_symbol=''): + def __init__(self, ngram_order, bos_symbol="", eos_symbol=""): assert ngram_order >= 2 self.ngram_order = ngram_order @@ -103,39 +123,48 @@ def __init__(self, ngram_order, bos_symbol='', eos_symbol=''): # would be (6,7,8) and 'predicted_word' would be 9; 'count' would be # 1. def add_count(self, history, predicted_word, context_word, count): - self.counts[len(history)][history].add_count(predicted_word, context_word, count) + self.counts[len(history)][history].add_count( + predicted_word, context_word, count + ) # 'line' is a string containing a sequence of integer word-ids. # This function adds the un-smoothed counts from this line of text. def add_raw_counts_from_line(self, line): - if line == '': + if line == "": words = [self.bos_symbol, self.eos_symbol] else: words = [self.bos_symbol] + whitespace.split(line) + [self.eos_symbol] for i in range(len(words)): - for n in range(1, self.ngram_order+1): + for n in range(1, self.ngram_order + 1): if i + n > len(words): break - ngram = words[i: i + n] + ngram = words[i : i + n] predicted_word = ngram[-1] - history = tuple(ngram[: -1]) + history = tuple(ngram[:-1]) if i == 0 or n == self.ngram_order: context_word = None else: - context_word = words[i-1] + context_word = words[i - 1] self.add_count(history, predicted_word, context_word, 1) def add_raw_counts_from_standard_input(self): lines_processed = 0 - infile = io.TextIOWrapper(sys.stdin.buffer, encoding=default_encoding) # byte stream as input + infile = io.TextIOWrapper( + sys.stdin.buffer, encoding=default_encoding + ) # byte stream as input for line in infile: line = line.strip(strip_chars) self.add_raw_counts_from_line(line) lines_processed += 1 if lines_processed == 0 or args.verbose > 0: - print("make_phone_lm.py: processed {0} lines of input".format(lines_processed), file=sys.stderr) + print( + "make_phone_lm.py: processed {0} lines of input".format( + lines_processed + ), + file=sys.stderr, + ) def add_raw_counts_from_file(self, filename): lines_processed = 0 @@ -145,7 +174,12 @@ def add_raw_counts_from_file(self, filename): self.add_raw_counts_from_line(line) lines_processed += 1 if lines_processed == 0 or args.verbose > 0: - print("make_phone_lm.py: processed {0} lines of input".format(lines_processed), file=sys.stderr) + print( + "make_phone_lm.py: processed {0} lines of input".format( + lines_processed + ), + file=sys.stderr, + ) def cal_discounting_constants(self): # For each order N of N-grams, we calculate discounting constant D_N = n1_N / (n1_N + 2 * n2_N), @@ -153,9 +187,11 @@ def cal_discounting_constants(self): # This constant is used similarly to absolute discounting. # Return value: d is a list of floats, where d[N+1] = D_N - self.d = [0] # for the lowest order, i.e., 1-gram, we do not need to discount, thus the constant is 0 - # This is a special case: as we currently assumed having seen all vocabularies in the dictionary, - # but perhaps this is not the case for some other scenarios. + self.d = [ + 0 + ] # for the lowest order, i.e., 1-gram, we do not need to discount, thus the constant is 0 + # This is a special case: as we currently assumed having seen all vocabularies in the dictionary, + # but perhaps this is not the case for some other scenarios. for n in range(1, self.ngram_order): this_order_counts = self.counts[n] n1 = 0 @@ -165,9 +201,11 @@ def cal_discounting_constants(self): n1 += stat[1] n2 += stat[2] assert n1 + 2 * n2 > 0 - self.d.append(max(0.1, n1 * 1.0) / (n1 + 2 * n2)) # We are doing this max(0.001, xxx) to avoid zero discounting constant D due to n1=0, - # which could happen if the number of symbols is small. - # Otherwise, zero discounting constant can cause division by zero in computing BOW. + self.d.append( + max(0.1, n1 * 1.0) / (n1 + 2 * n2) + ) # We are doing this max(0.001, xxx) to avoid zero discounting constant D due to n1=0, + # which could happen if the number of symbols is small. + # Otherwise, zero discounting constant can cause division by zero in computing BOW. def cal_f(self): # f(a_z) is a probability distribution of word sequence a_z. @@ -182,7 +220,9 @@ def cal_f(self): this_order_counts = self.counts[n] for hist, counts_for_hist in this_order_counts.items(): for w, c in counts_for_hist.word_to_count.items(): - counts_for_hist.word_to_f[w] = max((c - self.d[n]), 0) * 1.0 / counts_for_hist.total_count + counts_for_hist.word_to_f[w] = ( + max((c - self.d[n]), 0) * 1.0 / counts_for_hist.total_count + ) # lower order N-grams for n in range(0, self.ngram_order - 1): @@ -196,11 +236,17 @@ def cal_f(self): if n_star_star != 0: for w in counts_for_hist.word_to_count.keys(): n_star_z = len(counts_for_hist.word_to_context[w]) - counts_for_hist.word_to_f[w] = max((n_star_z - self.d[n]), 0) * 1.0 / n_star_star + counts_for_hist.word_to_f[w] = ( + max((n_star_z - self.d[n]), 0) * 1.0 / n_star_star + ) else: # patterns begin with, they do not have "modified count", so use raw count instead for w in counts_for_hist.word_to_count.keys(): n_star_z = counts_for_hist.word_to_count[w] - counts_for_hist.word_to_f[w] = max((n_star_z - self.d[n]), 0) * 1.0 / counts_for_hist.total_count + counts_for_hist.word_to_f[w] = ( + max((n_star_z - self.d[n]), 0) + * 1.0 + / counts_for_hist.total_count + ) def cal_bow(self): # Backoff weights are only necessary for ngrams which form a prefix of a longer ngram. @@ -240,12 +286,18 @@ def cal_bow(self): sum_z1_f_z = 0 _ = a_[1:] _counts_for_hist = self.counts[len(_)][_] - for u in a_counts_for_hist.word_to_count.keys(): # Should be careful here: what is Z1 + for ( + u + ) in ( + a_counts_for_hist.word_to_count.keys() + ): # Should be careful here: what is Z1 sum_z1_f_z += _counts_for_hist.word_to_f[u] if sum_z1_f_z < 1: # assert sum_z1_f_a_z < 1 - counts_for_hist.word_to_bow[w] = (1.0 - sum_z1_f_a_z) / (1.0 - sum_z1_f_z) + counts_for_hist.word_to_bow[w] = (1.0 - sum_z1_f_a_z) / ( + 1.0 - sum_z1_f_z + ) else: counts_for_hist.word_to_bow[w] = None @@ -259,7 +311,9 @@ def print_raw_counts(self, info_string): ngram = " ".join(hist) + " " + w ngram = ngram.strip(strip_chars) - res.append("{0}\t{1}".format(ngram, counts_for_hist.word_to_count[w])) + res.append( + "{0}\t{1}".format(ngram, counts_for_hist.word_to_count[w]) + ) res.sort(reverse=True) for r in res: print(r) @@ -322,27 +376,40 @@ def print_f_and_bow(self, info_string): if bow is None: res.append("{1}\t{0}".format(ngram, math.log(f, 10))) else: - res.append("{1}\t{0}\t{2}".format(ngram, math.log(f, 10), math.log(bow, 10))) + res.append( + "{1}\t{0}\t{2}".format( + ngram, math.log(f, 10), math.log(bow, 10) + ) + ) res.sort(reverse=True) for r in res: print(r) - def print_as_arpa(self, fout=io.TextIOWrapper(sys.stdout.buffer, encoding='latin-1')): + def print_as_arpa( + self, fout=io.TextIOWrapper(sys.stdout.buffer, encoding="latin-1") + ): # print as ARPA format. - print('\\data\\', file=fout) + print("\\data\\", file=fout) for hist_len in range(self.ngram_order): # print the number of n-grams. - print('ngram {0}={1}'.format( - hist_len + 1, - sum([len(counts_for_hist.word_to_f) for counts_for_hist in self.counts[hist_len].values()])), - file=fout + print( + "ngram {0}={1}".format( + hist_len + 1, + sum( + [ + len(counts_for_hist.word_to_f) + for counts_for_hist in self.counts[hist_len].values() + ] + ), + ), + file=fout, ) - print('', file=fout) + print("", file=fout) for hist_len in range(self.ngram_order): - print('\\{0}-grams:'.format(hist_len + 1), file=fout) + print("\\{0}-grams:".format(hist_len + 1), file=fout) this_order_counts = self.counts[hist_len] for hist, counts_for_hist in this_order_counts.items(): @@ -354,12 +421,12 @@ def print_as_arpa(self, fout=io.TextIOWrapper(sys.stdout.buffer, encoding='latin if prob == 0: # f() is always 0 prob = 1e-99 - line = '{0}\t{1}'.format('%.7f' % math.log10(prob), ' '.join(ngram)) + line = "{0}\t{1}".format("%.7f" % math.log10(prob), " ".join(ngram)) if bow is not None: - line += '\t{0}'.format('%.7f' % math.log10(bow)) + line += "\t{0}".format("%.7f" % math.log10(bow)) print(line, file=fout) - print('', file=fout) - print('\\end\\', file=fout) + print("", file=fout) + print("\\end\\", file=fout) if __name__ == "__main__": @@ -379,5 +446,5 @@ def print_as_arpa(self, fout=io.TextIOWrapper(sys.stdout.buffer, encoding='latin if args.lm is None: ngram_counts.print_as_arpa() else: - with open(args.lm, 'w', encoding=default_encoding) as f: + with open(args.lm, "w", encoding=default_encoding) as f: ngram_counts.print_as_arpa(fout=f) diff --git a/icefall/utils.py b/icefall/utils.py index 143c794974..b4d8e9a51e 100644 --- a/icefall/utils.py +++ b/icefall/utils.py @@ -130,9 +130,7 @@ def setup_logger( formatter = f"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] ({rank}/{world_size}) %(message)s" # noqa log_filename = f"{log_filename}-{date_time}-{rank}" else: - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" log_filename = f"{log_filename}-{date_time}" os.makedirs(os.path.dirname(log_filename), exist_ok=True) @@ -203,7 +201,7 @@ def encode_supervisions( supervisions["num_frames"], subsampling_factor, rounding_mode="floor", - ) + ), ), 1, ).to(torch.int32) @@ -288,13 +286,9 @@ def get_texts_with_timestamp( """ if isinstance(best_paths.aux_labels, k2.RaggedTensor): all_aux_shape = ( - best_paths.arcs.shape() - .remove_axis(1) - .compose(best_paths.aux_labels.shape) - ) - all_aux_labels = k2.RaggedTensor( - all_aux_shape, best_paths.aux_labels.values + best_paths.arcs.shape().remove_axis(1).compose(best_paths.aux_labels.shape) ) + all_aux_labels = k2.RaggedTensor(all_aux_shape, best_paths.aux_labels.values) # remove 0's and -1's. aux_labels = best_paths.aux_labels.remove_values_leq(0) # TODO: change arcs.shape() to arcs.shape @@ -363,9 +357,7 @@ def get_alignments(best_paths: k2.Fsa, kind: str) -> List[List[int]]: # arc.shape() has axes [fsa][state][arc], we remove "state"-axis here token_shape = best_paths.arcs.shape().remove_axis(1) # token_shape has axes [fsa][arc] - tokens = k2.RaggedTensor( - token_shape, getattr(best_paths, kind).contiguous() - ) + tokens = k2.RaggedTensor(token_shape, getattr(best_paths, kind).contiguous()) tokens = tokens.remove_values_eq(-1) return tokens.tolist() @@ -586,9 +578,7 @@ def write_error_stats( f"{cut_id}:\t" + " ".join( ( - ref_word - if ref_word == hyp_word - else f"({ref_word}->{hyp_word})" + ref_word if ref_word == hyp_word else f"({ref_word}->{hyp_word})" for ref_word, hyp_word in ali ) ), @@ -598,9 +588,7 @@ def write_error_stats( print("", file=f) print("SUBSTITUTIONS: count ref -> hyp", file=f) - for count, (ref, hyp) in sorted( - [(v, k) for k, v in subs.items()], reverse=True - ): + for count, (ref, hyp) in sorted([(v, k) for k, v in subs.items()], reverse=True): print(f"{count} {ref} -> {hyp}", file=f) print("", file=f) @@ -614,9 +602,7 @@ def write_error_stats( print(f"{count} {hyp}", file=f) print("", file=f) - print( - "PER-WORD STATS: word corr tot_errs count_in_ref count_in_hyp", file=f - ) + print("PER-WORD STATS: word corr tot_errs count_in_ref count_in_hyp", file=f) for _, word, counts in sorted( [(sum(v[1:]), k, v) for k, v in words.items()], reverse=True ): @@ -791,9 +777,7 @@ def write_error_stats_with_timestamps( f"{cut_id}:\t" + " ".join( ( - ref_word - if ref_word == hyp_word - else f"({ref_word}->{hyp_word})" + ref_word if ref_word == hyp_word else f"({ref_word}->{hyp_word})" for ref_word, hyp_word in ali ) ), @@ -803,9 +787,7 @@ def write_error_stats_with_timestamps( print("", file=f) print("SUBSTITUTIONS: count ref -> hyp", file=f) - for count, (ref, hyp) in sorted( - [(v, k) for k, v in subs.items()], reverse=True - ): + for count, (ref, hyp) in sorted([(v, k) for k, v in subs.items()], reverse=True): print(f"{count} {ref} -> {hyp}", file=f) print("", file=f) @@ -819,9 +801,7 @@ def write_error_stats_with_timestamps( print(f"{count} {hyp}", file=f) print("", file=f) - print( - "PER-WORD STATS: word corr tot_errs count_in_ref count_in_hyp", file=f - ) + print("PER-WORD STATS: word corr tot_errs count_in_ref count_in_hyp", file=f) for _, word, counts in sorted( [(sum(v[1:]), k, v) for k, v in words.items()], reverse=True ): @@ -891,9 +871,7 @@ def norm_items(self) -> List[Tuple[str, float]]: if k == "frames" or k == "utterances": continue norm_value = ( - float(v) / num_frames - if "utt_" not in k - else float(v) / num_utterances + float(v) / num_frames if "utt_" not in k else float(v) / num_utterances ) ans.append((k, norm_value)) return ans @@ -927,9 +905,7 @@ def write_summary( tb_writer.add_scalar(prefix + k, v, batch_idx) -def concat( - ragged: k2.RaggedTensor, value: int, direction: str -) -> k2.RaggedTensor: +def concat(ragged: k2.RaggedTensor, value: int, direction: str) -> k2.RaggedTensor: """Prepend a value to the beginning of each sublist or append a value. to the end of each sublist. @@ -1101,9 +1077,7 @@ def linf_norm(x): return torch.max(torch.abs(x)) -def measure_weight_norms( - model: nn.Module, norm: str = "l2" -) -> Dict[str, float]: +def measure_weight_norms(model: nn.Module, norm: str = "l2") -> Dict[str, float]: """ Compute the norms of the model's parameters. @@ -1126,9 +1100,7 @@ def measure_weight_norms( return norms -def measure_gradient_norms( - model: nn.Module, norm: str = "l1" -) -> Dict[str, float]: +def measure_gradient_norms(model: nn.Module, norm: str = "l1") -> Dict[str, float]: """ Compute the norms of the gradients for each of model's parameters. @@ -1413,9 +1385,7 @@ def parse_hyp_and_timestamp( use_word_table = True for i in range(N): - time = convert_timestamp( - res.timestamps[i], subsampling_factor, frame_shift_ms - ) + time = convert_timestamp(res.timestamps[i], subsampling_factor, frame_shift_ms) if use_word_table: words = [word_table[i] for i in res.hyps[i]] else: diff --git a/pyproject.toml b/pyproject.toml index b4f8c3377e..3183055d46 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -3,7 +3,7 @@ profile = "black" skip = ["icefall/__init__.py"] [tool.black] -line-length = 80 +line-length = 88 exclude = ''' /( \.git diff --git a/setup.py b/setup.py index 6c720e1219..ccd2503ffc 100644 --- a/setup.py +++ b/setup.py @@ -1,8 +1,9 @@ #!/usr/bin/env python3 -from setuptools import find_packages, setup from pathlib import Path +from setuptools import find_packages, setup + icefall_dir = Path(__file__).parent install_requires = (icefall_dir / "requirements.txt").read_text().splitlines() diff --git a/test/test_checkpoint.py b/test/test_checkpoint.py index 511a11c23e..34e829642b 100644 --- a/test/test_checkpoint.py +++ b/test/test_checkpoint.py @@ -20,11 +20,7 @@ import torch import torch.nn as nn -from icefall.checkpoint import ( - average_checkpoints, - load_checkpoint, - save_checkpoint, -) +from icefall.checkpoint import average_checkpoints, load_checkpoint, save_checkpoint @pytest.fixture diff --git a/test/test_decode.py b/test/test_decode.py index 97964ac67b..4c2e192a79 100644 --- a/test/test_decode.py +++ b/test/test_decode.py @@ -23,6 +23,7 @@ """ import k2 + from icefall.decode import Nbest diff --git a/test/test_graph_compiler.py b/test/test_graph_compiler.py index ccfb57d495..10443cf221 100644 --- a/test/test_graph_compiler.py +++ b/test/test_graph_compiler.py @@ -154,9 +154,7 @@ def test_compile(compiler, lexicon): fsas = k2.Fsa.from_fsas([fsa1, fsa2]) decoding_graph = k2.arc_sort(decoding_graph) - lattice = k2.intersect( - decoding_graph, fsas, treat_epsilons_specially=False - ) + lattice = k2.intersect(decoding_graph, fsas, treat_epsilons_specially=False) lattice = k2.connect(lattice) aux_labels0 = lattice[0].aux_labels[:-1] diff --git a/test/test_utils.py b/test/test_utils.py index 6a9ce7853e..31f06bd51b 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -50,9 +50,7 @@ def test_encode_supervisions(sup): assert torch.all( torch.eq( supervision_segments, - torch.tensor( - [[1, 0, 30 // 4], [0, 0, 20 // 4], [2, 9 // 4, 10 // 4]] - ), + torch.tensor([[1, 0, 30 // 4], [0, 0, 20 // 4], [2, 9 // 4, 10 // 4]]), ) ) assert texts == ["two", "one", "three"]