Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[cli] support context biasing with ac automaton #2128

Merged
merged 2 commits into from
Nov 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 16 additions & 4 deletions wenet/cli/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,12 @@
from wenet.utils.file_utils import read_symbol_table
from wenet.transformer.search import (attention_rescoring,
ctc_prefix_beam_search, DecodeResult)
from wenet.utils.context_graph import ContextGraph


class Model:
def __init__(self, model_dir: str, gpu: int = -1):
def __init__(self, model_dir: str, gpu: int = -1, beam: int = 5,
context_path: str = None, context_score: float = 6.0):
model_path = os.path.join(model_dir, 'final.zip')
units_path = os.path.join(model_dir, 'units.txt')
self.model = torch.jit.load(model_path)
Expand All @@ -40,6 +42,12 @@ def __init__(self, model_dir: str, gpu: int = -1):
self.model = self.model.to(self.device)
self.symbol_table = read_symbol_table(units_path)
self.char_dict = {v: k for k, v in self.symbol_table.items()}
self.beam = beam
if context_path is not None:
self.context_graph = ContextGraph(context_path, self.symbol_table,
context_score=context_score)
else:
self.context_graph = None

def compute_feats(self, audio_file: str) -> torch.Tensor:
waveform, sample_rate = torchaudio.load(audio_file, normalize=False)
Expand Down Expand Up @@ -68,7 +76,8 @@ def _decode(self,
ctc_probs = self.model.ctc_activation(encoder_out)
if label is None:
ctc_prefix_results = ctc_prefix_beam_search(
ctc_probs, encoder_lens, 2)
ctc_probs, encoder_lens, self.beam,
context_graph=self.context_graph)
else: # force align mode, construct ctc prefix result from alignment
label_t = self.tokenize(label)
alignment = force_align(ctc_probs.squeeze(0),
Expand Down Expand Up @@ -131,7 +140,10 @@ def align(self, audio_file: str, label: str) -> dict:

def load_model(language: str = None,
model_dir: str = None,
gpu: int = -1) -> Model:
gpu: int = -1,
beam: int = 5,
context_path: str = None,
context_score: float = 6.0) -> Model:
if model_dir is None:
model_dir = Hub.get_model_by_lang(language)
return Model(model_dir, gpu)
return Model(model_dir, gpu, beam, context_path, context_score)
9 changes: 8 additions & 1 deletion wenet/cli/transcribe.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,12 @@ def get_args():
parser.add_argument('--paraformer',
action='store_true',
help='whether to use the best chinese model')
parser.add_argument('--beam', type=int, default=5,
help="beam size")
parser.add_argument('--context_path', type=str, default=None,
help='context list file')
parser.add_argument('--context_score', type=float, default=6.0,
help='context score')
args = parser.parse_args()
return args

Expand All @@ -60,7 +66,8 @@ def main():
if args.paraformer:
model = load_paraformer(args.model_dir, args.gpu)
else:
model = load_model(args.language, args.model_dir, args.gpu)
model = load_model(args.language, args.model_dir, args.gpu,
args.beam, args.context_path, args.context_score)
if args.align:
result = model.align(args.audio_file, args.label)
else:
Expand Down
63 changes: 58 additions & 5 deletions wenet/transformer/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from wenet.utils.ctc_utils import remove_duplicates_and_blank
from wenet.utils.mask import (make_pad_mask, mask_finished_preds,
mask_finished_scores, subsequent_mask)
from wenet.utils.context_graph import ContextGraph, ContextState


class DecodeResult:
Expand Down Expand Up @@ -62,14 +63,19 @@ def __init__(self,
s: float = float('-inf'),
ns: float = float('-inf'),
v_s: float = float('-inf'),
v_ns: float = float('-inf')):
v_ns: float = float('-inf'),
context_state: ContextState = None,
context_score: float = 0.0):
self.s = s # blank_ending_score
self.ns = ns # none_blank_ending_score
self.v_s = v_s # viterbi blank ending score
self.v_ns = v_ns # viterbi none blank ending score
self.cur_token_prob = float('-inf') # prob of current token
self.times_s = [] # times of viterbi blank path
self.times_ns = [] # times of viterbi none blank path
self.context_state = context_state
self.context_score = context_score
self.has_context = False

def score(self):
return log_add(self.s, self.ns)
Expand All @@ -80,6 +86,20 @@ def viterbi_score(self):
def times(self):
return self.times_s if self.v_s > self.v_ns else self.times_ns

def total_score(self):
return self.score() + self.context_score

