Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixing an issue with confidence ensembles #7004

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions examples/asr/transcribe_speech.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,8 @@ class TranscriptionConfig:

# Set to True to output greedy timestamp information (only supported models)
compute_timestamps: bool = False
# set to True if need to return full alignment information
preserve_alignment: bool = False

# Set to True to output language ID information
compute_langs: bool = False
Expand Down Expand Up @@ -230,6 +232,8 @@ def main(cfg: TranscriptionConfig) -> Union[TranscriptionConfig, List[Hypothesis
# we will adjust this flag if the model does not support it
compute_timestamps = cfg.compute_timestamps
compute_langs = cfg.compute_langs
# has to be True if timestamps are required
preserve_alignment = True if cfg.compute_timestamps else cfg.preserve_alignment

# Check whether model and decoder type match
if isinstance(asr_model, EncDecCTCModel):
Expand All @@ -252,7 +256,7 @@ def main(cfg: TranscriptionConfig) -> Union[TranscriptionConfig, List[Hypothesis
decoding_cfg = cfg.rnnt_decoding if cfg.decoder_type == 'rnnt' else cfg.ctc_decoding
decoding_cfg.compute_timestamps = cfg.compute_timestamps # both ctc and rnnt support it
if 'preserve_alignments' in decoding_cfg:
decoding_cfg.preserve_alignments = cfg.compute_timestamps
decoding_cfg.preserve_alignments = preserve_alignment
if 'compute_langs' in decoding_cfg:
decoding_cfg.compute_langs = cfg.compute_langs
if hasattr(asr_model, 'cur_decoder'):
Expand All @@ -267,7 +271,7 @@ def main(cfg: TranscriptionConfig) -> Union[TranscriptionConfig, List[Hypothesis
cfg.rnnt_decoding.compute_langs = cfg.compute_langs

if 'preserve_alignments' in cfg.rnnt_decoding:
cfg.rnnt_decoding.preserve_alignments = cfg.compute_timestamps
cfg.rnnt_decoding.preserve_alignments = preserve_alignment

asr_model.change_decoding_strategy(cfg.rnnt_decoding)
else:
Expand Down
9 changes: 5 additions & 4 deletions nemo/collections/asr/models/confidence_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,11 @@ def get_filtered_logprobs(hypothesis: Hypothesis, exclude_blank: bool) -> torch.
filtered_logprobs = logprobs[:1]
else:
filtered_logprobs = logprobs

# need to make sure logprobs are always normalized, so checking if they sum up to 1
if not torch.allclose(filtered_logprobs[0].exp().sum(), torch.tensor(1.0)):
filtered_logprobs = torch.log_softmax(filtered_logprobs, dim=1)

return filtered_logprobs


Expand Down Expand Up @@ -217,10 +222,6 @@ def update_decoding_parameters(self, decoding_cfg: DictConfig):
with open_dict(decoding_cfg):
decoding_cfg.temperature = self.cfg.temperature
decoding_cfg.preserve_alignments = True
if 'confidence_cfg' in decoding_cfg:
decoding_cfg.confidence_cfg.preserve_frame_confidence = True
else:
decoding_cfg.confidence_cfg = ConfidenceConfig(preserve_frame_confidence=True)

def setup_training_data(self, train_data_config: Union[DictConfig, Dict]):
"""Pass-through to the ensemble models.
Expand Down
6 changes: 2 additions & 4 deletions scripts/confidence_ensembles/build_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,7 +458,7 @@ def find_best_confidence(
return best_conf_spec.to_confidence_config(), best_pipe


@hydra_runner(schema=BuildEnsembleConfig)
@hydra_runner(config_name="BuildEnsembleConfig", schema=BuildEnsembleConfig)
def main(cfg: BuildEnsembleConfig):
# silencing all messages from nemo/ptl to avoid dumping tons of configs to the stdout
logging.getLogger('pytorch_lightning').setLevel(logging.CRITICAL)
Expand All @@ -471,12 +471,10 @@ def main(cfg: BuildEnsembleConfig):
pl.seed_everything(cfg.random_seed)
cfg.transcription.random_seed = None # seed is already applied
cfg.transcription.return_transcriptions = True
# that sets preserve_alignment to True
cfg.transcription.compute_timestamps = True
cfg.transcription.preserve_alignment = True
cfg.transcription.ctc_decoding.temperature = cfg.temperature
cfg.transcription.rnnt_decoding.temperature = cfg.temperature
# this ensures that generated output is after log-softmax for consistency with CTC
cfg.transcription.rnnt_decoding.confidence_cfg.preserve_frame_confidence = True

train_confidences = []
dev_confidences = []
Expand Down
2 changes: 1 addition & 1 deletion scripts/confidence_ensembles/test_confidence_ensembles.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,4 +113,4 @@ def test_confidence_ensemble(tmp_path, build_args):
)

results = speech_to_text_eval.main(eval_cfg)
assert results.metric_value < 0.15 # relaxed check for better than 15% WER
assert results.metric_value < 0.20 # relaxed check for better than 20% WER