Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP]: Implement token level shallow fusion #609

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions egs/librispeech/ASR/generate-lm.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
#!/usr/bin/env bash

lang_dir=data/lang_bpe_500
if [ ! -f $lang_dir/bigram.arpa ]; then
./shared/make_kn_lm.py \
-ngram-order 2 \
-text $lang_dir/transcript_tokens.txt \
-lm $lang_dir/bigram.arpa
fi

if [ ! -f $lang_dir/bigram.fst.txt ]; then
python3 -m kaldilm \
--read-symbol-table="$lang_dir/tokens.txt" \
--disambig-symbol='#0' \
--max-order=2 \
$lang_dir/bigram.arpa > $lang_dir/bigram.fst.txt
fi

if [ ! -f $lang_dir/trigram.arpa ]; then
./shared/make_kn_lm.py \
-ngram-order 3 \
-text $lang_dir/transcript_tokens.txt \
-lm $lang_dir/trigram.arpa
fi

if [ ! -f $lang_dir/trigram.fst.txt ]; then
python3 -m kaldilm \
--read-symbol-table="$lang_dir/tokens.txt" \
--disambig-symbol='#0' \
--max-order=3 \
$lang_dir/trigram.arpa > $lang_dir/trigram.fst.txt
fi
33 changes: 33 additions & 0 deletions egs/librispeech/ASR/lstm_transducer_stateless2/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,10 +115,12 @@
greedy_search,
greedy_search_batch,
modified_beam_search,
modified_beam_search2,
)
from librispeech import LibriSpeech
from train import add_model_arguments, get_params, get_transducer_model

from icefall import NgramLm
from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
Expand Down Expand Up @@ -315,6 +317,8 @@ def decode_one_batch(
batch: dict,
word_table: Optional[k2.SymbolTable] = None,
decoding_graph: Optional[k2.Fsa] = None,
ngram_lm: Optional[NgramLm] = None,
ngram_lm_scale: float = 1.0,
) -> Dict[str, List[List[str]]]:
"""Decode one batch and return the result in a dict. The dict has the
following format:
Expand Down Expand Up @@ -448,6 +452,17 @@ def decode_one_batch(
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.decoding_method == "modified_beam_search2":
batch_size = encoder_out.size(0)
for i in range(batch_size):
encoder_out_i = encoder_out[i, : encoder_out_lens[i]]
hyp = modified_beam_search2(
model=model,
encoder_out=encoder_out_i,
ngram_lm=ngram_lm,
ngram_lm_scale=ngram_lm_scale,
)
hyps.append(sp.decode(hyp).split())
else:
batch_size = encoder_out.size(0)

Expand Down Expand Up @@ -497,6 +512,8 @@ def decode_dataset(
sp: spm.SentencePieceProcessor,
word_table: Optional[k2.SymbolTable] = None,
decoding_graph: Optional[k2.Fsa] = None,
ngram_lm: Optional[NgramLm] = None,
ngram_lm_scale: float = 1.0,
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
"""Decode dataset.

Expand Down Expand Up @@ -546,6 +563,8 @@ def decode_dataset(
decoding_graph=decoding_graph,
word_table=word_table,
batch=batch,
ngram_lm=ngram_lm,
ngram_lm_scale=ngram_lm_scale,
)

for name, hyps in hyps_dict.items():
Expand Down Expand Up @@ -631,6 +650,7 @@ def main():
"fast_beam_search_nbest_LG",
"fast_beam_search_nbest_oracle",
"modified_beam_search",
"modified_beam_search2",
)
params.res_dir = params.exp_dir / params.decoding_method

Expand All @@ -655,6 +675,7 @@ def main():
else:
params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"

if params.use_averaged_model:
params.suffix += "-use-averaged-model"
Expand Down Expand Up @@ -768,6 +789,16 @@ def main():
model.to(device)
model.eval()

# lm_filename = "bigram.fst.txt"
lm_filename = "trigram.fst.txt"
logging.info(f"lm filename: {lm_filename}")
ngram_lm = NgramLm(
str(params.lang_dir / lm_filename),
backoff_id=500,
is_binary=False,
)
logging.info(f"num states: {ngram_lm.lm.num_states}")