def copy_context(self, prefix_score):
self.context_score = prefix_score.context_score
self.context_state = prefix_score.context_state

def update_context(self, context_graph, prefix_score, word_id):
self.copy_context(prefix_score)
(score, context_state) = context_graph.forward_one_step(
prefix_score.context_state, word_id)
self.context_score += score
self.context_state = context_state


def ctc_greedy_search(ctc_probs: torch.Tensor,
ctc_lens: torch.Tensor) -> List[DecodeResult]:
Expand All @@ -99,7 +119,8 @@ def ctc_greedy_search(ctc_probs: torch.Tensor,


def ctc_prefix_beam_search(ctc_probs: torch.Tensor, ctc_lens: torch.Tensor,
beam_size: int) -> List[DecodeResult]:
beam_size: int, context_graph: ContextGraph = None,
) -> List[DecodeResult]:
"""
Returns:
List[List[List[int]]]: nbest result for each utterance
Expand All @@ -110,7 +131,14 @@ def ctc_prefix_beam_search(ctc_probs: torch.Tensor, ctc_lens: torch.Tensor,
for i in range(batch_size):
ctc_prob = ctc_probs[i]
num_t = ctc_lens[i]
cur_hyps = [(tuple(), PrefixScore(0.0, -float('inf'), 0.0, 0.0))]
cur_hyps = [(tuple(),
PrefixScore(s=0.0,
ns=-float('inf'),
v_s=0.0,
v_ns=0.0,
context_state=None if context_graph is None
else context_graph.root,
context_score=0.0))]
# 2. CTC beam search step by step
for t in range(0, num_t):
logp = ctc_prob[t] # (vocab_size,)
Expand All @@ -129,6 +157,10 @@ def ctc_prefix_beam_search(ctc_probs: torch.Tensor, ctc_lens: torch.Tensor,
prefix_score.score() + prob)
next_score.v_s = prefix_score.viterbi_score() + prob
next_score.times_s = prefix_score.times().copy()
# perfix not changed, copy the context from prefix
if context_graph and not next_score.has_context:
next_score.copy_context(prefix_score)
next_score.has_context = True
elif u == last:
# Update *uu -> *u;
next_score1 = next_hyps[prefix]
Expand All @@ -141,6 +173,9 @@ def ctc_prefix_beam_search(ctc_probs: torch.Tensor, ctc_lens: torch.Tensor,
next_score1.times_ns = prefix_score.times_ns.copy(
)
next_score1.times_ns[-1] = t
if context_graph and not next_score1.has_context:
next_score1.copy_context(prefix_score)
next_score1.has_context = True

# Update *u-u -> *uu, - is for blank
n_prefix = prefix + (u, )
Expand All @@ -152,6 +187,10 @@ def ctc_prefix_beam_search(ctc_probs: torch.Tensor, ctc_lens: torch.Tensor,
next_score2.cur_token_prob = prob
next_score2.times_ns = prefix_score.times_s.copy()
next_score2.times_ns.append(t)
if context_graph and not next_score2.has_context:
next_score2.update_context(context_graph,
prefix_score, u)
next_score2.has_context = True
else:
n_prefix = prefix + (u, )
next_score = next_hyps[n_prefix]
Expand All @@ -164,14 +203,28 @@ def ctc_prefix_beam_search(ctc_probs: torch.Tensor, ctc_lens: torch.Tensor,
next_score.cur_token_prob = prob
next_score.times_ns = prefix_score.times().copy()
next_score.times_ns.append(t)
if context_graph and not next_score.has_context:
next_score.update_context(context_graph,
prefix_score, u)
next_score.has_context = True

# 2.2 Second beam prune
next_hyps = sorted(next_hyps.items(),
key=lambda x: x[1].score(),
key=lambda x: x[1].total_score(),
reverse=True)
cur_hyps = next_hyps[:beam_size]

# We should backoff the context score/state when the context is
# not fully matched at the last time.
if context_graph is not None:
for i, hyp in enumerate(cur_hyps):
context_score, new_context_state = context_graph.finalize(
hyp[1].context_state)
cur_hyps[i][1].context_score = context_score
cur_hyps[i][1].context_state = new_context_state

nbest = [y[0] for y in cur_hyps]
nbest_scores = [y[1].score() for y in cur_hyps]
nbest_scores = [y[1].total_score() for y in cur_hyps]
nbest_times = [y[1].times() for y in cur_hyps]
best = nbest[0]
best_score = nbest_scores[0]
Expand Down
Loading
Loading