-
Notifications
You must be signed in to change notification settings - Fork 2.5k
/
decoder_timestamps_utils.py
806 lines (700 loc) · 35.1 KB
/
decoder_timestamps_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import math
from typing import Dict, List, Tuple, Type, Union
import numpy as np
import torch
from omegaconf import OmegaConf
import nemo.collections.asr as nemo_asr
from nemo.collections.asr.metrics.wer import WER
from nemo.collections.asr.models import EncDecCTCModel, EncDecCTCModelBPE
from nemo.collections.asr.parts.preprocessing.segment import get_samples
from nemo.collections.asr.parts.submodules.ctc_decoding import (
CTCBPEDecoding,
CTCBPEDecodingConfig,
CTCDecoding,
CTCDecodingConfig,
)
from nemo.collections.asr.parts.utils.speaker_utils import audio_rttm_map, get_uniqname_from_filepath
from nemo.collections.asr.parts.utils.streaming_utils import AudioFeatureIterator, FrameBatchASR
from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec
from nemo.utils import logging
__all__ = ['ASRDecoderTimeStamps']
try:
from pyctcdecode import build_ctcdecoder
PYCTCDECODE = True
except ImportError:
PYCTCDECODE = False
def if_none_get_default(param, default_value):
return (param, default_value)[param is None]
class WERBPE_TS(WER):
"""
This is WERBPE_TS class that is modified for generating word_timestamps with logits.
The functions in WER class is modified to save the word_timestamps whenever BPE token
is being saved into a list.
This class is designed to support ASR models based on CTC and BPE.
Please refer to the definition of WERBPE class for more information.
"""
def __init__(
self,
tokenizer: TokenizerSpec,
batch_dim_index=0,
use_cer=False,
ctc_decode=None,
log_prediction=True,
dist_sync_on_step=False,
):
if ctc_decode is not None:
logging.warning(f'`ctc_decode` was set to {ctc_decode}. Note that this is ignored.')
decoding_cfg = CTCBPEDecodingConfig(batch_dim_index=batch_dim_index)
decoding = CTCBPEDecoding(decoding_cfg, tokenizer=tokenizer)
super().__init__(decoding, use_cer, log_prediction, dist_sync_on_step)
def ctc_decoder_predictions_tensor_with_ts(
self, time_stride, predictions: torch.Tensor, predictions_len: torch.Tensor = None
) -> List[str]:
hypotheses, timestamps, word_timestamps = [], [], []
# '⁇' string should be removed since it causes error during string split.
unk = '⁇'
prediction_cpu_tensor = predictions.long().cpu()
# iterate over batch
self.time_stride = time_stride
for ind in range(prediction_cpu_tensor.shape[self.decoding.batch_dim_index]):
prediction = prediction_cpu_tensor[ind].detach().numpy().tolist()
if predictions_len is not None:
prediction = prediction[: predictions_len[ind]]
# CTC decoding procedure
decoded_prediction, char_ts, timestamp_list = [], [], []
previous = self.decoding.blank_id
for pdx, p in enumerate(prediction):
if (p != previous or previous == self.decoding.blank_id) and p != self.decoding.blank_id:
decoded_prediction.append(p)
char_ts.append(round(pdx * self.time_stride, 2))
timestamp_list.append(round(pdx * self.time_stride, 2))
previous = p
hypothesis = self.decode_tokens_to_str_with_ts(decoded_prediction)
hypothesis = hypothesis.replace(unk, '')
word_ts, word_seq = self.get_ts_from_decoded_prediction(decoded_prediction, hypothesis, char_ts)
hypotheses.append(" ".join(word_seq))
timestamps.append(timestamp_list)
word_timestamps.append(word_ts)
return hypotheses, timestamps, word_timestamps
def decode_tokens_to_str_with_ts(self, tokens: List[int]) -> str:
hypothesis = self.decoding.tokenizer.ids_to_text(tokens)
return hypothesis
def decode_ids_to_tokens_with_ts(self, tokens: List[int]) -> List[str]:
token_list = self.decoding.tokenizer.ids_to_tokens(tokens)
return token_list
def get_ts_from_decoded_prediction(
self, decoded_prediction: List[str], hypothesis: str, char_ts: List[str]
) -> Tuple[List[List[float]], List[str]]:
decoded_char_list = self.decoding.tokenizer.ids_to_tokens(decoded_prediction)
stt_idx, end_idx = 0, len(decoded_char_list) - 1
stt_ch_idx, end_ch_idx = 0, 0
space = '▁'
word_ts, word_seq = [], []
word_open_flag = False
for idx, ch in enumerate(decoded_char_list):
# If the symbol is space and not an end of the utterance, move on
if idx != end_idx and (space == ch and space in decoded_char_list[idx + 1]):
continue
# If the word does not containg space (the start of the word token), keep counting
if (idx == stt_idx or space == decoded_char_list[idx - 1] or (space in ch and len(ch) > 1)) and (
ch != space
):
_stt = char_ts[idx]
stt_ch_idx = idx
word_open_flag = True
# If this char has `word_open_flag=True` and meets any of one of the following condition:
# (1) last word (2) unknown word (3) start symbol in the following word,
# close the `word_open_flag` and add the word to the `word_seq` list.
close_cond = idx == end_idx or ch in ['<unk>'] or space in decoded_char_list[idx + 1]
if (word_open_flag and ch != space) and close_cond:
_end = round(char_ts[idx] + self.time_stride, 2)
end_ch_idx = idx
word_open_flag = False
word_ts.append([_stt, _end])
stitched_word = ''.join(decoded_char_list[stt_ch_idx : end_ch_idx + 1]).replace(space, '')
word_seq.append(stitched_word)
assert len(word_ts) == len(hypothesis.split()), "Text hypothesis does not match word timestamps."
return word_ts, word_seq
class WER_TS(WER):
"""
This is WER class that is modified for generating timestamps with logits.
The functions in WER class is modified to save the timestamps whenever character
is being saved into a list.
This class is designed to support ASR models based on CTC and Character-level tokens.
Please refer to the definition of WER class for more information.
"""
def __init__(
self,
vocabulary,
batch_dim_index=0,
use_cer=False,
ctc_decode=None,
log_prediction=True,
dist_sync_on_step=False,
):
if ctc_decode is not None:
logging.warning(f'`ctc_decode` was set to {ctc_decode}. Note that this is ignored.')
decoding_cfg = CTCDecodingConfig(batch_dim_index=batch_dim_index)
decoding = CTCDecoding(decoding_cfg, vocabulary=vocabulary)
super().__init__(decoding, use_cer, log_prediction, dist_sync_on_step)
def decode_tokens_to_str_with_ts(self, tokens: List[int], timestamps: List[int]) -> str:
"""
Take frame-level tokens and timestamp list and collect the timestamps for
start and end of each word.
"""
token_list, timestamp_list = self.decode_ids_to_tokens_with_ts(tokens, timestamps)
hypothesis = ''.join(self.decoding.decode_ids_to_tokens(tokens))
return hypothesis, timestamp_list
def decode_ids_to_tokens_with_ts(self, tokens: List[int], timestamps: List[int]) -> List[str]:
token_list, timestamp_list = [], []
for i, c in enumerate(tokens):
if c != self.decoding.blank_id:
token_list.append(self.decoding.labels_map[c])
timestamp_list.append(timestamps[i])
return token_list, timestamp_list
def ctc_decoder_predictions_tensor_with_ts(
self,
predictions: torch.Tensor,
predictions_len: torch.Tensor = None,
) -> List[str]:
"""
A shortened version of the original function ctc_decoder_predictions_tensor().
Replaced decode_tokens_to_str() function with decode_tokens_to_str_with_ts().
"""
hypotheses, timestamps = [], []
prediction_cpu_tensor = predictions.long().cpu()
for ind in range(prediction_cpu_tensor.shape[self.decoding.batch_dim_index]):
prediction = prediction_cpu_tensor[ind].detach().numpy().tolist()
if predictions_len is not None:
prediction = prediction[: predictions_len[ind]]
# CTC decoding procedure with timestamps
decoded_prediction, decoded_timing_list = [], []
previous = self.decoding.blank_id
for pdx, p in enumerate(prediction):
if (p != previous or previous == self.decoding.blank_id) and p != self.decoding.blank_id:
decoded_prediction.append(p)
decoded_timing_list.append(pdx)
previous = p
text, timestamp_list = self.decode_tokens_to_str_with_ts(decoded_prediction, decoded_timing_list)
hypotheses.append(text)
timestamps.append(timestamp_list)
return hypotheses, timestamps
def get_wer_feat_logit(audio_file_path, asr, frame_len, tokens_per_chunk, delay, model_stride_in_secs):
"""
Create a preprocessor to convert audio samples into raw features,
Normalization will be done per buffer in frame_bufferer.
"""
asr.reset()
asr.read_audio_file_and_return(audio_file_path, delay, model_stride_in_secs)
hyp, tokens, log_prob = asr.transcribe_with_ts(tokens_per_chunk, delay)
return hyp, tokens, log_prob
class FrameBatchASRLogits(FrameBatchASR):
"""
A class for streaming frame-based ASR.
Inherits from FrameBatchASR and adds new capability of returning the logit output.
Please refer to FrameBatchASR for more detailed information.
"""
def __init__(
self,
asr_model: Type[EncDecCTCModelBPE],
frame_len: float = 1.6,
total_buffer: float = 4.0,
batch_size: int = 4,
):
super().__init__(asr_model, frame_len, total_buffer, batch_size)
self.all_logprobs = []
def clear_buffer(self):
self.all_logprobs = []
self.all_preds = []
def read_audio_file_and_return(self, audio_filepath: str, delay: float, model_stride_in_secs: float):
samples = get_samples(audio_filepath)
samples = np.pad(samples, (0, int(delay * model_stride_in_secs * self.asr_model._cfg.sample_rate)))
frame_reader = AudioFeatureIterator(samples, self.frame_len, self.raw_preprocessor, self.asr_model.device)
self.set_frame_reader(frame_reader)
@torch.no_grad()
def _get_batch_preds(self, keep_logits):
device = self.asr_model.device
for batch in iter(self.data_loader):
feat_signal, feat_signal_len = batch
feat_signal, feat_signal_len = feat_signal.to(device), feat_signal_len.to(device)
log_probs, encoded_len, predictions = self.asr_model(
processed_signal=feat_signal, processed_signal_length=feat_signal_len
)
preds = torch.unbind(predictions)
for pred in preds:
self.all_preds.append(pred.cpu().numpy())
# Always keep logits in FrameBatchASRLogits
_ = keep_logits
log_probs_tup = torch.unbind(log_probs)
for log_prob in log_probs_tup:
self.all_logprobs.append(log_prob)
del log_probs, log_probs_tup
del encoded_len
del predictions
def transcribe_with_ts(
self,
tokens_per_chunk: int,
delay: int,
):
self.infer_logits()
self.unmerged = []
self.part_logprobs = []
for idx, pred in enumerate(self.all_preds):
decoded = pred.tolist()
_stt, _end = len(decoded) - 1 - delay, len(decoded) - 1 - delay + tokens_per_chunk
self.unmerged += decoded[len(decoded) - 1 - delay : len(decoded) - 1 - delay + tokens_per_chunk]
self.part_logprobs.append(self.all_logprobs[idx][_stt:_end, :])
self.unmerged_logprobs = torch.cat(self.part_logprobs, 0)
assert (
len(self.unmerged) == self.unmerged_logprobs.shape[0]
), "Unmerged decoded result and log prob lengths are different."
return self.greedy_merge(self.unmerged), self.unmerged, self.unmerged_logprobs
class ASRDecoderTimeStamps:
"""
A class designed for extracting word timestamps while the ASR decoding process.
This class contains a few setups for a slew of NeMo ASR models such as QuartzNet, CitriNet and ConformerCTC models.
"""
def __init__(self, cfg_diarizer):
self.manifest_filepath = cfg_diarizer.manifest_filepath
self.params = cfg_diarizer.asr.parameters
self.ctc_decoder_params = cfg_diarizer.asr.ctc_decoder_parameters
self.ASR_model_name = cfg_diarizer.asr.model_path
self.nonspeech_threshold = self.params.asr_based_vad_threshold
self.root_path = None
self.run_ASR = None
self.encdec_class = None
self.AUDIO_RTTM_MAP = audio_rttm_map(self.manifest_filepath)
self.audio_file_list = [value['audio_filepath'] for _, value in self.AUDIO_RTTM_MAP.items()]
def set_asr_model(self):
"""
Initialize the parameters for the given ASR model.
Currently, the following NGC models are supported:
stt_en_quartznet15x5,
stt_en_citrinet*,
stt_en_conformer_ctc*
To assign a proper decoding function for generating timestamp output,
the name of .nemo file should include the architecture name such as:
'quartznet', 'conformer', and 'citrinet'.
decoder_delay_in_sec is the amount of delay that is compensated during the word timestamp extraction.
word_ts_anchor_offset is the reference point for a word and used for matching the word with diarization labels.
Each ASR model has a different optimal decoder delay and word timestamp anchor offset.
To obtain an optimized diarization result with ASR, decoder_delay_in_sec and word_ts_anchor_offset
need to be searched on a development set.
"""
if 'quartznet' in self.ASR_model_name.lower():
self.run_ASR = self.run_ASR_QuartzNet_CTC
self.encdec_class = EncDecCTCModel
self.decoder_delay_in_sec = if_none_get_default(self.params['decoder_delay_in_sec'], 0.04)
self.word_ts_anchor_offset = if_none_get_default(self.params['word_ts_anchor_offset'], 0.12)
self.asr_batch_size = if_none_get_default(self.params['asr_batch_size'], 4)
self.model_stride_in_secs = 0.02
elif 'fastconformer' in self.ASR_model_name.lower():
self.run_ASR = self.run_ASR_BPE_CTC
self.encdec_class = EncDecCTCModelBPE
self.decoder_delay_in_sec = if_none_get_default(self.params['decoder_delay_in_sec'], 0.08)
self.word_ts_anchor_offset = if_none_get_default(self.params['word_ts_anchor_offset'], 0.12)
self.asr_batch_size = if_none_get_default(self.params['asr_batch_size'], 16)
self.model_stride_in_secs = 0.08
# FastConformer requires buffered inference and the parameters for buffered processing.
self.chunk_len_in_sec = 15
self.total_buffer_in_secs = 30
elif 'conformer' in self.ASR_model_name.lower():
self.run_ASR = self.run_ASR_BPE_CTC
self.encdec_class = EncDecCTCModelBPE
self.decoder_delay_in_sec = if_none_get_default(self.params['decoder_delay_in_sec'], 0.08)
self.word_ts_anchor_offset = if_none_get_default(self.params['word_ts_anchor_offset'], 0.12)
self.asr_batch_size = if_none_get_default(self.params['asr_batch_size'], 16)
self.model_stride_in_secs = 0.04
# Conformer requires buffered inference and the parameters for buffered processing.
self.chunk_len_in_sec = 5
self.total_buffer_in_secs = 25
elif 'citrinet' in self.ASR_model_name.lower():
self.run_ASR = self.run_ASR_CitriNet_CTC
self.encdec_class = EncDecCTCModelBPE
self.decoder_delay_in_sec = if_none_get_default(self.params['decoder_delay_in_sec'], 0.16)
self.word_ts_anchor_offset = if_none_get_default(self.params['word_ts_anchor_offset'], 0.2)
self.asr_batch_size = if_none_get_default(self.params['asr_batch_size'], 4)
self.model_stride_in_secs = 0.08
else:
raise ValueError(f"Cannot find the ASR model class for: {self.params['self.ASR_model_name']}")
if self.ASR_model_name.endswith('.nemo'):
asr_model = self.encdec_class.restore_from(restore_path=self.ASR_model_name)
else:
asr_model = self.encdec_class.from_pretrained(model_name=self.ASR_model_name, strict=False)
if self.ctc_decoder_params['pretrained_language_model']:
if not PYCTCDECODE:
raise ImportError(
'LM for beam search decoding is provided but pyctcdecode is not installed. Install pyctcdecode using PyPI: pip install pyctcdecode'
)
self.beam_search_decoder = self.load_LM_for_CTC_decoder(asr_model)
else:
self.beam_search_decoder = None
asr_model.eval()
return asr_model
def load_LM_for_CTC_decoder(self, asr_model: Type[Union[EncDecCTCModel, EncDecCTCModelBPE]]):
"""
Load a language model for CTC decoder (pyctcdecode).
Note that only EncDecCTCModel and EncDecCTCModelBPE models can use pyctcdecode.
"""
kenlm_model = self.ctc_decoder_params['pretrained_language_model']
logging.info(f"Loading language model : {self.ctc_decoder_params['pretrained_language_model']}")
if 'EncDecCTCModelBPE' in str(type(asr_model)):
vocab = asr_model.tokenizer.tokenizer.get_vocab()
labels = list(vocab.keys())
labels[0] = "<unk>"
elif 'EncDecCTCModel' in str(type(asr_model)):
labels = asr_model.decoder.vocabulary
else:
raise ValueError(f"Cannot find a vocabulary or tokenizer for: {self.params['self.ASR_model_name']}")
decoder = build_ctcdecoder(
labels, kenlm_model, alpha=self.ctc_decoder_params['alpha'], beta=self.ctc_decoder_params['beta']
)
return decoder
def run_ASR_QuartzNet_CTC(self, asr_model: Type[EncDecCTCModel]) -> Tuple[Dict, Dict]:
"""
Launch QuartzNet ASR model and collect logit, timestamps and text output.
Args:
asr_model (class):
The loaded NeMo ASR model.
Returns:
words_dict (dict):
Dictionary containing the sequence of words from hypothesis.
word_ts_dict (dict):
Dictionary containing the time-stamps of words.
"""
words_dict, word_ts_dict = {}, {}
wer_ts = WER_TS(
vocabulary=asr_model.decoder.vocabulary,
batch_dim_index=0,
use_cer=asr_model._cfg.get('use_cer', False),
ctc_decode=True,
dist_sync_on_step=True,
log_prediction=asr_model._cfg.get("log_prediction", False),
)
with torch.amp.autocast(asr_model.device.type):
transcript_hyps_list = asr_model.transcribe(
self.audio_file_list, batch_size=self.asr_batch_size, return_hypotheses=True
) # type: List[nemo_asr.parts.Hypothesis]
transcript_logits_list = [hyp.alignments for hyp in transcript_hyps_list]
for idx, logit_np in enumerate(transcript_logits_list):
logit_np = logit_np.cpu().numpy()
uniq_id = get_uniqname_from_filepath(self.audio_file_list[idx])
if self.beam_search_decoder:
logging.info(
f"Running beam-search decoder on {uniq_id} with LM {self.ctc_decoder_params['pretrained_language_model']}"
)
hyp_words, word_ts = self.run_pyctcdecode(logit_np)
else:
log_prob = torch.from_numpy(logit_np)
logits_len = torch.from_numpy(np.array([log_prob.shape[0]]))
greedy_predictions = log_prob.argmax(dim=-1, keepdim=False).unsqueeze(0)
text, char_ts = wer_ts.ctc_decoder_predictions_tensor_with_ts(
greedy_predictions, predictions_len=logits_len
)
trans, char_ts_in_feature_frame_idx = self.clean_trans_and_TS(text[0], char_ts[0])
spaces_in_sec, hyp_words = self._get_spaces(
trans, char_ts_in_feature_frame_idx, self.model_stride_in_secs
)
word_ts = self.get_word_ts_from_spaces(
char_ts_in_feature_frame_idx, spaces_in_sec, end_stamp=logit_np.shape[0]
)
word_ts = self.align_decoder_delay(word_ts, self.decoder_delay_in_sec)
assert len(hyp_words) == len(word_ts), "Words and word timestamp list length does not match."
words_dict[uniq_id] = hyp_words
word_ts_dict[uniq_id] = word_ts
return words_dict, word_ts_dict
@staticmethod
def clean_trans_and_TS(trans: str, char_ts: List[str]) -> Tuple[str, List[str]]:
"""
Remove the spaces in the beginning and the end.
The char_ts need to be changed and synced accordingly.
Args:
trans (list):
List containing the character output (str).
char_ts (list):
List containing the timestamps (int) for each character.
Returns:
trans (list):
List containing the cleaned character output.
char_ts (list):
List containing the cleaned timestamps for each character.
"""
assert (len(trans) > 0) and (len(char_ts) > 0)
assert len(trans) == len(char_ts)
trans = trans.lstrip()
diff_L = len(char_ts) - len(trans)
char_ts = char_ts[diff_L:]
trans = trans.rstrip()
diff_R = len(char_ts) - len(trans)
if diff_R > 0:
char_ts = char_ts[: -1 * diff_R]
return trans, char_ts
def _get_spaces(self, trans: str, char_ts: List[str], time_stride: float) -> Tuple[float, List[str]]:
"""
Collect the space symbols with a list of words.
Args:
trans (list):
List containing the character output (str).
char_ts (list):
List containing the timestamps of the characters.
time_stride (float):
The size of stride of the model in second.
Returns:
spaces_in_sec (list):
List containing the ranges of spaces
word_list (list):
List containing the words from ASR inference.
"""
blank = ' '
spaces_in_sec, word_list = [], []
stt_idx = 0
assert (len(trans) > 0) and (len(char_ts) > 0), "Transcript and char_ts length should not be 0."
assert len(trans) == len(char_ts), "Transcript and timestamp lengths do not match."
# If there is a blank, update the time stamps of the space and the word.
for k, s in enumerate(trans):
if s == blank:
spaces_in_sec.append(
[round(char_ts[k] * time_stride, 2), round((char_ts[k + 1] - 1) * time_stride, 2)]
)
word_list.append(trans[stt_idx:k])
stt_idx = k + 1
# Add the last word
if len(trans) > stt_idx and trans[stt_idx] != blank:
word_list.append(trans[stt_idx:])
return spaces_in_sec, word_list
def run_ASR_CitriNet_CTC(self, asr_model: Type[EncDecCTCModelBPE]) -> Tuple[Dict, Dict]:
"""
Launch CitriNet ASR model and collect logit, timestamps and text output.
Args:
asr_model (class):
The loaded NeMo ASR model.
Returns:
words_dict (dict):
Dictionary containing the sequence of words from hypothesis.
word_ts_dict (dict):
Dictionary containing the timestamps of hypothesis words.
"""
words_dict, word_ts_dict = {}, {}
werbpe_ts = WERBPE_TS(
tokenizer=asr_model.tokenizer,
batch_dim_index=0,
use_cer=asr_model._cfg.get('use_cer', False),
ctc_decode=True,
dist_sync_on_step=True,
log_prediction=asr_model._cfg.get("log_prediction", False),
)
with torch.amp.autocast(asr_model.device.type):
transcript_hyps_list = asr_model.transcribe(
self.audio_file_list, batch_size=self.asr_batch_size, return_hypotheses=True
) # type: List[nemo_asr.parts.Hypothesis]
transcript_logits_list = [hyp.alignments for hyp in transcript_hyps_list]
for idx, logit_np in enumerate(transcript_logits_list):
log_prob = logit_np.cpu().numpy()
uniq_id = get_uniqname_from_filepath(self.audio_file_list[idx])
if self.beam_search_decoder:
logging.info(
f"Running beam-search decoder with LM {self.ctc_decoder_params['pretrained_language_model']}"
)
hyp_words, word_ts = self.run_pyctcdecode(logit_np)
else:
log_prob = torch.from_numpy(logit_np)
greedy_predictions = log_prob.argmax(dim=-1, keepdim=False).unsqueeze(0)
logits_len = torch.from_numpy(np.array([log_prob.shape[0]]))
text, char_ts, word_ts = werbpe_ts.ctc_decoder_predictions_tensor_with_ts(
self.model_stride_in_secs, greedy_predictions, predictions_len=logits_len
)
hyp_words, word_ts = text[0].split(), word_ts[0]
word_ts = self.align_decoder_delay(word_ts, self.decoder_delay_in_sec)
assert len(hyp_words) == len(word_ts), "Words and word timestamp list length does not match."
words_dict[uniq_id] = hyp_words
word_ts_dict[uniq_id] = word_ts
return words_dict, word_ts_dict
def set_buffered_infer_params(self, asr_model: Type[EncDecCTCModelBPE]) -> Tuple[float, float, float]:
"""
Prepare the parameters for the buffered inference.
"""
cfg = copy.deepcopy(asr_model._cfg)
OmegaConf.set_struct(cfg.preprocessor, False)
# some changes for streaming scenario
cfg.preprocessor.dither = 0.0
cfg.preprocessor.pad_to = 0
cfg.preprocessor.normalize = "None"
preprocessor = nemo_asr.models.EncDecCTCModelBPE.from_config_dict(cfg.preprocessor)
preprocessor.to(asr_model.device)
# Disable config overwriting
OmegaConf.set_struct(cfg.preprocessor, True)
onset_delay = (
math.ceil(((self.total_buffer_in_secs - self.chunk_len_in_sec) / 2) / self.model_stride_in_secs) + 1
)
mid_delay = math.ceil(
(self.chunk_len_in_sec + (self.total_buffer_in_secs - self.chunk_len_in_sec) / 2)
/ self.model_stride_in_secs
)
tokens_per_chunk = math.ceil(self.chunk_len_in_sec / self.model_stride_in_secs)
return onset_delay, mid_delay, tokens_per_chunk
def run_ASR_BPE_CTC(self, asr_model: Type[EncDecCTCModelBPE]) -> Tuple[Dict, Dict]:
"""
Launch CTC-BPE based ASR model and collect logit, timestamps and text output.
Args:
asr_model (class):
The loaded NeMo ASR model.
Returns:
words_dict (dict):
Dictionary containing the sequence of words from hypothesis.
word_ts_dict (dict):
Dictionary containing the time-stamps of words.
"""
torch.manual_seed(0)
torch.set_grad_enabled(False)
words_dict, word_ts_dict = {}, {}
werbpe_ts = WERBPE_TS(
tokenizer=asr_model.tokenizer,
batch_dim_index=0,
use_cer=asr_model._cfg.get('use_cer', False),
ctc_decode=True,
dist_sync_on_step=True,
log_prediction=asr_model._cfg.get("log_prediction", False),
)
frame_asr = FrameBatchASRLogits(
asr_model=asr_model,
frame_len=self.chunk_len_in_sec,
total_buffer=self.total_buffer_in_secs,
batch_size=self.asr_batch_size,
)
onset_delay, mid_delay, tokens_per_chunk = self.set_buffered_infer_params(asr_model)
onset_delay_in_sec = round(onset_delay * self.model_stride_in_secs, 2)
with torch.amp.autocast(asr_model.device.type):
logging.info(f"Running ASR model {self.ASR_model_name}")
for idx, audio_file_path in enumerate(self.audio_file_list):
uniq_id = get_uniqname_from_filepath(audio_file_path)
logging.info(f"[{idx+1}/{len(self.audio_file_list)}] FrameBatchASR: {audio_file_path}")
frame_asr.clear_buffer()
hyp, greedy_predictions_list, log_prob = get_wer_feat_logit(
audio_file_path,
frame_asr,
self.chunk_len_in_sec,
tokens_per_chunk,
mid_delay,
self.model_stride_in_secs,
)
if self.beam_search_decoder:
logging.info(
f"Running beam-search decoder with LM {self.ctc_decoder_params['pretrained_language_model']}"
)
log_prob = log_prob.unsqueeze(0).cpu().numpy()[0]
hyp_words, word_ts = self.run_pyctcdecode(log_prob, onset_delay_in_sec=onset_delay_in_sec)
else:
logits_len = torch.from_numpy(np.array([len(greedy_predictions_list)]))
greedy_predictions_list = greedy_predictions_list[onset_delay:]
greedy_predictions = torch.from_numpy(np.array(greedy_predictions_list)).unsqueeze(0)
text, char_ts, word_ts = werbpe_ts.ctc_decoder_predictions_tensor_with_ts(
self.model_stride_in_secs, greedy_predictions, predictions_len=logits_len
)
hyp_words, word_ts = text[0].split(), word_ts[0]
word_ts = self.align_decoder_delay(word_ts, self.decoder_delay_in_sec)
assert len(hyp_words) == len(word_ts), "Words and word timestamp list length does not match."
words_dict[uniq_id] = hyp_words
word_ts_dict[uniq_id] = word_ts
return words_dict, word_ts_dict
def get_word_ts_from_spaces(self, char_ts: List[float], spaces_in_sec: List[float], end_stamp: float) -> List[str]:
"""
Take word timestamps from the spaces from the decoded prediction.
Args:
char_ts (list):
List containing the timestamp for each character.
spaces_in_sec (list):
List containing the start and the end time of each space token.
end_stamp (float):
The end time of the session in sec.
Returns:
word_timestamps (list):
List containing the timestamps for the resulting words.
"""
end_stamp = min(end_stamp, (char_ts[-1] + 2))
start_stamp_in_sec = round(char_ts[0] * self.model_stride_in_secs, 2)
end_stamp_in_sec = round(end_stamp * self.model_stride_in_secs, 2)
# In case of one word output with no space information.
if len(spaces_in_sec) == 0:
word_timestamps = [[start_stamp_in_sec, end_stamp_in_sec]]
elif len(spaces_in_sec) > 0:
# word_timetamps_middle should be an empty list if len(spaces_in_sec) == 1.
word_timetamps_middle = [
[
round(spaces_in_sec[k][1], 2),
round(spaces_in_sec[k + 1][0], 2),
]
for k in range(len(spaces_in_sec) - 1)
]
word_timestamps = (
[[start_stamp_in_sec, round(spaces_in_sec[0][0], 2)]]
+ word_timetamps_middle
+ [[round(spaces_in_sec[-1][1], 2), end_stamp_in_sec]]
)
return word_timestamps
def run_pyctcdecode(
self, logprob: np.ndarray, onset_delay_in_sec: float = 0, beam_width: int = 32
) -> Tuple[List[str], List[str]]:
"""
Launch pyctcdecode with the loaded pretrained language model.
Args:
logprob (np.ndarray):
The log probability from the ASR model inference in numpy array format.
onset_delay_in_sec (float):
The amount of delay that needs to be compensated for the timestamp outputs froM pyctcdecode.
beam_width (int):
The beam width parameter for beam search decodring.
Returns:
hyp_words (list):
List containing the words in the hypothesis.
word_ts (list):
List containing the word timestamps from the decoder.
"""
beams = self.beam_search_decoder.decode_beams(logprob, beam_width=self.ctc_decoder_params['beam_width'])
word_ts_beam, words_beam = [], []
for idx, (word, _) in enumerate(beams[0][2]):
ts = self.get_word_ts_from_wordframes(idx, beams[0][2], self.model_stride_in_secs, onset_delay_in_sec)
word_ts_beam.append(ts)
words_beam.append(word)
hyp_words, word_ts = words_beam, word_ts_beam
return hyp_words, word_ts
@staticmethod
def get_word_ts_from_wordframes(
idx, word_frames: List[List[float]], frame_duration: float, onset_delay: float, word_block_delay: float = 2.25
):
"""
Extract word timestamps from word frames generated from pyctcdecode.
"""
offset = -1 * word_block_delay * frame_duration - onset_delay
frame_begin = word_frames[idx][1][0]
if frame_begin == -1:
frame_begin = word_frames[idx - 1][1][1] if idx != 0 else 0
frame_end = word_frames[idx][1][1]
return [
round(max(frame_begin * frame_duration + offset, 0), 2),
round(max(frame_end * frame_duration + offset, 0), 2),
]
@staticmethod
def align_decoder_delay(word_ts, decoder_delay_in_sec: float):
"""
Subtract decoder_delay_in_sec from the word timestamp output.
"""
for k in range(len(word_ts)):
word_ts[k] = [
round(word_ts[k][0] - decoder_delay_in_sec, 2),
round(word_ts[k][1] - decoder_delay_in_sec, 2),
]
return word_ts