diff --git a/nemo/collections/asr/parts/utils/streaming_utils.py b/nemo/collections/asr/parts/utils/streaming_utils.py index 17d2d15f4daa..622b4fe57478 100644 --- a/nemo/collections/asr/parts/utils/streaming_utils.py +++ b/nemo/collections/asr/parts/utils/streaming_utils.py @@ -700,6 +700,7 @@ def __init__( ) self.asr_model = asr_model + self.decoder = asr_model.decoder self.batch_size = batch_size self.all_logits = [] @@ -716,6 +717,7 @@ def __init__( self.frame_buffers = [] self.reset() cfg = copy.deepcopy(asr_model._cfg) + self.cfg = cfg self.frame_len = frame_len OmegaConf.set_struct(cfg.preprocessor, False) @@ -725,6 +727,7 @@ def __init__( cfg.preprocessor.normalize = "None" self.raw_preprocessor = EncDecCTCModelBPE.from_config_dict(cfg.preprocessor) self.raw_preprocessor.to(asr_model.device) + self.preprocessor = self.raw_preprocessor def reset(self): """ @@ -750,17 +753,17 @@ def set_frame_reader(self, frame_reader): self.frame_bufferer.set_frame_reader(frame_reader) @torch.no_grad() - def infer_logits(self): + def infer_logits(self, keep_logits=False): frame_buffers = self.frame_bufferer.get_buffers_batch() while len(frame_buffers) > 0: self.frame_buffers += frame_buffers[:] self.data_layer.set_signal(frame_buffers[:]) - self._get_batch_preds() + self._get_batch_preds(keep_logits) frame_buffers = self.frame_bufferer.get_buffers_batch() @torch.no_grad() - def _get_batch_preds(self): + def _get_batch_preds(self, keep_logits=False): device = self.asr_model.device for batch in iter(self.data_loader): @@ -772,19 +775,32 @@ def _get_batch_preds(self): preds = torch.unbind(predictions) for pred in preds: self.all_preds.append(pred.cpu().numpy()) - del log_probs + if keep_logits: + log_probs = torch.unbind(log_probs) + for log_prob in log_probs: + self.all_logits.append(log_prob.cpu()) + else: + del log_probs del encoded_len del predictions - def transcribe( - self, tokens_per_chunk: int, delay: int, - ): - self.infer_logits() + def transcribe(self, tokens_per_chunk: int, delay: int, keep_logits=False): + self.infer_logits(keep_logits) self.unmerged = [] for pred in self.all_preds: decoded = pred.tolist() self.unmerged += decoded[len(decoded) - 1 - delay : len(decoded) - 1 - delay + tokens_per_chunk] - return self.greedy_merge(self.unmerged) + hypothesis = self.greedy_merge(self.unmerged) + if not keep_logits: + return hypothesis + + all_logits = [] + for log_prob in self.all_logits: + T = log_prob.shape[0] + log_prob = log_prob[T - 1 - delay : T - 1 - delay + tokens_per_chunk, :] + all_logits.append(log_prob) + all_logits = torch.concat(all_logits, 0) + return hypothesis, all_logits def greedy_merge(self, preds): decoded_prediction = [] diff --git a/tools/nemo_forced_aligner/README.md b/tools/nemo_forced_aligner/README.md index 9d02177e1694..35ee78ffecb0 100644 --- a/tools/nemo_forced_aligner/README.md +++ b/tools/nemo_forced_aligner/README.md @@ -46,6 +46,12 @@ Call the `align.py` script, specifying the parameters as follows: * **[OPTIONAL]** `minimum_timestamp_duration`: a float indicating a minimum duration (in seconds) for timestamps in the CTM. If any line in the CTM has a duration lower than the `minimum_timestamp_duration`, it will be enlarged from the middle outwards until it meets the minimum_timestamp_duration, or reaches the beginning or end of the audio file. Note that this may cause timestamps to overlap. (Default: 0, i.e. no modifications to predicted duration). +* **[OPTIONAL]** `use_buffered_chunked_streaming`: a flag to indicate whether to do buffered chunk streaming. Notice only CTC models (e.g., stt_en_citrinet_1024_gamma_0_25)with `per_feature` preprocessor are supported. The below two params are needed if this option set to `True`. + +* **[OPTIONAL]** `chunk_len_in_secs`: the chunk size for buffered chunked streaming inference. Default is 1.6 seconds. + +* **[OPTIONAL]** `total_buffer_in_secs`: the buffer size for buffered chunked streaming inference. Default is 4.0 seconds. + # Input manifest file format By default, NFA needs to be provided with a 'manifest' file where each line specifies the absolute "audio_filepath" and "text" of each utterance that you wish to produce alignments for, like the format below: ```json diff --git a/tools/nemo_forced_aligner/align.py b/tools/nemo_forced_aligner/align.py index 56627614c18d..e688060f529d 100644 --- a/tools/nemo_forced_aligner/align.py +++ b/tools/nemo_forced_aligner/align.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import copy +import math import os from dataclasses import dataclass, is_dataclass from typing import Optional @@ -29,11 +31,11 @@ from utils.viterbi_decoding import viterbi_decoding from nemo.collections.asr.models.ctc_models import EncDecCTCModel +from nemo.collections.asr.parts.utils.streaming_utils import FrameBatchASR from nemo.collections.asr.parts.utils.transcribe_utils import setup_model from nemo.core.config import hydra_runner from nemo.utils import logging - """ Align the utterances in manifest_filepath. Results are saved in ctm files in output_dir. @@ -82,6 +84,16 @@ line in the CTM has a duration lower than the `minimum_timestamp_duration`, it will be enlarged from the middle outwards until it meets the minimum_timestamp_duration, or reaches the beginning or end of the audio file. Note that this may cause timestamps to overlap. + + use_buffered_infer: False, if set True, using streaming to do get the logits for alignment + This flag is useful when aligning large audio file. + However, currently the chunk streaming inference does not support batch inference, + which means even you set batch_size > 1, it will only infer one by one instead of doing + the whole batch inference together. + chunk_len_in_secs: float chunk length in seconds + total_buffer_in_secs: float Length of buffer (chunk + left and right padding) in seconds + chunk_batch_size: int batch size for buffered chunk inference, + which will cut one audio into segments and do inference on chunk_batch_size segments at a time """ @@ -104,6 +116,12 @@ class AlignmentConfig: minimum_timestamp_duration: float = 0 audio_filepath_parts_in_utt_id: int = 1 + # Buffered chunked streaming configs + use_buffered_chunked_streaming: bool = False + chunk_len_in_secs: float = 1.6 + total_buffer_in_secs: float = 4.0 + chunk_batch_size: int = 32 + @hydra_runner(config_name="AlignmentConfig", schema=AlignmentConfig) def main(cfg: AlignmentConfig): @@ -194,6 +212,41 @@ def main(cfg: AlignmentConfig): "This may cause the alignments for some tokens/words/additional segments to be overlapping." ) + buffered_chunk_params = {} + if cfg.use_buffered_chunked_streaming: + model_cfg = copy.deepcopy(model._cfg) + + OmegaConf.set_struct(model_cfg.preprocessor, False) + # some changes for streaming scenario + model_cfg.preprocessor.dither = 0.0 + model_cfg.preprocessor.pad_to = 0 + + if model_cfg.preprocessor.normalize != "per_feature": + logging.error( + "Only EncDecCTCModelBPE models trained with per_feature normalization are supported currently" + ) + # Disable config overwriting + OmegaConf.set_struct(model_cfg.preprocessor, True) + + feature_stride = model_cfg.preprocessor['window_stride'] + model_stride_in_secs = feature_stride * cfg.model_downsample_factor + total_buffer = cfg.total_buffer_in_secs + chunk_len = float(cfg.chunk_len_in_secs) + tokens_per_chunk = math.ceil(chunk_len / model_stride_in_secs) + mid_delay = math.ceil((chunk_len + (total_buffer - chunk_len) / 2) / model_stride_in_secs) + logging.info(f"tokens_per_chunk is {tokens_per_chunk}, mid_delay is {mid_delay}") + + model = FrameBatchASR( + asr_model=model, + frame_len=chunk_len, + total_buffer=cfg.total_buffer_in_secs, + batch_size=cfg.chunk_batch_size, + ) + buffered_chunk_params = { + "delay": mid_delay, + "model_stride_in_secs": model_stride_in_secs, + "tokens_per_chunk": tokens_per_chunk, + } # get start and end line IDs of batches starts, ends = get_batch_starts_ends(cfg.manifest_filepath, cfg.batch_size) @@ -217,7 +270,12 @@ def main(cfg: AlignmentConfig): segment_info_batch, pred_text_batch, ) = get_batch_tensors_and_boundary_info( - manifest_lines_batch, model, cfg.additional_ctm_grouping_separator, cfg.align_using_pred_text, + manifest_lines_batch, + model, + cfg.additional_ctm_grouping_separator, + cfg.align_using_pred_text, + cfg.use_buffered_chunked_streaming, + buffered_chunk_params, ) if cfg.align_using_pred_text: diff --git a/tools/nemo_forced_aligner/utils/data_prep.py b/tools/nemo_forced_aligner/utils/data_prep.py index c7ed101ac9e7..c506bee0d818 100644 --- a/tools/nemo_forced_aligner/utils/data_prep.py +++ b/tools/nemo_forced_aligner/utils/data_prep.py @@ -17,6 +17,7 @@ import soundfile as sf import torch +from tqdm.auto import tqdm from utils.constants import BLANK_TOKEN, SPACE_TOKEN, V_NEGATIVE_NUM @@ -140,8 +141,10 @@ def get_y_and_boundary_info_for_utt(text, model, separator): segments = [seg.strip() for seg in segments] if hasattr(model, 'tokenizer'): - - BLANK_ID = len(model.decoder.vocabulary) # TODO: check + if hasattr(model, 'blank_id'): + BLANK_ID = model.blank_id + else: + BLANK_ID = len(model.decoder.vocabulary) # TODO: check y_token_ids_with_blanks = [BLANK_ID] token_info = [{"text": BLANK_TOKEN, "s_start": 0, "s_end": 0,}] @@ -283,7 +286,14 @@ def get_y_and_boundary_info_for_utt(text, model, separator): raise RuntimeError("Cannot get tokens of this model.") -def get_batch_tensors_and_boundary_info(manifest_lines_batch, model, separator, align_using_pred_text): +def get_batch_tensors_and_boundary_info( + manifest_lines_batch, + model, + separator, + align_using_pred_text, + use_buffered_chunked_streaming=False, + buffered_chunk_params={}, +): """ Returns: log_probs, y, T, U (y and U are s.t. every other token is a blank) - these are the tensors we will need @@ -299,16 +309,28 @@ def get_batch_tensors_and_boundary_info(manifest_lines_batch, model, separator, # and (optionally) the predicted ASR text from the hypotheses audio_filepaths_batch = [line["audio_filepath"] for line in manifest_lines_batch] B = len(audio_filepaths_batch) - with torch.no_grad(): - hypotheses = model.transcribe(audio_filepaths_batch, return_hypotheses=True, batch_size=B) - log_probs_list_batch = [] T_list_batch = [] pred_text_batch = [] - for hypothesis in hypotheses: - log_probs_list_batch.append(hypothesis.y_sequence) - T_list_batch.append(hypothesis.y_sequence.shape[0]) - pred_text_batch.append(hypothesis.text) + + if not use_buffered_chunked_streaming: + with torch.no_grad(): + hypotheses = model.transcribe(audio_filepaths_batch, return_hypotheses=True, batch_size=B) + for hypothesis in hypotheses: + log_probs_list_batch.append(hypothesis.y_sequence) + T_list_batch.append(hypothesis.y_sequence.shape[0]) + pred_text_batch.append(hypothesis.text) + else: + delay = buffered_chunk_params["delay"] + model_stride_in_secs = buffered_chunk_params["model_stride_in_secs"] + tokens_per_chunk = buffered_chunk_params["tokens_per_chunk"] + for l in tqdm(audio_filepaths_batch, desc="Sample:"): + model.reset() + model.read_audio_file(l, delay, model_stride_in_secs) + hyp, logits = model.transcribe(tokens_per_chunk, delay, keep_logits=True) + log_probs_list_batch.append(logits) + T_list_batch.append(logits.shape[0]) + pred_text_batch.append(hyp) # we loop over every line in the manifest that is in our current batch, # and record the y (list of tokens, including blanks), U (list of lengths of y) and