Skip to content

Commit

Permalink
[cli] support context biasing with ac automaton (#2128)
Browse files Browse the repository at this point in the history
* [cli] support context biasing with ac automaton

* [cli] fix lint

---------

Co-authored-by: user01 <user01@user01deMacBook-Pro.local>
  • Loading branch information
cdliang11 and user01 authored Nov 7, 2023
1 parent 3c7f291 commit 5faf24b
Show file tree
Hide file tree
Showing 4 changed files with 294 additions and 61 deletions.
20 changes: 16 additions & 4 deletions wenet/cli/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,12 @@
from wenet.utils.file_utils import read_symbol_table
from wenet.transformer.search import (attention_rescoring,
ctc_prefix_beam_search, DecodeResult)
from wenet.utils.context_graph import ContextGraph


class Model:
def __init__(self, model_dir: str, gpu: int = -1):
def __init__(self, model_dir: str, gpu: int = -1, beam: int = 5,
context_path: str = None, context_score: float = 6.0):
model_path = os.path.join(model_dir, 'final.zip')
units_path = os.path.join(model_dir, 'units.txt')
self.model = torch.jit.load(model_path)
Expand All @@ -40,6 +42,12 @@ def __init__(self, model_dir: str, gpu: int = -1):
self.model = self.model.to(self.device)
self.symbol_table = read_symbol_table(units_path)
self.char_dict = {v: k for k, v in self.symbol_table.items()}
self.beam = beam
if context_path is not None:
self.context_graph = ContextGraph(context_path, self.symbol_table,
context_score=context_score)
else:
self.context_graph = None

def compute_feats(self, audio_file: str) -> torch.Tensor:
waveform, sample_rate = torchaudio.load(audio_file, normalize=False)
Expand Down Expand Up @@ -68,7 +76,8 @@ def _decode(self,
ctc_probs = self.model.ctc_activation(encoder_out)
if label is None:
ctc_prefix_results = ctc_prefix_beam_search(
ctc_probs, encoder_lens, 2)
ctc_probs, encoder_lens, self.beam,
context_graph=self.context_graph)
else: # force align mode, construct ctc prefix result from alignment
label_t = self.tokenize(label)
alignment = force_align(ctc_probs.squeeze(0),
Expand Down Expand Up @@ -131,7 +140,10 @@ def align(self, audio_file: str, label: str) -> dict:

def load_model(language: str = None,
model_dir: str = None,
gpu: int = -1) -> Model:
gpu: int = -1,
beam: int = 5,
context_path: str = None,
context_score: float = 6.0) -> Model:
if model_dir is None:
model_dir = Hub.get_model_by_lang(language)
return Model(model_dir, gpu)
return Model(model_dir, gpu, beam, context_path, context_score)
9 changes: 8 additions & 1 deletion wenet/cli/transcribe.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,12 @@ def get_args():
parser.add_argument('--paraformer',
action='store_true',
help='whether to use the best chinese model')
parser.add_argument('--beam', type=int, default=5,
help="beam size")
parser.add_argument('--context_path', type=str, default=None,
help='context list file')
parser.add_argument('--context_score', type=float, default=6.0,
help='context score')
args = parser.parse_args()
return args

Expand All @@ -60,7 +66,8 @@ def main():
if args.paraformer:
model = load_paraformer(args.model_dir, args.gpu)
else:
model = load_model(args.language, args.model_dir, args.gpu)
model = load_model(args.language, args.model_dir, args.gpu,
args.beam, args.context_path, args.context_score)
if args.align:
result = model.align(args.audio_file, args.label)
else:
Expand Down
63 changes: 58 additions & 5 deletions wenet/transformer/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from wenet.utils.ctc_utils import remove_duplicates_and_blank
from wenet.utils.mask import (make_pad_mask, mask_finished_preds,
mask_finished_scores, subsequent_mask)
from wenet.utils.context_graph import ContextGraph, ContextState


class DecodeResult:
Expand Down Expand Up @@ -62,14 +63,19 @@ def __init__(self,
s: float = float('-inf'),
ns: float = float('-inf'),
v_s: float = float('-inf'),
v_ns: float = float('-inf')):
v_ns: float = float('-inf'),
context_state: ContextState = None,
context_score: float = 0.0):
self.s = s # blank_ending_score
self.ns = ns # none_blank_ending_score
self.v_s = v_s # viterbi blank ending score
self.v_ns = v_ns # viterbi none blank ending score
self.cur_token_prob = float('-inf') # prob of current token
self.times_s = [] # times of viterbi blank path
self.times_ns = [] # times of viterbi none blank path
self.context_state = context_state
self.context_score = context_score
self.has_context = False

