diff --git a/wenet/transformer/asr_model.py b/wenet/transformer/asr_model.py index d790a0919..20503ef02 100644 --- a/wenet/transformer/asr_model.py +++ b/wenet/transformer/asr_model.py @@ -44,14 +44,16 @@ def __init__( reverse_weight: float = 0.0, lsm_weight: float = 0.0, length_normalized_loss: bool = False, + special_tokens: dict = None, ): assert 0.0 <= ctc_weight <= 1.0, ctc_weight super().__init__() # note that eos is the same as sos (equivalent ID) - self.sos = vocab_size - 1 - self.eos = vocab_size - 1 + self.sos = special_tokens.get("sos", vocab_size - 1) + self.eos = special_tokens.get("eos", vocab_size - 1) self.vocab_size = vocab_size + self.special_tokens = special_tokens self.ignore_id = ignore_id self.ctc_weight = ctc_weight self.reverse_weight = reverse_weight diff --git a/wenet/transformer/search.py b/wenet/transformer/search.py index 9f8827a85..7140897fe 100644 --- a/wenet/transformer/search.py +++ b/wenet/transformer/search.py @@ -19,7 +19,8 @@ import torch from torch.nn.utils.rnn import pad_sequence -from wenet.utils.common import (add_sos_eos, log_add) +from wenet.utils.common import (add_sos_eos, log_add, WHISPER_LANGS, + add_whisper_tokens) from wenet.utils.ctc_utils import remove_duplicates_and_blank from wenet.utils.mask import (make_pad_mask, mask_finished_preds, mask_finished_scores, subsequent_mask) @@ -259,8 +260,18 @@ def attention_beam_search( encoder_mask = encoder_mask.unsqueeze(1).repeat(1, beam_size, 1, 1).view( running_size, 1, maxlen) # (B*N, 1, max_len) - hyps = torch.ones([running_size, 1], dtype=torch.long, - device=device).fill_(model.sos) # (B*N, 1) + if model.special_tokens is not None and "transcribe" in model.special_tokens: + hyps = torch.ones([running_size, 4], dtype=torch.long, + device=device) # (B*N, 4) + # TODO(xcsong): add args for language, task, etc + hyps[:, 0] = model.special_tokens["sot"] + hyps[:, 1] = model.special_tokens["sot"] + 1 + WHISPER_LANGS.index("zh") + hyps[:, 2] = model.special_tokens["transcribe"] + hyps[:, 3] = model.special_tokens["no_timestamps"] + else: + hyps = torch.ones([running_size, 1], dtype=torch.long, + device=device).fill_(model.sos) # (B*N, 1) + prefix_len = hyps.size(1) scores = torch.tensor([0.0] + [-float('inf')] * (beam_size - 1), dtype=torch.float) scores = scores.to(device).repeat([batch_size @@ -268,7 +279,7 @@ def attention_beam_search( end_flag = torch.zeros_like(scores, dtype=torch.bool, device=device) cache: Optional[List[torch.Tensor]] = None # 2. Decoder forward step by step - for i in range(1, maxlen + 1): + for i in range(prefix_len, maxlen + 1): # Stop if all batch and all beam produce eos if end_flag.sum() == running_size: break @@ -323,7 +334,7 @@ def attention_beam_search( best_hyps_index = best_index + torch.arange( batch_size, dtype=torch.long, device=device) * beam_size best_hyps = torch.index_select(hyps, dim=0, index=best_hyps_index) - best_hyps = best_hyps[:, 1:] + best_hyps = best_hyps[:, prefix_len:] results = [] for i in range(batch_size): @@ -360,8 +371,20 @@ def attention_rescoring( hyps_lens = torch.tensor([len(hyp) for hyp in hyps], device=device, dtype=torch.long) # (beam_size,) - hyps_pad, _ = add_sos_eos(hyps_pad, sos, eos, model.ignore_id) - hyps_lens = hyps_lens + 1 # Add at begining + if model.special_tokens is not None and "transcribe" in model.special_tokens: + # TODO(xcsong): add args for language, task, etc + prev_len = hyps_pad.size(1) + hyps_pad, _ = add_whisper_tokens( + model.special_tokens, hyps_pad, model.ignore_id, task="transcribe", + no_timestamp=True, language="zh", use_prev=False + ) + cur_len = hyps_pad.size(1) + hyps_lens = hyps_lens + cur_len - prev_len + prefix_len = 4 + else: + hyps_pad, _ = add_sos_eos(hyps_pad, sos, eos, model.ignore_id) + hyps_lens = hyps_lens + 1 # Add at begining + prefix_len = 1 decoder_out, r_decoder_out = model.forward_attention_decoder( hyps_pad, hyps_lens, encoder_out, reverse_weight) # Only use decoder score for rescoring @@ -373,18 +396,18 @@ def attention_rescoring( score = 0.0 tc = [] # tokens confidences for j, w in enumerate(hyp): - s = decoder_out[i][j][w] + s = decoder_out[i][j + (prefix_len - 1)][w] score += s tc.append(math.exp(s)) - score += decoder_out[i][len(hyp)][eos] + score += decoder_out[i][len(hyp) + (prefix_len - 1)][eos] # add right to left decoder score if reverse_weight > 0 and r_decoder_out.dim() > 0: r_score = 0.0 for j, w in enumerate(hyp): - s = r_decoder_out[i][len(hyp) - j - 1][w] + s = r_decoder_out[i][len(hyp) - j - 1 + (prefix_len - 1)][w] r_score += s tc[j] = (tc[j] + math.exp(s)) / 2 - r_score += r_decoder_out[i][len(hyp)][eos] + r_score += r_decoder_out[i][len(hyp) + (prefix_len - 1)][eos] score = score * (1 - reverse_weight) + r_score * reverse_weight confidences.append(math.exp(score / (len(hyp) + 1))) # add ctc score diff --git a/wenet/whisper/whisper.py b/wenet/whisper/whisper.py index 7505704f5..0cc73e6ea 100644 --- a/wenet/whisper/whisper.py +++ b/wenet/whisper/whisper.py @@ -40,11 +40,11 @@ def __init__( special_tokens: dict = None, ): super().__init__(vocab_size, encoder, decoder, ctc, ctc_weight, ignore_id, - reverse_weight, lsm_weight, length_normalized_loss) + reverse_weight, lsm_weight, length_normalized_loss, + special_tokens) assert reverse_weight == 0.0 self.sos = special_tokens["sot"] self.eos = special_tokens["eot"] - self.special_tokens = special_tokens # TODO(xcsong): time align def set_alignment_heads(self, dump: bytes):