diff --git a/examples/asr/transcribe_speech_parallel.py b/examples/asr/transcribe_speech_parallel.py index 74019d7668f0..f14df284c6b1 100644 --- a/examples/asr/transcribe_speech_parallel.py +++ b/examples/asr/transcribe_speech_parallel.py @@ -32,6 +32,15 @@ predict_ds.batch_size=16 \ output_path=/tmp/ +Example for Hybrid-CTC/RNNT models with non-tarred datasets: + +python transcribe_speech_parallel.py \ + model=stt_en_fastconformer_hybrid_large \ + decoder_type=ctc \ + predict_ds.manifest_filepath=/dataset/manifest_file.json \ + predict_ds.batch_size=16 \ + output_path=/tmp/ + Example for tarred datasets: python transcribe_speech_parallel.py \ @@ -73,7 +82,7 @@ from nemo.collections.asr.data.audio_to_text_dataset import ASRPredictionWriter from nemo.collections.asr.metrics.rnnt_wer import RNNTDecodingConfig from nemo.collections.asr.metrics.wer import word_error_rate -from nemo.collections.asr.models import ASRModel +from nemo.collections.asr.models import ASRModel, EncDecHybridRNNTCTCModel from nemo.collections.asr.models.configs.asr_models_config import ASRDatasetConfig from nemo.core.config import TrainerConfig, hydra_runner from nemo.utils import logging @@ -92,6 +101,10 @@ class ParallelTranscriptionConfig: # decoding strategy for RNNT models rnnt_decoding: RNNTDecodingConfig = RNNTDecodingConfig() + + # decoder for hybrid models, must be one of 'ctc', 'rnnt' if not None + decoder_type: Optional[str] = None + trainer: TrainerConfig = TrainerConfig(devices=-1, accelerator="gpu", strategy="ddp") @@ -137,6 +150,9 @@ def main(cfg: ParallelTranscriptionConfig): ) model = ASRModel.from_pretrained(model_name=cfg.model, map_location="cpu") + if isinstance(model, EncDecHybridRNNTCTCModel) and cfg.decoder_type is not None: + model.change_decoding_strategy(decoder_type=cfg.decoder_type) + trainer = ptl.Trainer(**cfg.trainer) cfg.predict_ds.return_sample_id = True diff --git a/nemo/collections/asr/data/audio_to_text_dataset.py b/nemo/collections/asr/data/audio_to_text_dataset.py index 14e8dea19651..d5dcc8be4847 100644 --- a/nemo/collections/asr/data/audio_to_text_dataset.py +++ b/nemo/collections/asr/data/audio_to_text_dataset.py @@ -713,6 +713,7 @@ def write_on_batch_end( item = {} sample = self.dataset.get_manifest_sample(sample_id) item["audio_filepath"] = sample.audio_file + item["offset"] = sample.offset item["duration"] = sample.duration item["text"] = sample.text_raw item["pred_text"] = transcribed_text