From f92c4e94c36dfd480cc86c4f084056324b5aa6c1 Mon Sep 17 00:00:00 2001 From: Dinghao Zhou Date: Fri, 1 Nov 2024 11:35:10 +0800 Subject: [PATCH] [cli] fix ts (#2649) * [cli] paraformer support batch infer * fix device * fix ts * fix lint --- wenet/cli/paraformer_model.py | 9 +++++---- wenet/paraformer/paraformer.py | 23 +++++++++++++---------- 2 files changed, 18 insertions(+), 14 deletions(-) diff --git a/wenet/cli/paraformer_model.py b/wenet/cli/paraformer_model.py index 20233255a5..81d04091b6 100644 --- a/wenet/cli/paraformer_model.py +++ b/wenet/cli/paraformer_model.py @@ -55,24 +55,25 @@ def transcribe_batch(self, feats_lst, batch_first=True).to(device=self.device) feats_lens_tensor = torch.tensor(feats_lens_lst, device=self.device) - decoder_out, token_num, tp_alphas = self.model.forward_paraformer( + decoder_out, token_num, tp_alphas, frames = self.model.forward_paraformer( feats_tensor, feats_lens_tensor) + frames = frames.cpu().numpy() cif_peaks = self.model.forward_cif_peaks(tp_alphas, token_num) results = paraformer_greedy_search(decoder_out, token_num, cif_peaks) r = [] - for res in results: + for (i, res) in enumerate(results): result = {} result['confidence'] = res.confidence result['text'] = self.tokenizer.detokenize(res.tokens)[0] if tokens_info: tokens_info_l = [] times = gen_timestamps_from_peak(res.times, - num_frames=tp_alphas.size(1), + num_frames=frames[i], frame_rate=0.02) - for i, x in enumerate(res.tokens): + for i, x in enumerate(res.tokens[:len(times)]): tokens_info_l.append({ 'token': self.tokenizer.char_dict[x], diff --git a/wenet/paraformer/paraformer.py b/wenet/paraformer/paraformer.py index be19f15b49..7824225640 100644 --- a/wenet/paraformer/paraformer.py +++ b/wenet/paraformer/paraformer.py @@ -19,9 +19,7 @@ import torch from wenet.paraformer.cif import Cif, cif_without_hidden - -from wenet.paraformer.layers import SanmDecoder, SanmEncoder -from wenet.paraformer.layers import LFR +from wenet.paraformer.layers import LFR, SanmDecoder, SanmEncoder from wenet.paraformer.search import (paraformer_beam_search, paraformer_greedy_search) from wenet.transformer.asr_model import ASRModel @@ -99,7 +97,8 @@ def forward(self, tp_alphas = tp_alphas.squeeze(-1) tp_token_num = tp_alphas.sum(-1) - return acoustic_embeds, token_num, alphas, cif_peak, tp_alphas, tp_token_num + return acoustic_embeds, token_num, alphas, cif_peak, tp_alphas, \ + tp_token_num, mask class Paraformer(ASRModel): @@ -170,7 +169,7 @@ def forward( if self.add_eos: _, ys_pad = add_sos_eos(text, self.sos, self.eos, self.ignore_id) ys_pad_lens = text_lengths + 1 - acoustic_embd, token_num, _, _, _, tp_token_num = self.predictor( + acoustic_embd, token_num, _, _, _, tp_token_num, _ = self.predictor( encoder_out, ys_pad, encoder_out_mask, self.ignore_id) # 2 decoder with sampler @@ -295,9 +294,10 @@ def forward_paraformer( self, speech: torch.Tensor, speech_lengths: torch.Tensor, - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: res = self._forward_paraformer(speech, speech_lengths) - return res['decoder_out'], res['decoder_out_lens'], res['tp_alphas'] + return res['decoder_out'], res['decoder_out_lens'], res[ + 'tp_alphas'], res['tp_mask'].sum(1).squeeze(-1) @torch.jit.export def forward_encoder_chunk( @@ -336,8 +336,10 @@ def _forward_paraformer( num_decoding_left_chunks) # cif predictor - acoustic_embed, token_num, _, _, tp_alphas, _ = self.predictor( - encoder_out, mask=encoder_out_mask) + acoustic_embed, token_num, _, _, tp_alphas, _, tp_mask = self.predictor( + encoder_out, + mask=encoder_out_mask, + ) token_num = token_num.floor().to(speech_lengths.dtype) # decoder @@ -350,7 +352,8 @@ def _forward_paraformer( "encoder_out_mask": encoder_out_mask, "decoder_out": decoder_out, "tp_alphas": tp_alphas, - "decoder_out_lens": token_num + "decoder_out_lens": token_num, + "tp_mask": tp_mask } def decode(