From 512073345cc0713e20138e3dd49669de508586d2 Mon Sep 17 00:00:00 2001 From: Igor Gitman Date: Mon, 10 Jul 2023 09:34:47 -0700 Subject: [PATCH] Fixing an issue with confidence ensembles (#6987) * Bug fix for the confidence ensembles Signed-off-by: Igor Gitman * Relax constraints for the test Signed-off-by: Igor Gitman --------- Signed-off-by: Igor Gitman --- examples/asr/transcribe_speech.py | 8 ++++++-- nemo/collections/asr/models/confidence_ensemble.py | 9 +++++---- scripts/confidence_ensembles/build_ensemble.py | 6 ++---- .../confidence_ensembles/test_confidence_ensembles.py | 2 +- 4 files changed, 14 insertions(+), 11 deletions(-) diff --git a/examples/asr/transcribe_speech.py b/examples/asr/transcribe_speech.py index 4ed3d92a6305..401755bc8275 100644 --- a/examples/asr/transcribe_speech.py +++ b/examples/asr/transcribe_speech.py @@ -130,6 +130,8 @@ class TranscriptionConfig: # Set to True to output greedy timestamp information (only supported models) compute_timestamps: bool = False + # set to True if need to return full alignment information + preserve_alignment: bool = False # Set to True to output language ID information compute_langs: bool = False @@ -230,6 +232,8 @@ def main(cfg: TranscriptionConfig) -> Union[TranscriptionConfig, List[Hypothesis # we will adjust this flag if the model does not support it compute_timestamps = cfg.compute_timestamps compute_langs = cfg.compute_langs + # has to be True if timestamps are required + preserve_alignment = True if cfg.compute_timestamps else cfg.preserve_alignment # Check whether model and decoder type match if isinstance(asr_model, EncDecCTCModel): @@ -252,7 +256,7 @@ def main(cfg: TranscriptionConfig) -> Union[TranscriptionConfig, List[Hypothesis decoding_cfg = cfg.rnnt_decoding if cfg.decoder_type == 'rnnt' else cfg.ctc_decoding decoding_cfg.compute_timestamps = cfg.compute_timestamps # both ctc and rnnt support it if 'preserve_alignments' in decoding_cfg: - decoding_cfg.preserve_alignments = cfg.compute_timestamps + decoding_cfg.preserve_alignments = preserve_alignment if 'compute_langs' in decoding_cfg: decoding_cfg.compute_langs = cfg.compute_langs if hasattr(asr_model, 'cur_decoder'): @@ -267,7 +271,7 @@ def main(cfg: TranscriptionConfig) -> Union[TranscriptionConfig, List[Hypothesis cfg.rnnt_decoding.compute_langs = cfg.compute_langs if 'preserve_alignments' in cfg.rnnt_decoding: - cfg.rnnt_decoding.preserve_alignments = cfg.compute_timestamps + cfg.rnnt_decoding.preserve_alignments = preserve_alignment asr_model.change_decoding_strategy(cfg.rnnt_decoding) else: diff --git a/nemo/collections/asr/models/confidence_ensemble.py b/nemo/collections/asr/models/confidence_ensemble.py index cd4738e7b97c..dd52d9a7010a 100644 --- a/nemo/collections/asr/models/confidence_ensemble.py +++ b/nemo/collections/asr/models/confidence_ensemble.py @@ -106,6 +106,11 @@ def get_filtered_logprobs(hypothesis: Hypothesis, exclude_blank: bool) -> torch. filtered_logprobs = logprobs[:1] else: filtered_logprobs = logprobs + + # need to make sure logprobs are always normalized, so checking if they sum up to 1 + if not torch.allclose(filtered_logprobs[0].exp().sum(), torch.tensor(1.0)): + filtered_logprobs = torch.log_softmax(filtered_logprobs, dim=1) + return filtered_logprobs @@ -217,10 +222,6 @@ def update_decoding_parameters(self, decoding_cfg: DictConfig): with open_dict(decoding_cfg): decoding_cfg.temperature = self.cfg.temperature decoding_cfg.preserve_alignments = True - if 'confidence_cfg' in decoding_cfg: - decoding_cfg.confidence_cfg.preserve_frame_confidence = True - else: - decoding_cfg.confidence_cfg = ConfidenceConfig(preserve_frame_confidence=True) def setup_training_data(self, train_data_config: Union[DictConfig, Dict]): """Pass-through to the ensemble models. diff --git a/scripts/confidence_ensembles/build_ensemble.py b/scripts/confidence_ensembles/build_ensemble.py index 07ceccb8b3d5..e953dec02b7a 100644 --- a/scripts/confidence_ensembles/build_ensemble.py +++ b/scripts/confidence_ensembles/build_ensemble.py @@ -458,7 +458,7 @@ def find_best_confidence( return best_conf_spec.to_confidence_config(), best_pipe -@hydra_runner(schema=BuildEnsembleConfig) +@hydra_runner(config_name="BuildEnsembleConfig", schema=BuildEnsembleConfig) def main(cfg: BuildEnsembleConfig): # silencing all messages from nemo/ptl to avoid dumping tons of configs to the stdout logging.getLogger('pytorch_lightning').setLevel(logging.CRITICAL) @@ -471,12 +471,10 @@ def main(cfg: BuildEnsembleConfig): pl.seed_everything(cfg.random_seed) cfg.transcription.random_seed = None # seed is already applied cfg.transcription.return_transcriptions = True - # that sets preserve_alignment to True - cfg.transcription.compute_timestamps = True + cfg.transcription.preserve_alignment = True cfg.transcription.ctc_decoding.temperature = cfg.temperature cfg.transcription.rnnt_decoding.temperature = cfg.temperature # this ensures that generated output is after log-softmax for consistency with CTC - cfg.transcription.rnnt_decoding.confidence_cfg.preserve_frame_confidence = True train_confidences = [] dev_confidences = [] diff --git a/scripts/confidence_ensembles/test_confidence_ensembles.py b/scripts/confidence_ensembles/test_confidence_ensembles.py index b665375c0c33..fa537529ab6b 100644 --- a/scripts/confidence_ensembles/test_confidence_ensembles.py +++ b/scripts/confidence_ensembles/test_confidence_ensembles.py @@ -113,4 +113,4 @@ def test_confidence_ensemble(tmp_path, build_args): ) results = speech_to_text_eval.main(eval_cfg) - assert results.metric_value < 0.15 # relaxed check for better than 15% WER + assert results.metric_value < 0.20 # relaxed check for better than 20% WER