Skip to content

Commit

Permalink
[cli] fix ts (#2649)
Browse files Browse the repository at this point in the history
* [cli] paraformer support batch infer

* fix device

* fix ts

* fix lint
  • Loading branch information
Mddct authored Nov 1, 2024
1 parent 93eb806 commit f92c4e9
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 14 deletions.
9 changes: 5 additions & 4 deletions wenet/cli/paraformer_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,24 +55,25 @@ def transcribe_batch(self,
feats_lst, batch_first=True).to(device=self.device)
feats_lens_tensor = torch.tensor(feats_lens_lst, device=self.device)

decoder_out, token_num, tp_alphas = self.model.forward_paraformer(
decoder_out, token_num, tp_alphas, frames = self.model.forward_paraformer(
feats_tensor, feats_lens_tensor)
frames = frames.cpu().numpy()
cif_peaks = self.model.forward_cif_peaks(tp_alphas, token_num)

results = paraformer_greedy_search(decoder_out, token_num, cif_peaks)

r = []
for res in results:
for (i, res) in enumerate(results):
result = {}
result['confidence'] = res.confidence
result['text'] = self.tokenizer.detokenize(res.tokens)[0]
if tokens_info:
tokens_info_l = []
times = gen_timestamps_from_peak(res.times,
num_frames=tp_alphas.size(1),
num_frames=frames[i],
frame_rate=0.02)

for i, x in enumerate(res.tokens):
for i, x in enumerate(res.tokens[:len(times)]):
tokens_info_l.append({
'token':
self.tokenizer.char_dict[x],
Expand Down
23 changes: 13 additions & 10 deletions wenet/paraformer/paraformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,7 @@

import torch
from wenet.paraformer.cif import Cif, cif_without_hidden

from wenet.paraformer.layers import SanmDecoder, SanmEncoder
from wenet.paraformer.layers import LFR
from wenet.paraformer.layers import LFR, SanmDecoder, SanmEncoder
from wenet.paraformer.search import (paraformer_beam_search,
paraformer_greedy_search)
from wenet.transformer.asr_model import ASRModel
Expand Down Expand Up @@ -99,7 +97,8 @@ def forward(self,
tp_alphas = tp_alphas.squeeze(-1)
tp_token_num = tp_alphas.sum(-1)

return acoustic_embeds, token_num, alphas, cif_peak, tp_alphas, tp_token_num
return acoustic_embeds, token_num, alphas, cif_peak, tp_alphas, \
tp_token_num, mask


class Paraformer(ASRModel):
Expand Down Expand Up @@ -170,7 +169,7 @@ def forward(
if self.add_eos:
_, ys_pad = add_sos_eos(text, self.sos, self.eos, self.ignore_id)
ys_pad_lens = text_lengths + 1
acoustic_embd, token_num, _, _, _, tp_token_num = self.predictor(
acoustic_embd, token_num, _, _, _, tp_token_num, _ = self.predictor(
encoder_out, ys_pad, encoder_out_mask, self.ignore_id)

# 2 decoder with sampler
Expand Down Expand Up @@ -295,9 +294,10 @@ def forward_paraformer(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
res = self._forward_paraformer(speech, speech_lengths)
return res['decoder_out'], res['decoder_out_lens'], res['tp_alphas']
return res['decoder_out'], res['decoder_out_lens'], res[
'tp_alphas'], res['tp_mask'].sum(1).squeeze(-1)

@torch.jit.export
def forward_encoder_chunk(
Expand Down Expand Up @@ -336,8 +336,10 @@ def _forward_paraformer(
num_decoding_left_chunks)

# cif predictor
acoustic_embed, token_num, _, _, tp_alphas, _ = self.predictor(
encoder_out, mask=encoder_out_mask)
acoustic_embed, token_num, _, _, tp_alphas, _, tp_mask = self.predictor(
encoder_out,
mask=encoder_out_mask,
)
token_num = token_num.floor().to(speech_lengths.dtype)

# decoder
Expand All @@ -350,7 +352,8 @@ def _forward_paraformer(
"encoder_out_mask": encoder_out_mask,
"decoder_out": decoder_out,
"tp_alphas": tp_alphas,
"decoder_out_lens": token_num
"decoder_out_lens": token_num,
"tp_mask": tp_mask
}

def decode(
Expand Down

0 comments on commit f92c4e9

Please sign in to comment.