diff --git a/wenet/cli/model.py b/wenet/cli/model.py index 65da988d4..1490476e6 100644 --- a/wenet/cli/model.py +++ b/wenet/cli/model.py @@ -34,7 +34,7 @@ def __init__(self, language: str): symbol_table = read_symbol_table(units_path) self.char_dict = {v: k for k, v in symbol_table.items()} - def transcribe(self, audio_file: str, token_times: bool = False): + def transcribe(self, audio_file: str, tokens_info: bool = False): waveform, sample_rate = torchaudio.load(audio_file, normalize=False) waveform = waveform.to(torch.float) feats = kaldi.fbank(waveform, @@ -54,19 +54,21 @@ def transcribe(self, audio_file: str, token_times: bool = False): res = rescoring_results[0] result = {} result['rec'] = ''.join([self.char_dict[x] for x in res.tokens]) + result['confidence'] = res.confidence - if token_times: + if tokens_info: frame_rate = self.model.subsampling_rate( ) * 0.01 # 0.01 seconds per frame max_duration = encoder_out.size(1) * frame_rate times = gen_timestamps_from_peak(res.times, max_duration, frame_rate, 1.0) - times_info = [] + tokens_info = [] for i, x in enumerate(res.tokens): - times_info.append({ + tokens_info.append({ 'token': self.char_dict[x], 'start': times[i][0], - 'end': times[i][1] + 'end': times[i][1], + 'confidence': res.tokens_confidence[i] }) - result['times'] = times_info + result['tokens'] = tokens_info return result diff --git a/wenet/cli/transcribe.py b/wenet/cli/transcribe.py index 8564c6e61..3d517c401 100644 --- a/wenet/cli/transcribe.py +++ b/wenet/cli/transcribe.py @@ -28,9 +28,10 @@ def get_args(): default='chinese', help='language type') parser.add_argument('-t', - '--gen_token_times', + '--show_tokens_info', action='store_true', - help='whether to generate token times') + help='whether to output token(word) level information' + ', such times/confidence') args = parser.parse_args() return args @@ -38,7 +39,7 @@ def get_args(): def main(): args = get_args() model = Model(args.language) - result = model.transcribe(args.audio_file, args.gen_token_times) + result = model.transcribe(args.audio_file, args.show_tokens_info) print(result) diff --git a/wenet/transformer/search.py b/wenet/transformer/search.py index c5914a418..accca3338 100644 --- a/wenet/transformer/search.py +++ b/wenet/transformer/search.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import math from collections import defaultdict from typing import List, Optional @@ -311,25 +312,36 @@ def attention_rescoring( # Only use decoder score for rescoring best_score = -float('inf') best_index = 0 + confidences = [] + tokens_confidences = [] for i, hyp in enumerate(hyps): score = 0.0 + tc = [] # tokens confidences for j, w in enumerate(hyp): - score += decoder_out[i][j][w] + s = decoder_out[i][j][w] + score += s + tc.append(math.exp(s)) score += decoder_out[i][len(hyp)][eos] # add right to left decoder score if reverse_weight > 0: r_score = 0.0 for j, w in enumerate(hyp): - r_score += r_decoder_out[i][len(hyp) - j - 1][w] + s = r_decoder_out[i][len(hyp) - j - 1][w] + r_score += s + tc[j] = (tc[j] + math.exp(s)) / 2 r_score += r_decoder_out[i][len(hyp)][eos] score = score * (1 - reverse_weight) + r_score * reverse_weight + confidences.append(math.exp(score / (len(hyp) + 1))) # add ctc score score += ctc_scores[i] * ctc_weight if score > best_score: best_score = score best_index = i + tokens_confidences.append(tc) results.append( DecodeResult(hyps[best_index], best_score, - times=ctc_prefix_results[b].nbest_times[best_index])) + confidence=confidences[best_index], + times=ctc_prefix_results[b].nbest_times[best_index], + tokens_confidence=tokens_confidences[best_index])) return results