def score(self):
return log_add(self.s, self.ns)
Expand All @@ -80,6 +86,20 @@ def viterbi_score(self):
def times(self):
return self.times_s if self.v_s > self.v_ns else self.times_ns

def total_score(self):
return self.score() + self.context_score

def copy_context(self, prefix_score):
self.context_score = prefix_score.context_score
self.context_state = prefix_score.context_state

def update_context(self, context_graph, prefix_score, word_id):
self.copy_context(prefix_score)
(score, context_state) = context_graph.forward_one_step(
prefix_score.context_state, word_id)
self.context_score += score
self.context_state = context_state


def ctc_greedy_search(ctc_probs: torch.Tensor,
ctc_lens: torch.Tensor) -> List[DecodeResult]:
Expand All @@ -99,7 +119,8 @@ def ctc_greedy_search(ctc_probs: torch.Tensor,


def ctc_prefix_beam_search(ctc_probs: torch.Tensor, ctc_lens: torch.Tensor,
beam_size: int) -> List[DecodeResult]:
beam_size: int, context_graph: ContextGraph = None,
) -> List[DecodeResult]:
"""
Returns:
List[List[List[int]]]: nbest result for each utterance
Expand All @@ -110,7 +131,14 @@ def ctc_prefix_beam_search(ctc_probs: torch.Tensor, ctc_lens: torch.Tensor,
for i in range(batch_size):
ctc_prob = ctc_probs[i]
num_t = ctc_lens[i]
cur_hyps = [(tuple(), PrefixScore(0.0, -float('inf'), 0.0, 0.0))]
cur_hyps = [(tuple(),
PrefixScore(s=0.0,
ns=-float('inf'),
v_s=0.0,
v_ns=0.0,
context_state=None if context_graph is None
else context_graph.root,
context_score=0.0))]
# 2. CTC beam search step by step
for t in range(0, num_t):
logp = ctc_prob[t] # (vocab_size,)
Expand All @@ -129,6 +157,10 @@ def ctc_prefix_beam_search(ctc_probs: torch.Tensor, ctc_lens: torch.Tensor,
prefix_score.score() + prob)
next_score.v_s = prefix_score.viterbi_score() + prob
next_score.times_s = prefix_score.times().copy()
# perfix not changed, copy the context from prefix
if context_graph and not next_score.has_context:
next_score.copy_context(prefix_score)
next_score.has_context = True
elif u == last:
# Update *uu -> *u;
next_score1 = next_hyps[prefix]
Expand All @@ -141,6 +173,9 @@ def ctc_prefix_beam_search(ctc_probs: torch.Tensor, ctc_lens: torch.Tensor,
next_score1.times_ns = prefix_score.times_ns.copy(
)
next_score1.times_ns[-1] = t
if context_graph and not next_score1.has_context:
next_score1.copy_context(prefix_score)
next_score1.has_context = True

# Update *u-u -> *uu, - is for blank
n_prefix = prefix + (u, )
Expand All @@ -152,6 +187,10 @@ def ctc_prefix_beam_search(ctc_probs: torch.Tensor, ctc_lens: torch.Tensor,
next_score2.cur_token_prob = prob
next_score2.times_ns = prefix_score.times_s.copy()
next_score2.times_ns.append(t)
if context_graph and not next_score2.has_context:
next_score2.update_context(context_graph,
prefix_score, u)
next_score2.has_context = True
else:
n_prefix = prefix + (u, )
next_score = next_hyps[n_prefix]
Expand All @@ -164,14 +203,28 @@ def ctc_prefix_beam_search(ctc_probs: torch.Tensor, ctc_lens: torch.Tensor,
next_score.cur_token_prob = prob
next_score.times_ns = prefix_score.times().copy()
next_score.times_ns.append(t)
if context_graph and not next_score.has_context:
next_score.update_context(context_graph,
prefix_score, u)
next_score.has_context = True

# 2.2 Second beam prune
next_hyps = sorted(next_hyps.items(),
key=lambda x: x[1].score(),
key=lambda x: x[1].total_score(),
reverse=True)
cur_hyps = next_hyps[:beam_size]

# We should backoff the context score/state when the context is
# not fully matched at the last time.
if context_graph is not None:
for i, hyp in enumerate(cur_hyps):
context_score, new_context_state = context_graph.finalize(
hyp[1].context_state)
cur_hyps[i][1].context_score = context_score
cur_hyps[i][1].context_state = new_context_state

nbest = [y[0] for y in cur_hyps]
nbest_scores = [y[1].score() for y in cur_hyps]
nbest_scores = [y[1].total_score() for y in cur_hyps]
nbest_times = [y[1].times() for y in cur_hyps]
best = nbest[0]
best_score = nbest_scores[0]
Expand Down
Loading

0 comments on commit 5faf24b

Please sign in to comment.