if "fast_beam_search" in params.decoding_method:
if params.decoding_method == "fast_beam_search_nbest_LG":
lexicon = Lexicon(params.lang_dir)
Expand Down Expand Up @@ -812,6 +843,8 @@ def main():
sp=sp,
word_table=word_table,
decoding_graph=decoding_graph,
ngram_lm=ngram_lm,
ngram_lm_scale=params.ngram_lm_scale,
)

save_results(
Expand Down
96 changes: 96 additions & 0 deletions egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import torch
from model import Transducer

from icefall import NgramLm, NgramLmStateCost
from icefall.decode import Nbest, one_best_decoding
from icefall.utils import add_eos, add_sos, get_texts

Expand Down Expand Up @@ -656,6 +657,8 @@ class Hypothesis:
# It contains only one entry.
log_prob: torch.Tensor

state_cost: Optional[NgramLmStateCost] = None

@property
def key(self) -> str:
"""Return a string representation of self.ys"""
Expand Down Expand Up @@ -1539,3 +1542,96 @@ def fast_beam_search_with_nbest_rnn_rescoring(
ans[key] = hyps

return ans


def modified_beam_search2(
model: Transducer,
encoder_out: torch.Tensor,
ngram_lm: NgramLm,
ngram_lm_scale: float,
beam: int = 4,
):
encoder_out = model.joiner.encoder_proj(encoder_out)

lm_scale = ngram_lm_scale

assert encoder_out.ndim == 2, encoder_out.shape
blank_id = model.decoder.blank_id
unk_id = getattr(model, "unk_id", blank_id)
context_size = model.decoder.context_size
device = next(model.parameters()).device

B = HypothesisList()
B.add(
Hypothesis(
ys=[blank_id] * context_size,
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
state_cost=NgramLmStateCost(ngram_lm),
)
)

T = encoder_out.shape[0]
for t in range(T):
current_encoder_out = encoder_out[t : t + 1]
A = list(B)
B = HypothesisList()

ys_log_probs = torch.cat(
[
hyp.log_prob.reshape(1, 1) + hyp.state_cost.lm_score * lm_scale
for hyp in A
]
) # (num_hyps, 1)

decoder_input = torch.tensor(
[hyp.ys[-context_size:] for hyp in A],
device=device,
dtype=torch.int64,
) # (num_hyps, context_size)
decoder_out = model.decoder(decoder_input, need_pad=False).squeeze(1)
decoder_out = model.joiner.decoder_proj(decoder_out)

# decoder_out is of shape (num_hyps, joiner_dim)
current_encoder_out = current_encoder_out.repeat(len(A), 1)
# current_encoder_out is of shape (num_hyps, encoder_out_dim)
logits = model.joiner(
current_encoder_out,
decoder_out,
project_input=False,
) # (num_hyps, vocab_size)
log_probs = logits.log_softmax(dim=-1) # (num_hyps, vocab_size)
log_probs.add_(ys_log_probs)

vocab_size = log_probs.size(-1)
log_probs = log_probs.reshape(-1)
topk_log_probs, topk_indexes = log_probs.topk(beam)

with warnings.catch_warnings():
warnings.simplefilter("ignore")
topk_hyp_indexes = (topk_indexes // vocab_size).tolist()
topk_token_indexes = (topk_indexes % vocab_size).tolist()

for k in range(len(topk_hyp_indexes)):
hyp_idx = topk_hyp_indexes[k]
hyp = A[hyp_idx]
new_ys = hyp.ys[:]

new_token = topk_token_indexes[k]
if new_token not in (blank_id, unk_id):
new_ys.append(new_token)
state_cost = hyp.state_cost.forward_one_step(new_token)
else:
state_cost = hyp.state_cost

# We only keep AM scores in new_hyp.log_prob
new_log_prob = (
topk_log_probs[k] - hyp.state_cost.lm_score * lm_scale
)

new_hyp = Hypothesis(
ys=new_ys, log_prob=new_log_prob, state_cost=state_cost
)
B.add(new_hyp)

best_hyp = B.get_most_probable(length_norm=True)
return best_hyp.ys[context_size:]
2 changes: 2 additions & 0 deletions icefall/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,5 @@
subsequent_chunk_mask,
write_error_stats,
)

from .ngram_lm import NgramLm, NgramLmStateCost
Loading