From bccb46e10310a213ef924c87d6a7deef5a1703e1 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Thu, 4 May 2023 02:36:27 -0700 Subject: [PATCH] Patch transcribe and support offline transcribe for hybrid model (#6550) (#6559) Signed-off-by: fayejf Co-authored-by: fayejf <36722593+fayejf@users.noreply.github.com> --- examples/asr/transcribe_speech.py | 17 ++++++++++++++--- .../asr/parts/utils/transcribe_utils.py | 4 ++-- tools/asr_evaluator/conf/eval.yaml | 2 +- tools/asr_evaluator/utils.py | 3 ++- 4 files changed, 19 insertions(+), 7 deletions(-) diff --git a/examples/asr/transcribe_speech.py b/examples/asr/transcribe_speech.py index 3493fb28d81d..30700153e340 100644 --- a/examples/asr/transcribe_speech.py +++ b/examples/asr/transcribe_speech.py @@ -19,11 +19,11 @@ import pytorch_lightning as pl import torch -from omegaconf import OmegaConf +from omegaconf import OmegaConf, open_dict from nemo.collections.asr.metrics.rnnt_wer import RNNTDecodingConfig from nemo.collections.asr.metrics.wer import CTCDecodingConfig -from nemo.collections.asr.models.ctc_models import EncDecCTCModel +from nemo.collections.asr.models import EncDecCTCModel, EncDecHybridRNNTCTCModel from nemo.collections.asr.modules.conformer_encoder import ConformerChangeConfig from nemo.collections.asr.parts.utils.transcribe_utils import ( compute_output_filename, @@ -154,6 +154,9 @@ class TranscriptionConfig: def main(cfg: TranscriptionConfig) -> TranscriptionConfig: logging.info(f'Hydra config: {OmegaConf.to_yaml(cfg)}') + for key in cfg: + cfg[key] = None if cfg[key] == 'None' else cfg[key] + if is_dataclass(cfg): cfg = OmegaConf.structured(cfg) @@ -223,7 +226,6 @@ def main(cfg: TranscriptionConfig) -> TranscriptionConfig: decoding_cfg.preserve_alignments = cfg.compute_timestamps if 'compute_langs' in decoding_cfg: decoding_cfg.compute_langs = cfg.compute_langs - asr_model.change_decoding_strategy(decoding_cfg, decoder_type=cfg.decoder_type) # Check if ctc or rnnt model @@ -243,6 +245,15 @@ def main(cfg: TranscriptionConfig) -> TranscriptionConfig: asr_model.change_decoding_strategy(cfg.ctc_decoding) + # Setup decoding config based on model type and decoder_type + with open_dict(cfg): + if isinstance(asr_model, EncDecCTCModel) or ( + isinstance(asr_model, EncDecHybridRNNTCTCModel) and cfg.decoder_type == "ctc" + ): + cfg.decoding = cfg.ctc_decoding + else: + cfg.decoding = cfg.rnnt_decoding + # prepare audio filepaths and decide wether it's partical audio filepaths, partial_audio = prepare_audio_data(cfg) diff --git a/nemo/collections/asr/parts/utils/transcribe_utils.py b/nemo/collections/asr/parts/utils/transcribe_utils.py index d59d453ba972..8cfe58523751 100644 --- a/nemo/collections/asr/parts/utils/transcribe_utils.py +++ b/nemo/collections/asr/parts/utils/transcribe_utils.py @@ -289,14 +289,14 @@ def write_transcription( if isinstance(transcriptions[0], rnnt_utils.Hypothesis): # List[rnnt_utils.Hypothesis] best_hyps = transcriptions - assert cfg.ctc_decoding.beam.return_best_hypothesis, "Works only with return_best_hypothesis=true" + assert cfg.decoding.beam.return_best_hypothesis, "Works only with return_best_hypothesis=true" elif isinstance(transcriptions[0], list) and isinstance( transcriptions[0][0], rnnt_utils.Hypothesis ): # List[List[rnnt_utils.Hypothesis]] NBestHypothesis best_hyps, beams = [], [] for hyps in transcriptions: best_hyps.append(hyps[0]) - if not cfg.ctc_decoding.beam.return_best_hypothesis: + if not cfg.decoding.beam.return_best_hypothesis: beam = [] for hyp in hyps: beam.append((hyp.text, hyp.score)) diff --git a/tools/asr_evaluator/conf/eval.yaml b/tools/asr_evaluator/conf/eval.yaml index 95e7c94b5b43..9129eddc49f1 100644 --- a/tools/asr_evaluator/conf/eval.yaml +++ b/tools/asr_evaluator/conf/eval.yaml @@ -13,7 +13,7 @@ engine: chunk_len_in_secs: 1.6 #null # Need to specify if use buffered inference (default for offline_by_chunked is 20) total_buffer_in_secs: 4 #null # Need to specify if use buffered inference (default for offline_by_chunked is 22) model_stride: 4 # Model downsampling factor, 8 for Citrinet models and 4 for Conformer models - + decoder_type: null # Used for hybrid CTC RNNT model only. Specify decoder_type *ctc* or *rnnt* for hybrid CTC RNNT model. test_ds: manifest_filepath: null sample_rate: 16000 diff --git a/tools/asr_evaluator/utils.py b/tools/asr_evaluator/utils.py index ad69b249f5db..c233376eb13a 100644 --- a/tools/asr_evaluator/utils.py +++ b/tools/asr_evaluator/utils.py @@ -154,7 +154,8 @@ def run_offline_inference(cfg: DictConfig) -> DictConfig: f"output_filename={cfg.output_filename} " f"batch_size={cfg.test_ds.batch_size} " f"random_seed={cfg.random_seed} " - f"eval_config_yaml={f.name} ", + f"eval_config_yaml={f.name} " + f"decoder_type={cfg.inference.decoder_type} ", shell=True, check=True, )