Skip to content

Commit

Permalink
feat(whisper): Support whisper-style decoding (#2196)
Browse files Browse the repository at this point in the history
  • Loading branch information
xingchensong authored Dec 6, 2023
1 parent 46dd19d commit eccc66d
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 15 deletions.
6 changes: 4 additions & 2 deletions wenet/transformer/asr_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,14 +44,16 @@ def __init__(
reverse_weight: float = 0.0,
lsm_weight: float = 0.0,
length_normalized_loss: bool = False,
special_tokens: dict = None,
):
assert 0.0 <= ctc_weight <= 1.0, ctc_weight

super().__init__()
# note that eos is the same as sos (equivalent ID)
self.sos = vocab_size - 1
self.eos = vocab_size - 1
self.sos = special_tokens.get("sos", vocab_size - 1)
self.eos = special_tokens.get("eos", vocab_size - 1)
self.vocab_size = vocab_size
self.special_tokens = special_tokens
self.ignore_id = ignore_id
self.ctc_weight = ctc_weight
self.reverse_weight = reverse_weight
Expand Down
45 changes: 34 additions & 11 deletions wenet/transformer/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@
import torch
from torch.nn.utils.rnn import pad_sequence

from wenet.utils.common import (add_sos_eos, log_add)
from wenet.utils.common import (add_sos_eos, log_add, WHISPER_LANGS,
add_whisper_tokens)
from wenet.utils.ctc_utils import remove_duplicates_and_blank
from wenet.utils.mask import (make_pad_mask, mask_finished_preds,
mask_finished_scores, subsequent_mask)
Expand Down Expand Up @@ -259,16 +260,26 @@ def attention_beam_search(
encoder_mask = encoder_mask.unsqueeze(1).repeat(1, beam_size, 1, 1).view(
running_size, 1, maxlen) # (B*N, 1, max_len)

hyps = torch.ones([running_size, 1], dtype=torch.long,
device=device).fill_(model.sos) # (B*N, 1)
if model.special_tokens is not None and "transcribe" in model.special_tokens:
hyps = torch.ones([running_size, 4], dtype=torch.long,
device=device) # (B*N, 4)
# TODO(xcsong): add args for language, task, etc
hyps[:, 0] = model.special_tokens["sot"]
hyps[:, 1] = model.special_tokens["sot"] + 1 + WHISPER_LANGS.index("zh")
hyps[:, 2] = model.special_tokens["transcribe"]
hyps[:, 3] = model.special_tokens["no_timestamps"]
else:
hyps = torch.ones([running_size, 1], dtype=torch.long,
device=device).fill_(model.sos) # (B*N, 1)
prefix_len = hyps.size(1)
scores = torch.tensor([0.0] + [-float('inf')] * (beam_size - 1),
dtype=torch.float)
scores = scores.to(device).repeat([batch_size
]).unsqueeze(1).to(device) # (B*N, 1)
end_flag = torch.zeros_like(scores, dtype=torch.bool, device=device)
cache: Optional[List[torch.Tensor]] = None
# 2. Decoder forward step by step
for i in range(1, maxlen + 1):
for i in range(prefix_len, maxlen + 1):
# Stop if all batch and all beam produce eos
if end_flag.sum() == running_size:
break
Expand Down Expand Up @@ -323,7 +334,7 @@ def attention_beam_search(
best_hyps_index = best_index + torch.arange(
batch_size, dtype=torch.long, device=device) * beam_size
best_hyps = torch.index_select(hyps, dim=0, index=best_hyps_index)
best_hyps = best_hyps[:, 1:]
best_hyps = best_hyps[:, prefix_len:]

results = []
for i in range(batch_size):
Expand Down Expand Up @@ -360,8 +371,20 @@ def attention_rescoring(
hyps_lens = torch.tensor([len(hyp) for hyp in hyps],
device=device,
dtype=torch.long) # (beam_size,)
hyps_pad, _ = add_sos_eos(hyps_pad, sos, eos, model.ignore_id)
hyps_lens = hyps_lens + 1 # Add <sos> at begining
if model.special_tokens is not None and "transcribe" in model.special_tokens:
# TODO(xcsong): add args for language, task, etc
prev_len = hyps_pad.size(1)
hyps_pad, _ = add_whisper_tokens(
model.special_tokens, hyps_pad, model.ignore_id, task="transcribe",
no_timestamp=True, language="zh", use_prev=False
)
cur_len = hyps_pad.size(1)
hyps_lens = hyps_lens + cur_len - prev_len
prefix_len = 4
else:
hyps_pad, _ = add_sos_eos(hyps_pad, sos, eos, model.ignore_id)
hyps_lens = hyps_lens + 1 # Add <sos> at begining
prefix_len = 1
decoder_out, r_decoder_out = model.forward_attention_decoder(
hyps_pad, hyps_lens, encoder_out, reverse_weight)
# Only use decoder score for rescoring
Expand All @@ -373,18 +396,18 @@ def attention_rescoring(
score = 0.0
tc = [] # tokens confidences
for j, w in enumerate(hyp):
s = decoder_out[i][j][w]
s = decoder_out[i][j + (prefix_len - 1)][w]
score += s
tc.append(math.exp(s))
score += decoder_out[i][len(hyp)][eos]
score += decoder_out[i][len(hyp) + (prefix_len - 1)][eos]
# add right to left decoder score
if reverse_weight > 0 and r_decoder_out.dim() > 0:
r_score = 0.0
for j, w in enumerate(hyp):
s = r_decoder_out[i][len(hyp) - j - 1][w]
s = r_decoder_out[i][len(hyp) - j - 1 + (prefix_len - 1)][w]
r_score += s
tc[j] = (tc[j] + math.exp(s)) / 2
r_score += r_decoder_out[i][len(hyp)][eos]
r_score += r_decoder_out[i][len(hyp) + (prefix_len - 1)][eos]
score = score * (1 - reverse_weight) + r_score * reverse_weight
confidences.append(math.exp(score / (len(hyp) + 1)))
# add ctc score
Expand Down
4 changes: 2 additions & 2 deletions wenet/whisper/whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,11 @@ def __init__(
special_tokens: dict = None,
):
super().__init__(vocab_size, encoder, decoder, ctc, ctc_weight, ignore_id,
reverse_weight, lsm_weight, length_normalized_loss)
reverse_weight, lsm_weight, length_normalized_loss,
special_tokens)
assert reverse_weight == 0.0
self.sos = special_tokens["sot"]
self.eos = special_tokens["eot"]
self.special_tokens = special_tokens

# TODO(xcsong): time align
def set_alignment_heads(self, dump: bytes):
Expand Down

0 comments on commit eccc66d

Please sign in to comment.