Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[paraformer] delete training relatd and simplified paraformer #2093

Merged
merged 2 commits into from
Oct 31, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion wenet/cli/paraformer_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def transcribe(self, audio_file: str, tokens_info: bool = False) -> dict:
return result

def align(self, audio_file: str, label: str) -> dict:
raise NotImplementedError
raise NotImplementedError("Align is currently not supported")


def load_model(language: str = None, model_dir: str = None) -> Paraformer:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
import torch
import yaml
from wenet.cif.predictor import Predictor
from wenet.paraformer.ali_paraformer.model import (AliParaformer, SanmDecoer,
SanmEncoder)
from wenet.paraformer.layers import (SanmDecoer, SanmEncoder)
from wenet.paraformer.paraformer import Paraformer
from wenet.transformer.cmvn import GlobalCMVN
from wenet.utils.checkpoint import load_checkpoint
from wenet.utils.cmvn import load_cmvn
Expand Down Expand Up @@ -41,7 +41,7 @@ def init_model(configs):
encoder_output_size=encoder.output_size(),
**configs['decoder_conf'])
predictor = Predictor(**configs['cif_predictor_conf'])
model = AliParaformer(
model = Paraformer(
encoder=encoder,
decoder=decoder,
predictor=predictor,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,11 @@
"""

import math
from typing import Dict, List, Optional, Tuple
from typing import Optional, Tuple
import torch
from wenet.cif.predictor import Predictor
from wenet.paraformer.ali_paraformer.attention import (DummyMultiHeadSANM,
MultiHeadAttentionCross,
MultiHeadedAttentionSANM
)
from wenet.paraformer.search import paraformer_beam_search, paraformer_greedy_search
from wenet.transformer.search import DecodeResult
from wenet.paraformer.attention import (DummyMultiHeadSANM,
MultiHeadAttentionCross,
MultiHeadedAttentionSANM)
from wenet.transformer.encoder import BaseEncoder
from wenet.transformer.decoder import TransformerDecoder
from wenet.transformer.decoder_layer import DecoderLayer
Expand Down Expand Up @@ -455,77 +451,3 @@ def forward(
if self.output_layer is not None:
x = self.output_layer(x)
return x, torch.tensor(0.0), ys_pad_lens


class AliParaformer(torch.nn.Module):

def __init__(self, encoder: SanmEncoder, decoder: SanmDecoer,
predictor: Predictor):
super().__init__()
self.encoder = encoder
self.decoder = decoder
self.predictor = predictor
self.lfr = LFR()
self.sos = 1
self.eos = 2

@torch.jit.ignore(drop=True)
def forward(
self, speech: torch.Tensor, speech_lengths: torch.Tensor,
text: torch.Tensor,
text_lengths: torch.Tensor) -> Dict[str, Optional[torch.Tensor]]:
raise NotImplementedError

@torch.jit.export
def forward_paraformer(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
features, features_lens = self.lfr(speech, speech_lengths)
features_lens = features_lens.to(speech_lengths.dtype)
# encoder
encoder_out, encoder_out_mask = self.encoder(features, features_lens)

# cif predictor
acoustic_embed, token_num, _, _ = self.predictor(encoder_out,
mask=encoder_out_mask)
token_num = token_num.floor().to(speech_lengths.dtype)

# decoder
decoder_out, _, _ = self.decoder(encoder_out, encoder_out_mask,
acoustic_embed, token_num)
decoder_out = decoder_out.log_softmax(dim=-1)
return decoder_out, token_num

def decode(self,
methods: List[str],
speech: torch.Tensor,
speech_lengths: torch.Tensor,
beam_size: int,
decoding_chunk_size: int = -1,
num_decoding_left_chunks: int = -1,
ctc_weight: float = 0,
simulate_streaming: bool = False,
reverse_weight: float = 0) -> Dict[str, List[DecodeResult]]:
decoder_out, decoder_out_lens = self.forward_paraformer(
speech, speech_lengths)

results = {}
if 'paraformer_greedy_search' in methods:
assert decoder_out is not None
assert decoder_out_lens is not None
paraformer_greedy_result = paraformer_greedy_search(
decoder_out, decoder_out_lens)
results['paraformer_greedy_search'] = paraformer_greedy_result
if 'paraformer_beam_search' in methods:
assert decoder_out is not None
assert decoder_out_lens is not None
paraformer_beam_result = paraformer_beam_search(
decoder_out,
decoder_out_lens,
beam_size=beam_size,
eos=self.eos)
results['paraformer_beam_search'] = paraformer_beam_result

return results
179 changes: 47 additions & 132 deletions wenet/paraformer/paraformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,54 +15,41 @@
# Modified from ESPnet(https://github.com/espnet/espnet) and
# FunASR(https://github.com/alibaba-damo-academy/FunASR)

# NOTE: This file is only for loading ali-paraformer-large-model and inference

from typing import Dict, List, Optional, Tuple

import torch

from wenet.cif.predictor import MAELoss
from wenet.paraformer.search import paraformer_beam_search, paraformer_greedy_search
from wenet.transformer.asr_model import ASRModel
from wenet.transformer.ctc import CTC
from wenet.transformer.decoder import TransformerDecoder
from wenet.transformer.encoder import TransformerEncoder
from wenet.transformer.search import (DecodeResult, ctc_greedy_search,
ctc_prefix_beam_search)
from wenet.utils.common import (IGNORE_ID, add_sos_eos, th_accuracy)
from wenet.cif.predictor import Predictor
from wenet.paraformer.layers import SanmDecoer, SanmEncoder
from wenet.paraformer.layers import LFR
from wenet.paraformer.search import (paraformer_beam_search,
paraformer_greedy_search)
from wenet.transformer.search import DecodeResult
from wenet.utils.mask import make_pad_mask


class Paraformer(ASRModel):
class Paraformer(torch.nn.Module):
""" Paraformer: Fast and Accurate Parallel Transformer for
Non-autoregressive End-to-End Speech Recognition
see https://arxiv.org/pdf/2206.08317.pdf

"""

def __init__(
self,
vocab_size: int,
encoder: TransformerEncoder,
decoder: TransformerDecoder,
ctc: CTC,
predictor,
ctc_weight: float = 0.5,
predictor_weight: float = 1.0,
predictor_bias: int = 0,
ignore_id: int = IGNORE_ID,
reverse_weight: float = 0.0,
lsm_weight: float = 0.0,
length_normalized_loss: bool = False,
):
assert 0.0 <= ctc_weight <= 1.0, ctc_weight
assert 0.0 <= predictor_weight <= 1.0, predictor_weight

super().__init__(vocab_size, encoder, decoder, ctc, ctc_weight,
ignore_id, reverse_weight, lsm_weight,
length_normalized_loss)
def __init__(self, encoder: SanmEncoder, decoder: SanmDecoer,
predictor: Predictor):

super().__init__()
self.lfr = LFR()
self.encoder = encoder
self.decoder = decoder
self.predictor = predictor
self.predictor_weight = predictor_weight
self.predictor_bias = predictor_bias
self.criterion_pre = MAELoss(normalize_length=length_normalized_loss)

self.sos = 1
self.eos = 2

@torch.jit.ignore(drop=True)
def forward(
self,
speech: torch.Tensor,
Expand All @@ -78,76 +65,7 @@ def forward(
text: (Batch, Length)
text_lengths: (Batch,)
"""
assert text_lengths.dim() == 1, text_lengths.shape
# Check that batch_size is unified
assert (speech.shape[0] == speech_lengths.shape[0] == text.shape[0] ==
text_lengths.shape[0]), (speech.shape, speech_lengths.shape,
text.shape, text_lengths.shape)
# 1. Encoder
encoder_out, encoder_mask = self.encoder(speech, speech_lengths)
encoder_out_lens = encoder_mask.squeeze(1).sum(1)

# 2a. Attention-decoder branch
if self.ctc_weight != 1.0:
loss_att, acc_att, loss_pre = self._calc_att_loss(
encoder_out, encoder_mask, text, text_lengths)
else:
# loss_att = None
# loss_pre = None
loss_att: torch.Tensor = torch.tensor(0)
loss_pre: torch.Tensor = torch.tensor(0)

# 2b. CTC branch
if self.ctc_weight != 0.0:
loss_ctc = self.ctc(encoder_out, encoder_out_lens, text,
text_lengths)
else:
loss_ctc = None

if loss_ctc is None:
loss = loss_att + self.predictor_weight * loss_pre
# elif loss_att is None:
elif loss_att == torch.tensor(0):
loss = loss_ctc
else:
loss = self.ctc_weight * loss_ctc + \
(1 - self.ctc_weight) * loss_att + \
self.predictor_weight * loss_pre
return {
"loss": loss,
"loss_att": loss_att,
"loss_ctc": loss_ctc,
"loss_pre": loss_pre
}

def _calc_att_loss(
self,
encoder_out: torch.Tensor,
encoder_mask: torch.Tensor,
ys_pad: torch.Tensor,
ys_pad_lens: torch.Tensor,
) -> Tuple[torch.Tensor, float, torch.Tensor]:
if self.predictor_bias == 1:
_, ys_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
ys_pad_lens = ys_pad_lens + self.predictor_bias
pre_acoustic_embeds, pre_token_length, _, pre_peak_index = \
self.predictor(encoder_out, ys_pad, encoder_mask,
ignore_id=self.ignore_id)
# 1. Forward decoder
decoder_out, _, _ = self.decoder(encoder_out, encoder_mask,
pre_acoustic_embeds, ys_pad_lens)

# 2. Compute attention loss
loss_att = self.criterion_att(decoder_out, ys_pad)
acc_att = th_accuracy(
decoder_out.view(-1, self.vocab_size),
ys_pad,
ignore_label=self.ignore_id,
)
loss_pre: torch.Tensor = self.criterion_pre(
ys_pad_lens.type_as(pre_token_length), pre_token_length)

return loss_att, acc_att, loss_pre
raise NotImplementedError("Training is currently not supported")

def calc_predictor(self, encoder_out, encoder_out_lens):
encoder_mask = (~make_pad_mask(
Expand All @@ -166,6 +84,28 @@ def cal_decoder_with_predictor(self, encoder_out, encoder_mask,
sematic_embeds, ys_pad_lens)
return decoder_out, ys_pad_lens

@torch.jit.export
def forward_paraformer(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
features, features_lens = self.lfr(speech, speech_lengths)
features_lens = features_lens.to(speech_lengths.dtype)
# encoder
encoder_out, encoder_out_mask = self.encoder(features, features_lens)

# cif predictor
acoustic_embed, token_num, _, _ = self.predictor(encoder_out,
mask=encoder_out_mask)
token_num = token_num.floor().to(speech_lengths.dtype)

# decoder
decoder_out, _, _ = self.decoder(encoder_out, encoder_out_mask,
acoustic_embed, token_num)
decoder_out = decoder_out.log_softmax(dim=-1)
return decoder_out, token_num

def decode(self,
methods: List[str],
speech: torch.Tensor,
Expand All @@ -176,35 +116,10 @@ def decode(self,
ctc_weight: float = 0,
simulate_streaming: bool = False,
reverse_weight: float = 0) -> Dict[str, List[DecodeResult]]:
assert speech.shape[0] == speech_lengths.shape[0]
assert decoding_chunk_size != 0
encoder_out, encoder_mask = self._forward_encoder(
speech, speech_lengths, decoding_chunk_size,
num_decoding_left_chunks, simulate_streaming)
encoder_lens = encoder_mask.squeeze(1).sum(1)
results = {}
decoder_out, decoder_out_lens = self.forward_paraformer(
speech, speech_lengths)

ctc_probs: Optional[torch.Tensor] = None
if 'ctc_greedy_search' in methods:
ctc_probs = self.ctc.log_softmax(encoder_out)
results['ctc_greedy_search'] = ctc_greedy_search(
ctc_probs, encoder_lens)
if 'ctc_prefix_beam_search' in methods:
if ctc_probs is None:
ctc_probs = self.ctc.log_softmax(encoder_out)
ctc_prefix_result = ctc_prefix_beam_search(ctc_probs, encoder_lens,
beam_size)
results['ctc_prefix_beam_search'] = ctc_prefix_result

decoder_out: Optional[torch.Tensor] = None
decoder_out_lens: Optional[torch.Tensor] = None
# TODO(Mddct): add timestamp from predictor's alpha
if ('paraformer_greedy_search' in methods
or 'paraformer_beam_search' in methods):
acoustic_embed, token_nums, _, _ = self.calc_predictor(
encoder_out, encoder_lens)
decoder_out, decoder_out_lens = self.cal_decoder_with_predictor(
encoder_out, encoder_mask, acoustic_embed, token_nums)
results = {}
if 'paraformer_greedy_search' in methods:
assert decoder_out is not None
assert decoder_out_lens is not None
Expand Down
10 changes: 0 additions & 10 deletions wenet/utils/init_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,6 @@
from wenet.e_branchformer.encoder import EBranchformerEncoder
from wenet.squeezeformer.encoder import SqueezeformerEncoder
from wenet.efficient_conformer.encoder import EfficientConformerEncoder
from wenet.paraformer.paraformer import Paraformer
from wenet.cif.predictor import Predictor
from wenet.utils.cmvn import load_cmvn


Expand Down Expand Up @@ -115,14 +113,6 @@ def init_model(configs):
joint=joint,
ctc=ctc,
**configs['model_conf'])
elif 'paraformer' in configs:
predictor = Predictor(**configs['cif_predictor_conf'])
model = Paraformer(vocab_size=vocab_size,
encoder=encoder,
decoder=decoder,
ctc=ctc,
predictor=predictor,
**configs['model_conf'])
else:
print(configs)
if configs.get('lfmmi_dir', '') != '':
Expand Down
Loading