From 2c6a7dbdd32cfd4ae0289993db71bf6e96576ecb Mon Sep 17 00:00:00 2001 From: Mddct Date: Tue, 31 Oct 2023 12:27:02 +0800 Subject: [PATCH 1/2] [paraformer] delete training relatd and simplified paraformer --- .../{ali_paraformer => }/assets/config.yaml | 0 .../{ali_paraformer => }/assets/global_cmvn | 0 .../{ali_paraformer => }/assets/units.txt | 0 .../{ali_paraformer => }/attention.py | 0 .../{ali_paraformer => }/export_jit.py | 6 +- .../{ali_paraformer/model.py => layers.py} | 81 +------- wenet/paraformer/paraformer.py | 179 +++++------------- wenet/utils/init_model.py | 9 - 8 files changed, 53 insertions(+), 222 deletions(-) rename wenet/paraformer/{ali_paraformer => }/assets/config.yaml (100%) rename wenet/paraformer/{ali_paraformer => }/assets/global_cmvn (100%) rename wenet/paraformer/{ali_paraformer => }/assets/units.txt (100%) rename wenet/paraformer/{ali_paraformer => }/attention.py (100%) rename wenet/paraformer/{ali_paraformer => }/export_jit.py (92%) rename wenet/paraformer/{ali_paraformer/model.py => layers.py} (84%) diff --git a/wenet/paraformer/ali_paraformer/assets/config.yaml b/wenet/paraformer/assets/config.yaml similarity index 100% rename from wenet/paraformer/ali_paraformer/assets/config.yaml rename to wenet/paraformer/assets/config.yaml diff --git a/wenet/paraformer/ali_paraformer/assets/global_cmvn b/wenet/paraformer/assets/global_cmvn similarity index 100% rename from wenet/paraformer/ali_paraformer/assets/global_cmvn rename to wenet/paraformer/assets/global_cmvn diff --git a/wenet/paraformer/ali_paraformer/assets/units.txt b/wenet/paraformer/assets/units.txt similarity index 100% rename from wenet/paraformer/ali_paraformer/assets/units.txt rename to wenet/paraformer/assets/units.txt diff --git a/wenet/paraformer/ali_paraformer/attention.py b/wenet/paraformer/attention.py similarity index 100% rename from wenet/paraformer/ali_paraformer/attention.py rename to wenet/paraformer/attention.py diff --git a/wenet/paraformer/ali_paraformer/export_jit.py b/wenet/paraformer/export_jit.py similarity index 92% rename from wenet/paraformer/ali_paraformer/export_jit.py rename to wenet/paraformer/export_jit.py index 8d38e5ff0..6b7f8531b 100644 --- a/wenet/paraformer/ali_paraformer/export_jit.py +++ b/wenet/paraformer/export_jit.py @@ -5,8 +5,8 @@ import torch import yaml from wenet.cif.predictor import Predictor -from wenet.paraformer.ali_paraformer.model import (AliParaformer, SanmDecoer, - SanmEncoder) +from wenet.paraformer.layers import (SanmDecoer, SanmEncoder) +from wenet.paraformer.paraformer import Paraformer from wenet.transformer.cmvn import GlobalCMVN from wenet.utils.checkpoint import load_checkpoint from wenet.utils.cmvn import load_cmvn @@ -41,7 +41,7 @@ def init_model(configs): encoder_output_size=encoder.output_size(), **configs['decoder_conf']) predictor = Predictor(**configs['cif_predictor_conf']) - model = AliParaformer( + model = Paraformer( encoder=encoder, decoder=decoder, predictor=predictor, diff --git a/wenet/paraformer/ali_paraformer/model.py b/wenet/paraformer/layers.py similarity index 84% rename from wenet/paraformer/ali_paraformer/model.py rename to wenet/paraformer/layers.py index 2d507c4bd..d408ee492 100644 --- a/wenet/paraformer/ali_paraformer/model.py +++ b/wenet/paraformer/layers.py @@ -5,10 +5,9 @@ from typing import Dict, List, Optional, Tuple import torch from wenet.cif.predictor import Predictor -from wenet.paraformer.ali_paraformer.attention import (DummyMultiHeadSANM, - MultiHeadAttentionCross, - MultiHeadedAttentionSANM - ) +from wenet.paraformer.attention import (DummyMultiHeadSANM, + MultiHeadAttentionCross, + MultiHeadedAttentionSANM) from wenet.paraformer.search import paraformer_beam_search, paraformer_greedy_search from wenet.transformer.search import DecodeResult from wenet.transformer.encoder import BaseEncoder @@ -455,77 +454,3 @@ def forward( if self.output_layer is not None: x = self.output_layer(x) return x, torch.tensor(0.0), ys_pad_lens - - -class AliParaformer(torch.nn.Module): - - def __init__(self, encoder: SanmEncoder, decoder: SanmDecoer, - predictor: Predictor): - super().__init__() - self.encoder = encoder - self.decoder = decoder - self.predictor = predictor - self.lfr = LFR() - self.sos = 1 - self.eos = 2 - - @torch.jit.ignore(drop=True) - def forward( - self, speech: torch.Tensor, speech_lengths: torch.Tensor, - text: torch.Tensor, - text_lengths: torch.Tensor) -> Dict[str, Optional[torch.Tensor]]: - raise NotImplementedError - - @torch.jit.export - def forward_paraformer( - self, - speech: torch.Tensor, - speech_lengths: torch.Tensor, - ) -> Tuple[torch.Tensor, torch.Tensor]: - features, features_lens = self.lfr(speech, speech_lengths) - features_lens = features_lens.to(speech_lengths.dtype) - # encoder - encoder_out, encoder_out_mask = self.encoder(features, features_lens) - - # cif predictor - acoustic_embed, token_num, _, _ = self.predictor(encoder_out, - mask=encoder_out_mask) - token_num = token_num.floor().to(speech_lengths.dtype) - - # decoder - decoder_out, _, _ = self.decoder(encoder_out, encoder_out_mask, - acoustic_embed, token_num) - decoder_out = decoder_out.log_softmax(dim=-1) - return decoder_out, token_num - - def decode(self, - methods: List[str], - speech: torch.Tensor, - speech_lengths: torch.Tensor, - beam_size: int, - decoding_chunk_size: int = -1, - num_decoding_left_chunks: int = -1, - ctc_weight: float = 0, - simulate_streaming: bool = False, - reverse_weight: float = 0) -> Dict[str, List[DecodeResult]]: - decoder_out, decoder_out_lens = self.forward_paraformer( - speech, speech_lengths) - - results = {} - if 'paraformer_greedy_search' in methods: - assert decoder_out is not None - assert decoder_out_lens is not None - paraformer_greedy_result = paraformer_greedy_search( - decoder_out, decoder_out_lens) - results['paraformer_greedy_search'] = paraformer_greedy_result - if 'paraformer_beam_search' in methods: - assert decoder_out is not None - assert decoder_out_lens is not None - paraformer_beam_result = paraformer_beam_search( - decoder_out, - decoder_out_lens, - beam_size=beam_size, - eos=self.eos) - results['paraformer_beam_search'] = paraformer_beam_result - - return results diff --git a/wenet/paraformer/paraformer.py b/wenet/paraformer/paraformer.py index b4ed44150..f53d72c92 100644 --- a/wenet/paraformer/paraformer.py +++ b/wenet/paraformer/paraformer.py @@ -15,54 +15,41 @@ # Modified from ESPnet(https://github.com/espnet/espnet) and # FunASR(https://github.com/alibaba-damo-academy/FunASR) +# NOTE: This file is only for loading ali-paraformer-large-model and inference + from typing import Dict, List, Optional, Tuple import torch -from wenet.cif.predictor import MAELoss -from wenet.paraformer.search import paraformer_beam_search, paraformer_greedy_search -from wenet.transformer.asr_model import ASRModel -from wenet.transformer.ctc import CTC -from wenet.transformer.decoder import TransformerDecoder -from wenet.transformer.encoder import TransformerEncoder -from wenet.transformer.search import (DecodeResult, ctc_greedy_search, - ctc_prefix_beam_search) -from wenet.utils.common import (IGNORE_ID, add_sos_eos, th_accuracy) +from wenet.cif.predictor import Predictor +from wenet.paraformer.layers import SanmDecoer, SanmEncoder +from wenet.paraformer.layers import LFR +from wenet.paraformer.search import (paraformer_beam_search, + paraformer_greedy_search) +from wenet.transformer.search import DecodeResult from wenet.utils.mask import make_pad_mask -class Paraformer(ASRModel): +class Paraformer(torch.nn.Module): """ Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition see https://arxiv.org/pdf/2206.08317.pdf + """ - def __init__( - self, - vocab_size: int, - encoder: TransformerEncoder, - decoder: TransformerDecoder, - ctc: CTC, - predictor, - ctc_weight: float = 0.5, - predictor_weight: float = 1.0, - predictor_bias: int = 0, - ignore_id: int = IGNORE_ID, - reverse_weight: float = 0.0, - lsm_weight: float = 0.0, - length_normalized_loss: bool = False, - ): - assert 0.0 <= ctc_weight <= 1.0, ctc_weight - assert 0.0 <= predictor_weight <= 1.0, predictor_weight - - super().__init__(vocab_size, encoder, decoder, ctc, ctc_weight, - ignore_id, reverse_weight, lsm_weight, - length_normalized_loss) + def __init__(self, encoder: SanmEncoder, decoder: SanmDecoer, + predictor: Predictor): + + super().__init__() + self.lfr = LFR() + self.encoder = encoder + self.decoder = decoder self.predictor = predictor - self.predictor_weight = predictor_weight - self.predictor_bias = predictor_bias - self.criterion_pre = MAELoss(normalize_length=length_normalized_loss) + self.sos = 1 + self.eos = 2 + + @torch.jit.ignore(drop=True) def forward( self, speech: torch.Tensor, @@ -78,76 +65,7 @@ def forward( text: (Batch, Length) text_lengths: (Batch,) """ - assert text_lengths.dim() == 1, text_lengths.shape - # Check that batch_size is unified - assert (speech.shape[0] == speech_lengths.shape[0] == text.shape[0] == - text_lengths.shape[0]), (speech.shape, speech_lengths.shape, - text.shape, text_lengths.shape) - # 1. Encoder - encoder_out, encoder_mask = self.encoder(speech, speech_lengths) - encoder_out_lens = encoder_mask.squeeze(1).sum(1) - - # 2a. Attention-decoder branch - if self.ctc_weight != 1.0: - loss_att, acc_att, loss_pre = self._calc_att_loss( - encoder_out, encoder_mask, text, text_lengths) - else: - # loss_att = None - # loss_pre = None - loss_att: torch.Tensor = torch.tensor(0) - loss_pre: torch.Tensor = torch.tensor(0) - - # 2b. CTC branch - if self.ctc_weight != 0.0: - loss_ctc = self.ctc(encoder_out, encoder_out_lens, text, - text_lengths) - else: - loss_ctc = None - - if loss_ctc is None: - loss = loss_att + self.predictor_weight * loss_pre - # elif loss_att is None: - elif loss_att == torch.tensor(0): - loss = loss_ctc - else: - loss = self.ctc_weight * loss_ctc + \ - (1 - self.ctc_weight) * loss_att + \ - self.predictor_weight * loss_pre - return { - "loss": loss, - "loss_att": loss_att, - "loss_ctc": loss_ctc, - "loss_pre": loss_pre - } - - def _calc_att_loss( - self, - encoder_out: torch.Tensor, - encoder_mask: torch.Tensor, - ys_pad: torch.Tensor, - ys_pad_lens: torch.Tensor, - ) -> Tuple[torch.Tensor, float, torch.Tensor]: - if self.predictor_bias == 1: - _, ys_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id) - ys_pad_lens = ys_pad_lens + self.predictor_bias - pre_acoustic_embeds, pre_token_length, _, pre_peak_index = \ - self.predictor(encoder_out, ys_pad, encoder_mask, - ignore_id=self.ignore_id) - # 1. Forward decoder - decoder_out, _, _ = self.decoder(encoder_out, encoder_mask, - pre_acoustic_embeds, ys_pad_lens) - - # 2. Compute attention loss - loss_att = self.criterion_att(decoder_out, ys_pad) - acc_att = th_accuracy( - decoder_out.view(-1, self.vocab_size), - ys_pad, - ignore_label=self.ignore_id, - ) - loss_pre: torch.Tensor = self.criterion_pre( - ys_pad_lens.type_as(pre_token_length), pre_token_length) - - return loss_att, acc_att, loss_pre + raise NotImplementedError def calc_predictor(self, encoder_out, encoder_out_lens): encoder_mask = (~make_pad_mask( @@ -166,6 +84,28 @@ def cal_decoder_with_predictor(self, encoder_out, encoder_mask, sematic_embeds, ys_pad_lens) return decoder_out, ys_pad_lens + @torch.jit.export + def forward_paraformer( + self, + speech: torch.Tensor, + speech_lengths: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + features, features_lens = self.lfr(speech, speech_lengths) + features_lens = features_lens.to(speech_lengths.dtype) + # encoder + encoder_out, encoder_out_mask = self.encoder(features, features_lens) + + # cif predictor + acoustic_embed, token_num, _, _ = self.predictor(encoder_out, + mask=encoder_out_mask) + token_num = token_num.floor().to(speech_lengths.dtype) + + # decoder + decoder_out, _, _ = self.decoder(encoder_out, encoder_out_mask, + acoustic_embed, token_num) + decoder_out = decoder_out.log_softmax(dim=-1) + return decoder_out, token_num + def decode(self, methods: List[str], speech: torch.Tensor, @@ -176,35 +116,10 @@ def decode(self, ctc_weight: float = 0, simulate_streaming: bool = False, reverse_weight: float = 0) -> Dict[str, List[DecodeResult]]: - assert speech.shape[0] == speech_lengths.shape[0] - assert decoding_chunk_size != 0 - encoder_out, encoder_mask = self._forward_encoder( - speech, speech_lengths, decoding_chunk_size, - num_decoding_left_chunks, simulate_streaming) - encoder_lens = encoder_mask.squeeze(1).sum(1) - results = {} + decoder_out, decoder_out_lens = self.forward_paraformer( + speech, speech_lengths) - ctc_probs: Optional[torch.Tensor] = None - if 'ctc_greedy_search' in methods: - ctc_probs = self.ctc.log_softmax(encoder_out) - results['ctc_greedy_search'] = ctc_greedy_search( - ctc_probs, encoder_lens) - if 'ctc_prefix_beam_search' in methods: - if ctc_probs is None: - ctc_probs = self.ctc.log_softmax(encoder_out) - ctc_prefix_result = ctc_prefix_beam_search(ctc_probs, encoder_lens, - beam_size) - results['ctc_prefix_beam_search'] = ctc_prefix_result - - decoder_out: Optional[torch.Tensor] = None - decoder_out_lens: Optional[torch.Tensor] = None - # TODO(Mddct): add timestamp from predictor's alpha - if ('paraformer_greedy_search' in methods - or 'paraformer_beam_search' in methods): - acoustic_embed, token_nums, _, _ = self.calc_predictor( - encoder_out, encoder_lens) - decoder_out, decoder_out_lens = self.cal_decoder_with_predictor( - encoder_out, encoder_mask, acoustic_embed, token_nums) + results = {} if 'paraformer_greedy_search' in methods: assert decoder_out is not None assert decoder_out_lens is not None diff --git a/wenet/utils/init_model.py b/wenet/utils/init_model.py index 4c8544ced..80df97d07 100644 --- a/wenet/utils/init_model.py +++ b/wenet/utils/init_model.py @@ -28,7 +28,6 @@ from wenet.e_branchformer.encoder import EBranchformerEncoder from wenet.squeezeformer.encoder import SqueezeformerEncoder from wenet.efficient_conformer.encoder import EfficientConformerEncoder -from wenet.paraformer.paraformer import Paraformer from wenet.cif.predictor import Predictor from wenet.utils.cmvn import load_cmvn @@ -115,14 +114,6 @@ def init_model(configs): joint=joint, ctc=ctc, **configs['model_conf']) - elif 'paraformer' in configs: - predictor = Predictor(**configs['cif_predictor_conf']) - model = Paraformer(vocab_size=vocab_size, - encoder=encoder, - decoder=decoder, - ctc=ctc, - predictor=predictor, - **configs['model_conf']) else: print(configs) if configs.get('lfmmi_dir', '') != '': From 43265f6215ff200745c61fe9f83be27a241de994 Mon Sep 17 00:00:00 2001 From: Mddct Date: Tue, 31 Oct 2023 12:32:48 +0800 Subject: [PATCH 2/2] fix lint --- wenet/cli/paraformer_model.py | 2 +- wenet/paraformer/layers.py | 5 +---- wenet/paraformer/paraformer.py | 2 +- wenet/utils/init_model.py | 1 - 4 files changed, 3 insertions(+), 7 deletions(-) diff --git a/wenet/cli/paraformer_model.py b/wenet/cli/paraformer_model.py index 1134e13f8..fab4e0090 100644 --- a/wenet/cli/paraformer_model.py +++ b/wenet/cli/paraformer_model.py @@ -57,7 +57,7 @@ def transcribe(self, audio_file: str, tokens_info: bool = False) -> dict: return result def align(self, audio_file: str, label: str) -> dict: - raise NotImplementedError + raise NotImplementedError("Align is currently not supported") def load_model(language: str = None, model_dir: str = None) -> Paraformer: diff --git a/wenet/paraformer/layers.py b/wenet/paraformer/layers.py index d408ee492..064e4ca48 100644 --- a/wenet/paraformer/layers.py +++ b/wenet/paraformer/layers.py @@ -2,14 +2,11 @@ """ import math -from typing import Dict, List, Optional, Tuple +from typing import Optional, Tuple import torch -from wenet.cif.predictor import Predictor from wenet.paraformer.attention import (DummyMultiHeadSANM, MultiHeadAttentionCross, MultiHeadedAttentionSANM) -from wenet.paraformer.search import paraformer_beam_search, paraformer_greedy_search -from wenet.transformer.search import DecodeResult from wenet.transformer.encoder import BaseEncoder from wenet.transformer.decoder import TransformerDecoder from wenet.transformer.decoder_layer import DecoderLayer diff --git a/wenet/paraformer/paraformer.py b/wenet/paraformer/paraformer.py index f53d72c92..779eb87b7 100644 --- a/wenet/paraformer/paraformer.py +++ b/wenet/paraformer/paraformer.py @@ -65,7 +65,7 @@ def forward( text: (Batch, Length) text_lengths: (Batch,) """ - raise NotImplementedError + raise NotImplementedError("Training is currently not supported") def calc_predictor(self, encoder_out, encoder_out_lens): encoder_mask = (~make_pad_mask( diff --git a/wenet/utils/init_model.py b/wenet/utils/init_model.py index 80df97d07..dff0bd534 100644 --- a/wenet/utils/init_model.py +++ b/wenet/utils/init_model.py @@ -28,7 +28,6 @@ from wenet.e_branchformer.encoder import EBranchformerEncoder from wenet.squeezeformer.encoder import SqueezeformerEncoder from wenet.efficient_conformer.encoder import EfficientConformerEncoder -from wenet.cif.predictor import Predictor from wenet.utils.cmvn import load_cmvn