diff --git a/egs/librispeech/ASR/zipformer/biasing.py b/egs/librispeech/ASR/zipformer/biasing.py new file mode 100644 index 0000000000..8d66919808 --- /dev/null +++ b/egs/librispeech/ASR/zipformer/biasing.py @@ -0,0 +1,284 @@ +# Copyright 2024 Xiaomi Corp. (authors: Wei Kang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Optional, Tuple + +import math +import torch +import torch.nn as nn + +from icefall import ContextGraph +import logging +import random + + +class TCPGen(nn.Module): + def __init__( + self, + encoder_dim: int, + embed_dim: int, + joiner_dim: int, + decoder: nn.Module, + attn_dim: int = 512, + tcpgen_dropout: float = 0.1, + ): + super().__init__() + # embedding for ool (out of biasing list) token + self.oolemb = torch.nn.Embedding(1, embed_dim) + # project symbol embeddings + self.q_proj_sym = torch.nn.Linear(embed_dim, attn_dim) + # project encoder embeddings + self.q_proj_acoustic = torch.nn.Linear(encoder_dim, attn_dim) + # project symbol embeddings (vocabulary + ool) + self.k_proj = torch.nn.Linear(embed_dim, attn_dim) + # generate tcpgen probability + self.tcp_gate = torch.nn.Linear(attn_dim + joiner_dim, 1) + self.dropout_tcpgen = torch.nn.Dropout(tcpgen_dropout) + self.decoder = decoder + self.vocab_size = decoder.vocab_size + + def get_tcpgen_masks( + self, + targets: torch.Tensor, + context_graph: ContextGraph, + vocab_size: int, + blank_id: int = 0, + ) -> Tuple[torch.Tensor, torch.Tensor]: + batch_size, sql_len = targets.shape + dist_masks = torch.ones((batch_size, sql_len, vocab_size + 1)) + gen_masks = [] + yseqs = targets.tolist() + for i, yseq in enumerate(yseqs): + node = context_graph.root + gen_mask = [] + for j, y in enumerate(yseq): + not_matched = False + if y == blank_id: + node = context_graph.root + gen_mask.append(0) + elif y in node.next: + gen_mask.append(0) + node = node.next[y] + if node.is_end: + node = context_graph.root + else: + gen_mask.append(1) + node = context_graph.root + not_matched = True + # unmask_index = ( + # [vocab_size] + # if node.token == -1 + # else list(node.next.keys()) + [vocab_size] + # ) + # logging.info(f"token : {node.token}, keys : {node.next.keys()}") + # dist_masks[i, j, unmask_index] = 0 + if not not_matched: + dist_masks[i, j, list(node.next.keys())] = 0 + + gen_masks.append(gen_mask + [1] * (sql_len - len(gen_mask))) + gen_masks = torch.Tensor(gen_masks).to(targets.device).bool() + dist_masks = dist_masks.to(targets.device).bool() + if random.random() >= 0.95: + logging.info( + f"gen_mask nonzero {gen_masks.shape} : {torch.count_nonzero(torch.logical_not(gen_masks), dim=1)}" + ) + logging.info( + f"dist_masks nonzero {dist_masks.shape} : {torch.count_nonzero(torch.logical_not(dist_masks), dim=2)}" + ) + return dist_masks, gen_masks + + def get_tcpgen_distribution( + self, query: torch.Tensor, dist_masks: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + query: + shape : (B, T, s_range, attn_dim) + dist_masks: + shape : (B, T, s_range, V + 1) + """ + # From original paper, k, v share same embeddings + # (V + 1, embed_dim) + kv = torch.cat([self.decoder.embedding.weight.data, self.oolemb.weight], dim=0) + # (V + 1, attn_dim) + kv = self.dropout_tcpgen(self.k_proj(kv)) + # (B, T, s_range, attn_dim) * (attn_dim, V + 1) -> (B, T, s_range, V + 1) + distribution = torch.matmul(query, kv.permute(1, 0)) / math.sqrt(query.size(-1)) + distribution = distribution.masked_fill( + dist_masks, torch.finfo(distribution.dtype).min + ) + distribution = distribution.softmax(dim=-1) + # (B, T, s_range, V) * (V, attn_dim) -> (B, T, s_range, attn_dim) + # logging.info(f"distribution shape : {distribution.shape}") + hptr = torch.matmul(distribution[:, :, :, :-1], kv[:-1, :]) + hptr = self.dropout_tcpgen(hptr) + if random.random() > 0.95: + logging.info( + f"distribution mean : {torch.mean(distribution, dim=3)}, std: {torch.std(distribution, dim=3)}" + ) + logging.info( + f"distribution min : {torch.min(distribution, dim=3)}, max: {torch.max(distribution, dim=3)}" + ) + return hptr, distribution + + def prune_query_and_mask( + self, + query_sym: torch.Tensor, + query_acoustic: torch.Tensor, + dist_masks: torch.Tensor, + gen_masks: torch.Tensor, + ranges: torch.Tensor, + ): + """Prune the queries from symbols and acoustics with ranges + generated by `get_rnnt_prune_ranges` in pruned rnnt loss. + + Args: + query_sym: + The symbol query, with shape (B, S, attn_dim). + query_acoustic: + The acoustic query, with shape (B, T, attn_dim). + dist_masks: + The TCPGen distribution masks, with shape (B, S, V + 1). + gen_masks: + The TCPGen probability masks, with shape (B, S). + ranges: + A tensor containing the symbol indexes for each frame that we want to + keep. Its shape is (B, T, s_range), see the docs in + `get_rnnt_prune_ranges` in rnnt_loss.py for more details of this tensor. + + Returns: + Return the pruned query with the shape (B, T, s_range, attn_dim). + """ + assert ranges.shape[0] == query_sym.shape[0], (ranges.shape, query_sym.shape) + assert ranges.shape[0] == query_acoustic.shape[0], ( + ranges.shape, + query_acoustic.shape, + ) + assert query_acoustic.shape[1] == ranges.shape[1], ( + query_acoustic.shape, + ranges.shape, + ) + (B, T, s_range) = ranges.shape + (B, S, attn_dim) = query_sym.shape + assert query_acoustic.shape == (B, T, attn_dim), ( + query_acoustic.shape, + (B, T, attn_dim), + ) + assert dist_masks.shape == (B, S, self.vocab_size + 1), ( + dist_masks.shape, + (B, S, self.vocab_size + 1), + ) + assert gen_masks.shape == (B, S), (gen_masks.shape, (B, S)) + # (B, T, s_range, attn_dim) + query_acoustic_pruned = query_acoustic.unsqueeze(2).expand( + (B, T, s_range, attn_dim) + ) + # logging.info(f"query_sym : {query_sym.shape}") + # logging.info(f"ranges : {ranges}") + # (B, T, s_range, attn_dim) + query_sym_pruned = torch.gather( + query_sym, + dim=1, + index=ranges.reshape(B, T * s_range, 1).expand((B, T * s_range, attn_dim)), + ).reshape(B, T, s_range, attn_dim) + # (B, T, s_range, V + 1) + dist_masks_pruned = torch.gather( + dist_masks, + dim=1, + index=ranges.reshape(B, T * s_range, 1).expand( + (B, T * s_range, self.vocab_size + 1) + ), + ).reshape(B, T, s_range, self.vocab_size + 1) + # (B, T, s_range) + gen_masks_pruned = torch.gather( + gen_masks, dim=1, index=ranges.reshape(B, T * s_range) + ).reshape(B, T, s_range) + return ( + query_sym_pruned + query_acoustic_pruned, + dist_masks_pruned, + gen_masks_pruned, + ) + + def forward( + self, + targets: torch.Tensor, + encoder_embeddings: torch.Tensor, + ranges: torch.Tensor, + context_graph: ContextGraph, + ): + """ + Args: + target: + The training targets in token ids (padded with blanks). shape : (B, S) + encoder_embeddings: + The encoder outputs. shape: (B, T, attn_dim) + ranges: + The prune ranges from pruned rnnt. shape: (B, T, s_range) + context_graphs: + The context_graphs for each utterance. B == len(context_graphs). + + Return: + returns tcpgen embedding with shape (B, T, s_range, attn_dim) and + tcpgen distribution with shape (B, T, s_range, V + 1). + """ + query_sym = self.decoder.embedding(targets) + + query_sym = self.q_proj_sym(query_sym) # (B, S, attn_dim) + query_acoustic = self.q_proj_acoustic(encoder_embeddings) # (B , T, attn_dim) + + # dist_masks : (B, S, V + 1) + # gen_masks : (B, S) + dist_masks, gen_masks = self.get_tcpgen_masks( + targets=targets, context_graph=context_graph, vocab_size=self.vocab_size + ) + # query : (B, T, s_range, attn_dim) + # dist_masks : (B, T, s_range, V + 1) + query, dist_masks, gen_masks = self.prune_query_and_mask( + query_sym=query_sym, + query_acoustic=query_acoustic, + dist_masks=dist_masks, + gen_masks=gen_masks, + ranges=ranges, + ) + + if random.random() >= 0.95: + logging.info( + f"pruned gen_mask nonzero {gen_masks.shape} : {torch.count_nonzero(torch.logical_not(gen_masks), dim=1)}" + ) + logging.info( + f"pruned dist_masks nonzero {dist_masks.shape} : {torch.count_nonzero(torch.logical_not(dist_masks), dim=3)}" + ) + + # hptr : (B, T, s_range, attn_dim) + # tcpgen_dist : (B, T, s_range, V + 1) + hptr, tcpgen_dist = self.get_tcpgen_distribution(query, dist_masks) + return hptr, tcpgen_dist, gen_masks + + def generator_prob( + self, hptr: torch.Tensor, h_joiner: torch.Tensor, gen_masks: torch.Tensor + ) -> torch.Tensor: + # tcpgen_prob : (B, T, s_range, 1) + tcpgen_prob = self.tcp_gate(torch.cat((h_joiner, hptr), dim=-1)) + tcpgen_prob = torch.sigmoid(tcpgen_prob) + tcpgen_prob = tcpgen_prob.masked_fill(gen_masks.unsqueeze(-1), 0) + if random.random() >= 0.95: + logging.info( + f"tcpgen_prob mean : {torch.mean(tcpgen_prob.squeeze(-1), dim=(1,2))}, std : {torch.std(tcpgen_prob.squeeze(-1), dim=(1, 2))}" + ) + logging.info( + f"tcpgen_prob min : {torch.min(tcpgen_prob.squeeze(-1), dim=1)}, max : {torch.max(tcpgen_prob.squeeze(-1), dim=1)}" + ) + return tcpgen_prob diff --git a/egs/librispeech/ASR/zipformer/joiner.py b/egs/librispeech/ASR/zipformer/joiner.py index dfb0a0057b..d999499a53 100644 --- a/egs/librispeech/ASR/zipformer/joiner.py +++ b/egs/librispeech/ASR/zipformer/joiner.py @@ -17,6 +17,7 @@ import torch import torch.nn as nn from scaling import ScaledLinear +from typing import Optional class Joiner(nn.Module): @@ -37,6 +38,7 @@ def forward( self, encoder_out: torch.Tensor, decoder_out: torch.Tensor, + tcpgen_hptr: Optional[torch.Tensor] = None, project_input: bool = True, ) -> torch.Tensor: """ @@ -62,6 +64,11 @@ def forward( else: logit = encoder_out + decoder_out - logit = self.output_linear(torch.tanh(logit)) + if tcpgen_hptr is not None: + logit += tcpgen_hptr - return logit + activations = torch.tanh(logit) + + logit = self.output_linear(activations) + + return logit if tcpgen_hptr is None else (logit, activations) diff --git a/egs/librispeech/ASR/zipformer/model.py b/egs/librispeech/ASR/zipformer/model.py index 86da3ab29a..21ccdaf180 100644 --- a/egs/librispeech/ASR/zipformer/model.py +++ b/egs/librispeech/ASR/zipformer/model.py @@ -17,14 +17,17 @@ # limitations under the License. from typing import Optional, Tuple +import logging import k2 import torch import torch.nn as nn from encoder_interface import EncoderInterface from scaling import ScaledLinear +from biasing import TCPGen from icefall.utils import add_sos, make_pad_mask +from icefall import ContextGraph class AsrModel(nn.Module): @@ -36,9 +39,13 @@ def __init__( joiner: Optional[nn.Module] = None, encoder_dim: int = 384, decoder_dim: int = 512, + joiner_dim: int = 512, vocab_size: int = 500, + tcpgen_attn_dim: int = 512, use_transducer: bool = True, use_ctc: bool = False, + use_tcpgen_biasing: bool = False, + tcpgen_dropout: float = 0.15, ): """A joint CTC & Transducer ASR model. @@ -111,6 +118,19 @@ def __init__( nn.LogSoftmax(dim=-1), ) + self.use_tcpgen_biasing = use_tcpgen_biasing + if use_tcpgen_biasing: + assert use_transducer, "TCPGen biasing only support on transducer model." + self.tcp_gen = TCPGen( + encoder_dim=encoder_dim, + embed_dim=decoder_dim, + attn_dim=tcpgen_attn_dim, + joiner_dim=joiner_dim, + decoder=decoder, + tcpgen_dropout=tcpgen_dropout, + ) + self.hptr_proj = torch.nn.Linear(tcpgen_attn_dim, joiner_dim) + def forward_encoder( self, x: torch.Tensor, x_lens: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: @@ -180,6 +200,7 @@ def forward_transducer( prune_range: int = 5, am_scale: float = 0.0, lm_scale: float = 0.0, + context_graph: Optional[ContextGraph] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: """Compute Transducer loss. Args: @@ -202,6 +223,7 @@ def forward_transducer( """ # Now for the decoder, i.e., the prediction network blank_id = self.decoder.blank_id + assert blank_id == 0, f"Assuming blank_id is 0, given {blank_id}" sos_y = add_sos(y, sos_id=blank_id) # sos_y_padded: [B, S + 1], start with SOS. @@ -226,11 +248,6 @@ def forward_transducer( lm = self.simple_lm_proj(decoder_out) am = self.simple_am_proj(encoder_out) - # if self.training and random.random() < 0.25: - # lm = penalize_abs_values_gt(lm, 100.0, 1.0e-04) - # if self.training and random.random() < 0.25: - # am = penalize_abs_values_gt(am, 30.0, 1.0e-04) - with torch.cuda.amp.autocast(enabled=False): simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( lm=lm.float(), @@ -252,19 +269,66 @@ def forward_transducer( s_range=prune_range, ) - # am_pruned : [B, T, prune_range, encoder_dim] - # lm_pruned : [B, T, prune_range, decoder_dim] + # am_pruned : [B, T, prune_range, joiner_dim] + # lm_pruned : [B, T, prune_range, joiner_dim] am_pruned, lm_pruned = k2.do_rnnt_pruning( am=self.joiner.encoder_proj(encoder_out), lm=self.joiner.decoder_proj(decoder_out), ranges=ranges, ) - # logits : [B, T, prune_range, vocab_size] + fused_log_softmax = True + if self.use_tcpgen_biasing and context_graph is not None: + # logging.info(f"using tcpgen baising") + fused_log_softmax = False + # hptr : (B, T, s_range, attn_dim) + # tcpgen_dist : (B, T, s_range, V + 1) + # gen_masks : (B, T, s_range) + hptr, tcpgen_dist, gen_masks = self.tcp_gen( + targets=sos_y_padded, + encoder_embeddings=encoder_out, + ranges=ranges, + context_graph=context_graph, + ) + # (B, T, s_range, joiner_dim) + tcpgen_hptr = self.hptr_proj(hptr) + # logits : (B, T, s_range, V) + # activations : (B, T, s_range, joiner_dim) + logits, activations = self.joiner( + encoder_out=am_pruned, + decoder_out=lm_pruned, + tcpgen_hptr=tcpgen_hptr, + project_input=False, + ) + # (B, T, s_range, 1) + tcpgen_prob = self.tcp_gen.generator_prob( + hptr=hptr, h_joiner=activations, gen_masks=gen_masks + ) + + # Assuming blank_id is 0 + p_mdl = torch.softmax(logits, dim=-1) + p_no_blank = 1.0 - p_mdl[:, :, :, 0:1] + + # blank dist (tcpgen_dist[:,:,:,0]) should be 0 + scaled_tcpgen_dist = tcpgen_dist[:, :, :, 1:] * p_no_blank + + # ool probability : tcpgen_dist[:, :, :, -1:] + scaled_tcpgen_prob = tcpgen_prob * (1 - tcpgen_dist[:, :, :, -1:]) + + interpolated_no_blank = scaled_tcpgen_dist[ + :, :, :, :-1 + ] * tcpgen_prob + p_mdl[:, :, :, 1:] * (1 - scaled_tcpgen_prob) + + final_dist = torch.cat([p_mdl[:, :, :, 0:1], interpolated_no_blank], dim=-1) + + # actually logprobs + logits = torch.log(final_dist + torch.finfo(final_dist.dtype).tiny) + else: + # logits : [B, T, prune_range, vocab_size] - # project_input=False since we applied the decoder's input projections - # prior to do_rnnt_pruning (this is an optimization for speed). - logits = self.joiner(am_pruned, lm_pruned, project_input=False) + # project_input=False since we applied the decoder's input projections + # prior to do_rnnt_pruning (this is an optimization for speed). + logits = self.joiner(am_pruned, lm_pruned, project_input=False) with torch.cuda.amp.autocast(enabled=False): pruned_loss = k2.rnnt_loss_pruned( @@ -274,6 +338,7 @@ def forward_transducer( termination_symbol=blank_id, boundary=boundary, reduction="sum", + fused_log_softmax=fused_log_softmax, ) return simple_loss, pruned_loss @@ -286,6 +351,7 @@ def forward( prune_range: int = 5, am_scale: float = 0.0, lm_scale: float = 0.0, + context_graph: Optional[ContextGraph] = None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Args: @@ -338,6 +404,7 @@ def forward( prune_range=prune_range, am_scale=am_scale, lm_scale=lm_scale, + context_graph=context_graph, ) else: simple_loss = torch.empty(0) diff --git a/egs/librispeech/ASR/zipformer/train.py b/egs/librispeech/ASR/zipformer/train.py index 04caf2fd80..edacf968aa 100755 --- a/egs/librispeech/ASR/zipformer/train.py +++ b/egs/librispeech/ASR/zipformer/train.py @@ -54,10 +54,11 @@ import argparse import copy import logging +import random import warnings from pathlib import Path from shutil import copyfile -from typing import Any, Dict, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union import k2 import optim @@ -81,7 +82,7 @@ from torch.utils.tensorboard import SummaryWriter from zipformer import Zipformer2 -from icefall import diagnostics +from icefall import ContextGraph, diagnostics from icefall.checkpoint import load_checkpoint, remove_checkpoints from icefall.checkpoint import save_checkpoint as save_checkpoint_impl from icefall.checkpoint import ( @@ -259,6 +260,13 @@ def add_model_arguments(parser: argparse.ArgumentParser): help="If True, use CTC head.", ) + parser.add_argument( + "--use-tcpgen-biasing", + type=str2bool, + default=False, + help="If True, use tcpgen context biasing module", + ) + def get_parser(): parser = argparse.ArgumentParser( @@ -526,6 +534,7 @@ def get_params() -> AttributeDict: "best_train_epoch": -1, "best_valid_epoch": -1, "batch_idx_train": 0, + "epoch_idx_train": 0, "log_interval": 50, "reset_interval": 200, "valid_interval": 3000, # For the 100h subset, use 800 @@ -534,6 +543,10 @@ def get_params() -> AttributeDict: "subsampling_factor": 4, # not passed in, this is fixed. "warm_step": 2000, "env_info": get_env_info(), + # parameters for tcpgen + "tcpgen_start_epoch": 10, + "num_distrators": 500, + "distractors_list": [], } ) @@ -628,9 +641,11 @@ def get_model(params: AttributeDict) -> nn.Module: joiner=joiner, encoder_dim=max(_to_int_tuple(params.encoder_dim)), decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, vocab_size=params.vocab_size, use_transducer=params.use_transducer, use_ctc=params.use_ctc, + use_tcpgen_biasing=params.use_tcpgen_biasing, ) return model @@ -688,6 +703,7 @@ def load_checkpoint_if_available( "best_train_epoch", "best_valid_epoch", "batch_idx_train", + "epoch_idx_train", "best_train_loss", "best_valid_loss", ] @@ -751,6 +767,40 @@ def save_checkpoint( copyfile(src=filename, dst=best_valid_filename) +def prepare_context_graph( + params: AttributeDict, + texts: List[str], + sp: spm.SentencePieceProcessor, +) -> ContextGraph: + if params.epoch_idx_train == params.start_epoch: + params.distractors_list += texts + return None + + logging.info(f"distractors_list : {len(params.distractors_list)}") + if params.epoch_idx_train >= params.tcpgen_start_epoch: + # logging.info("prepare context graph") + contexts_list = [] + selected_texts = [] + for i, text in enumerate(texts): + if random.random() >= 0.5: + continue + else: + selected_texts.append(text) + for i in range(params.num_distrators): + index = random.randint(0, len(params.distractors_list) - 1) + selected_texts.append(params.distractors_list[index]) + for st in selected_texts: + word_list = st.split() + start = random.randint(0, len(word_list) - 1) + length = random.randint(1, 3) + contexts_list.append(" ".join(word_list[start : start + length])) + contexts_tokens = sp.encode(contexts_list, out_type=int) + # logging.info(f"contexts_list : {contexts_list}, tokens : {contexts_tokens}") + context_graph = ContextGraph(0) + context_graph.build(contexts_tokens) + return context_graph + + def compute_loss( params: AttributeDict, model: Union[nn.Module, DDP], @@ -792,6 +842,11 @@ def compute_loss( y = sp.encode(texts, out_type=int) y = k2.RaggedTensor(y) + context_graph = None + if params.use_tcpgen_biasing: + context_graph = prepare_context_graph(params=params, texts=texts, sp=sp) + # assert context_graph is not None + with torch.set_grad_enabled(is_training): simple_loss, pruned_loss, ctc_loss = model( x=feature, @@ -800,6 +855,7 @@ def compute_loss( prune_range=params.prune_range, am_scale=params.am_scale, lm_scale=params.lm_scale, + context_graph=context_graph, ) loss = 0.0 @@ -939,6 +995,8 @@ def save_bad_model(suffix: str = ""): rank=0, ) + logging.info(f"epoch_idx_train : {params.epoch_idx_train}") + params.epoch_idx_train += 1 for batch_idx, batch in enumerate(train_dl): if batch_idx % 10 == 0: set_batch_count(model, get_adjusted_batch_count(params)) @@ -967,9 +1025,7 @@ def save_bad_model(suffix: str = ""): scaler.update() optimizer.zero_grad() except Exception as e: - logging.info( - f"Caught exception: {e}." - ) + logging.info(f"Caught exception: {e}.") save_bad_model() display_and_save_batch(batch, params=params, sp=sp) raise @@ -1177,16 +1233,16 @@ def run(rank, world_size, args): librispeech = LibriSpeechAsrDataModule(args) if params.full_libri: - train_cuts = librispeech.train_all_shuf_cuts() + # train_cuts = librispeech.train_all_shuf_cuts() # previously we used the following code to load all training cuts, # strictly speaking, shuffled training cuts should be used instead, # but we leave the code here to demonstrate that there is an option # like this to combine multiple cutsets - # train_cuts = librispeech.train_clean_100_cuts() - # train_cuts += librispeech.train_clean_360_cuts() - # train_cuts += librispeech.train_other_500_cuts() + train_cuts = librispeech.train_clean_100_cuts() + train_cuts += librispeech.train_clean_360_cuts() + train_cuts += librispeech.train_other_500_cuts() else: train_cuts = librispeech.train_clean_100_cuts() @@ -1205,26 +1261,6 @@ def remove_short_and_long_utt(c: Cut): # ) return False - # In pruned RNN-T, we require that T >= S - # where T is the number of feature frames after subsampling - # and S is the number of tokens in the utterance - - # In ./zipformer.py, the conv module uses the following expression - # for subsampling - T = ((c.num_frames - 7) // 2 + 1) // 2 - tokens = sp.encode(c.supervisions[0].text, out_type=str) - - if T < len(tokens): - logging.warning( - f"Exclude cut with ID {c.id} from training. " - f"Number of frames (before subsampling): {c.num_frames}. " - f"Number of frames (after subsampling): {T}. " - f"Text: {c.supervisions[0].text}. " - f"Tokens: {tokens}. " - f"Number of tokens: {len(tokens)}" - ) - return False - return True train_cuts = train_cuts.filter(remove_short_and_long_utt) @@ -1245,13 +1281,14 @@ def remove_short_and_long_utt(c: Cut): valid_dl = librispeech.valid_dataloaders(valid_cuts) if not params.print_diagnostics: - scan_pessimistic_batches_for_oom( - model=model, - train_dl=train_dl, - optimizer=optimizer, - sp=sp, - params=params, - ) + pass + # scan_pessimistic_batches_for_oom( + # model=model, + # train_dl=train_dl, + # optimizer=optimizer, + # sp=sp, + # params=params, + # ) scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: @@ -1397,4 +1434,5 @@ def main(): torch.set_num_interop_threads(1) if __name__ == "__main__": + torch.set_printoptions(profile="full") main() diff --git a/icefall/context_graph.py b/icefall/context_graph.py index 138bf4673b..dbeb2aacdd 100644 --- a/icefall/context_graph.py +++ b/icefall/context_graph.py @@ -162,6 +162,7 @@ def build( phrases: Optional[List[str]] = None, scores: Optional[List[float]] = None, ac_thresholds: Optional[List[float]] = None, + simple_trie: bool = False, ): """Build the ContextGraph from a list of token list. It first build a trie from the given token lists, then fill the fail arc @@ -189,6 +190,11 @@ def build( 0 means using the default value (i.e. self.ac_threshold). It is used only when this graph applied for the keywords spotting system. The length of `ac_threshold` MUST be equal to the length of `token_ids`. + simple_trie: + True for building only trie (i.e. no fail and output arcs). Needed by + tcpgen biasing training. + False for building a Aho-corasick automata, for hotword / keywords + searching. Note: The phrases would have shared states, the score of the shared states is the MAXIMUM value among all the tokens sharing this state. @@ -211,7 +217,6 @@ def build( context_score = self.context_score if score == 0.0 else score threshold = self.ac_threshold if ac_threshold == 0.0 else ac_threshold for i, token in enumerate(tokens): - node_next = {} if token not in node.next: self.num_nodes += 1 is_end = i == len(tokens) - 1 @@ -240,7 +245,9 @@ def build( node.next[token].phrase = phrase node.next[token].ac_threshold = threshold node = node.next[token] - self._fill_fail_output() + + if not simple_trie: + self._fill_fail_output() def forward_one_step( self, state: ContextState, token: int, strict_mode: bool = True