diff --git a/nemo/collections/asr/metrics/bleu.py b/nemo/collections/asr/metrics/bleu.py index e76aa2915760..fda4c68a7a19 100644 --- a/nemo/collections/asr/metrics/bleu.py +++ b/nemo/collections/asr/metrics/bleu.py @@ -116,7 +116,7 @@ def __init__( encoder_hidden_states=predictions, encoder_input_mask=predictions_mask, decoder_input_ids=input_ids, - return_all_hypotheses=False, + return_hypotheses=False, ) else: raise TypeError(f"WER metric does not support decoding of type {type(self.decoding)}") diff --git a/nemo/collections/asr/metrics/wer.py b/nemo/collections/asr/metrics/wer.py index c6eaeaf70849..43bb795fe42d 100644 --- a/nemo/collections/asr/metrics/wer.py +++ b/nemo/collections/asr/metrics/wer.py @@ -280,7 +280,7 @@ def __init__( encoder_hidden_states=predictions, encoder_input_mask=predictions_mask, decoder_input_ids=input_ids, - return_all_hypotheses=False, + return_hypotheses=False, ) else: raise TypeError(f"WER metric does not support decoding of type {type(self.decoding)}") diff --git a/nemo/collections/asr/models/aed_multitask_models.py b/nemo/collections/asr/models/aed_multitask_models.py index 6b4f20800c6c..81a8567881b5 100644 --- a/nemo/collections/asr/models/aed_multitask_models.py +++ b/nemo/collections/asr/models/aed_multitask_models.py @@ -457,7 +457,7 @@ def transcribe( self, audio: Union[str, List[str], np.ndarray, DataLoader], batch_size: int = 4, - return_all_hypotheses: bool = False, + return_hypotheses: bool = False, num_workers: int = 0, channel_selector: Optional[ChannelSelectorType] = None, augmentor: DictConfig = None, @@ -506,7 +506,7 @@ def transcribe( if override_config is None: trcfg = MultiTaskTranscriptionConfig( batch_size=batch_size, - return_all_hypotheses=return_all_hypotheses, + return_hypotheses=return_hypotheses, num_workers=num_workers, channel_selector=channel_selector, augmentor=augmentor, @@ -955,7 +955,7 @@ def _transcribe_output_processing(self, outputs, trcfg: MultiTaskTranscriptionCo encoder_hidden_states=enc_states, encoder_input_mask=enc_mask, decoder_input_ids=decoder_input_ids, - return_all_hypotheses=trcfg.return_all_hypotheses, + return_hypotheses=trcfg.return_hypotheses, ) del enc_states, enc_mask, decoder_input_ids @@ -1086,7 +1086,7 @@ def predict_step( encoder_hidden_states=enc_states, encoder_input_mask=enc_mask, decoder_input_ids=batch.prompt, - return_all_hypotheses=False, + return_hypotheses=False, ) if batch.cuts: return list(zip(batch.cuts, text)) diff --git a/nemo/collections/asr/models/ctc_models.py b/nemo/collections/asr/models/ctc_models.py index 949085727f63..2337504c308b 100644 --- a/nemo/collections/asr/models/ctc_models.py +++ b/nemo/collections/asr/models/ctc_models.py @@ -616,7 +616,7 @@ def predict_step(self, batch, batch_idx, dataloader_idx=0): transcribed_texts = self.wer.decoding.ctc_decoder_predictions_tensor( decoder_outputs=log_probs, decoder_lengths=encoded_len, - return_all_hypotheses=False, + return_hypotheses=False, ) if isinstance(sample_id, torch.Tensor): @@ -711,7 +711,7 @@ def _transcribe_output_processing(self, outputs, trcfg: TranscribeConfig) -> Gen hypotheses = self.decoding.ctc_decoder_predictions_tensor( logits, decoder_lengths=logits_len, - return_all_hypotheses=trcfg.return_all_hypotheses, + return_hypotheses=trcfg.return_hypotheses, ) if trcfg.return_hypotheses: if logits.is_cuda: diff --git a/nemo/collections/asr/models/hybrid_rnnt_ctc_models.py b/nemo/collections/asr/models/hybrid_rnnt_ctc_models.py index c612ebebe3ae..75caf4bccb28 100644 --- a/nemo/collections/asr/models/hybrid_rnnt_ctc_models.py +++ b/nemo/collections/asr/models/hybrid_rnnt_ctc_models.py @@ -211,7 +211,7 @@ def _transcribe_output_processing( hypotheses = self.ctc_decoding.ctc_decoder_predictions_tensor( logits, encoded_len, - return_all_hypotheses=trcfg.return_all_hypotheses, + return_hypotheses=trcfg.return_hypotheses, ) logits = logits.cpu() @@ -504,7 +504,7 @@ def predict_step(self, batch, batch_idx, dataloader_idx=0): del signal best_hyp_text = self.decoding.rnnt_decoder_predictions_tensor( - encoder_output=encoded, encoded_lengths=encoded_len, return_all_hypotheses=False + encoder_output=encoded, encoded_lengths=encoded_len, return_hypotheses=False ) if isinstance(sample_id, torch.Tensor): sample_id = sample_id.cpu().detach().numpy() diff --git a/nemo/collections/asr/models/rnnt_models.py b/nemo/collections/asr/models/rnnt_models.py index 89ccbb816885..a0a491d50927 100644 --- a/nemo/collections/asr/models/rnnt_models.py +++ b/nemo/collections/asr/models/rnnt_models.py @@ -816,7 +816,7 @@ def predict_step(self, batch, batch_idx, dataloader_idx=0): del signal best_hyp_text = self.decoding.rnnt_decoder_predictions_tensor( - encoder_output=encoded, encoded_lengths=encoded_len, return_all_hypotheses=False + encoder_output=encoded, encoded_lengths=encoded_len, return_hypotheses=False ) if isinstance(sample_id, torch.Tensor): @@ -946,7 +946,7 @@ def _transcribe_output_processing( hyp = self.decoding.rnnt_decoder_predictions_tensor( encoded, encoded_len, - return_all_hypotheses=trcfg.return_all_hypotheses, + return_hypotheses=trcfg.return_hypotheses, partial_hypotheses=trcfg.partial_hypothesis, ) # cleanup memory diff --git a/nemo/collections/asr/parts/mixins/mixins.py b/nemo/collections/asr/parts/mixins/mixins.py index 52f72cd8ccf2..577b6393248c 100644 --- a/nemo/collections/asr/parts/mixins/mixins.py +++ b/nemo/collections/asr/parts/mixins/mixins.py @@ -705,7 +705,7 @@ def conformer_stream_step( best_hyp = self.decoding.rnnt_decoder_predictions_tensor( encoder_output=encoded, encoded_lengths=encoded_len, - return_all_hypotheses=True, + return_hypotheses=True, partial_hypotheses=previous_hypotheses, ) greedy_predictions = [hyp.y_sequence for hyp in best_hyp] diff --git a/nemo/collections/asr/parts/mixins/transcription.py b/nemo/collections/asr/parts/mixins/transcription.py index 52822db6baf3..72c1d0241766 100644 --- a/nemo/collections/asr/parts/mixins/transcription.py +++ b/nemo/collections/asr/parts/mixins/transcription.py @@ -57,7 +57,7 @@ class InternalTranscribeConfig: @dataclass class TranscribeConfig: batch_size: int = 4 - return_all_hypotheses: bool = False + return_hypotheses: bool = False num_workers: Optional[int] = None channel_selector: ChannelSelectorType = None augmentor: Optional[DictConfig] = None @@ -69,13 +69,6 @@ class TranscribeConfig: _internal: Optional[InternalTranscribeConfig] = None - @property - def return_hypotheses(self): - return self.return_all_hypotheses - - @return_hypotheses.setter - def return_hypotheses(self, value): - self.return_all_hypotheses = value def get_value_from_transcription_config(trcfg, key, default): diff --git a/nemo/collections/asr/parts/submodules/ctc_decoding.py b/nemo/collections/asr/parts/submodules/ctc_decoding.py index 7b88cc453dec..4b2dd5216800 100644 --- a/nemo/collections/asr/parts/submodules/ctc_decoding.py +++ b/nemo/collections/asr/parts/submodules/ctc_decoding.py @@ -359,7 +359,7 @@ def ctc_decoder_predictions_tensor( decoder_outputs: torch.Tensor, decoder_lengths: torch.Tensor = None, fold_consecutive: bool = True, - return_all_hypotheses: bool = False, + return_hypotheses: bool = False, ) -> Union[List[Hypothesis], List[List[Hypothesis]]]: """ Decodes a sequence of labels to words @@ -372,7 +372,7 @@ def ctc_decoder_predictions_tensor( of the sequence in the padded `predictions` tensor. fold_consecutive: Bool, determine whether to perform "ctc collapse", folding consecutive tokens into a single token. - return_all_hypotheses: Bool flag whether to return just the decoding predictions of the model + return_hypotheses: Bool flag whether to return just the decoding predictions of the model or a Hypothesis object that holds information such as the decoded `text`, the `alignment` of emited by the CTC Model, and the `length` of the sequence (if available). May also contain the log-probabilities of the decoder (if this method is called via @@ -426,7 +426,7 @@ def ctc_decoder_predictions_tensor( all_hypotheses.append(decoded_hyps) - if return_all_hypotheses: + if return_hypotheses: return all_hypotheses # type: list[list[Hypothesis]] # alaptev: The line below might contain a bug. Do we really want all_hyp_text to be flat? @@ -444,7 +444,7 @@ def ctc_decoder_predictions_tensor( # If computing timestamps if self.compute_timestamps is True: # greedy decoding, can get high-level confidence scores - if return_all_hypotheses and (self.preserve_word_confidence or self.preserve_token_confidence): + if return_hypotheses and (self.preserve_word_confidence or self.preserve_token_confidence): hypotheses = self.compute_confidence(hypotheses) else: # remove unused token_repetitions from Hypothesis.text @@ -454,7 +454,7 @@ def ctc_decoder_predictions_tensor( for hyp_idx in range(len(hypotheses)): hypotheses[hyp_idx] = self.compute_ctc_timestamps(hypotheses[hyp_idx], timestamp_type) - if return_all_hypotheses: + if return_hypotheses: return hypotheses return [Hypothesis(h.score, h.y_sequence, h.text) for h in hypotheses] diff --git a/nemo/collections/asr/parts/submodules/multitask_decoding.py b/nemo/collections/asr/parts/submodules/multitask_decoding.py index 9eb60b561b1a..0f1e9439d7ab 100644 --- a/nemo/collections/asr/parts/submodules/multitask_decoding.py +++ b/nemo/collections/asr/parts/submodules/multitask_decoding.py @@ -214,7 +214,7 @@ def decode_predictions_tensor( encoder_hidden_states: torch.Tensor, encoder_input_mask: torch.Tensor, decoder_input_ids: Optional[torch.Tensor] = None, - return_all_hypotheses: bool = False, + return_hypotheses: bool = False, partial_hypotheses: Optional[List[Hypothesis]] = None, ) -> Union[List[Hypothesis], List[List[Hypothesis]]]: """ @@ -260,7 +260,7 @@ def decode_predictions_tensor( hypotheses.append(decoded_hyps[0]) # best hypothesis all_hypotheses.append(decoded_hyps) - if return_all_hypotheses: + if return_hypotheses: return all_hypotheses all_hyp = [[Hypothesis(h.score, h.y_sequence, h.text) for h in hh] for hh in all_hypotheses] @@ -269,7 +269,7 @@ def decode_predictions_tensor( else: hypotheses = self.decode_hypothesis(prediction_list) - if return_all_hypotheses: + if return_hypotheses: # greedy decoding, can get high-level confidence scores if self.preserve_frame_confidence and ( self.preserve_word_confidence or self.preserve_token_confidence diff --git a/nemo/collections/asr/parts/submodules/rnnt_decoding.py b/nemo/collections/asr/parts/submodules/rnnt_decoding.py index 45b62ccf08c5..daa9a91b2c6b 100644 --- a/nemo/collections/asr/parts/submodules/rnnt_decoding.py +++ b/nemo/collections/asr/parts/submodules/rnnt_decoding.py @@ -492,7 +492,7 @@ def rnnt_decoder_predictions_tensor( self, encoder_output: torch.Tensor, encoded_lengths: torch.Tensor, - return_all_hypotheses: bool = False, + return_hypotheses: bool = False, partial_hypotheses: Optional[List[Hypothesis]] = None, ) -> Union[List[Hypothesis], List[List[Hypothesis]]]: """ @@ -541,7 +541,7 @@ def rnnt_decoder_predictions_tensor( hypotheses.append(decoded_hyps[0]) # best hypothesis all_hypotheses.append(decoded_hyps) - if return_all_hypotheses: + if return_hypotheses: return all_hypotheses # type: list[list[Hypothesis]] all_hyp = [[Hypothesis(h.score, h.y_sequence, h.text) for h in hh] for hh in all_hypotheses] @@ -556,7 +556,7 @@ def rnnt_decoder_predictions_tensor( for hyp_idx in range(len(hypotheses)): hypotheses[hyp_idx] = self.compute_rnnt_timestamps(hypotheses[hyp_idx], timestamp_type) - if return_all_hypotheses: + if return_hypotheses: # greedy decoding, can get high-level confidence scores if self.preserve_frame_confidence and ( self.preserve_word_confidence or self.preserve_token_confidence diff --git a/nemo/collections/multimodal/speech_cv/models/visual_ctc_models.py b/nemo/collections/multimodal/speech_cv/models/visual_ctc_models.py index 925b6d79d9d8..5c420b6e7e3a 100644 --- a/nemo/collections/multimodal/speech_cv/models/visual_ctc_models.py +++ b/nemo/collections/multimodal/speech_cv/models/visual_ctc_models.py @@ -209,7 +209,7 @@ def transcribe( current_hypotheses = self.decoding.ctc_decoder_predictions_tensor( logits, decoder_lengths=logits_len, - return_all_hypotheses=return_hypotheses, + return_hypotheses=return_hypotheses, ) if return_hypotheses: diff --git a/nemo/collections/multimodal/speech_cv/models/visual_hybrid_rnnt_ctc_models.py b/nemo/collections/multimodal/speech_cv/models/visual_hybrid_rnnt_ctc_models.py index ec379bdb3fb8..2a70dbc021eb 100644 --- a/nemo/collections/multimodal/speech_cv/models/visual_hybrid_rnnt_ctc_models.py +++ b/nemo/collections/multimodal/speech_cv/models/visual_hybrid_rnnt_ctc_models.py @@ -178,7 +178,7 @@ def transcribe( best_hyp = self.ctc_decoding.ctc_decoder_predictions_tensor( logits, encoded_len, - return_all_hypotheses=return_hypotheses, + return_hypotheses=return_hypotheses, ) if return_hypotheses: # dump log probs per file @@ -468,7 +468,7 @@ def predict_step(self, batch, batch_idx, dataloader_idx=0): del signal best_hyp_text = self.decoding.rnnt_decoder_predictions_tensor( - encoder_output=encoded, encoded_lengths=encoded_len, return_all_hypotheses=False + encoder_output=encoded, encoded_lengths=encoded_len, return_hypotheses=False ) sample_id = sample_id.cpu().detach().numpy() diff --git a/nemo/collections/multimodal/speech_cv/models/visual_rnnt_models.py b/nemo/collections/multimodal/speech_cv/models/visual_rnnt_models.py index 17b88970e30f..7efdb15f4e3c 100644 --- a/nemo/collections/multimodal/speech_cv/models/visual_rnnt_models.py +++ b/nemo/collections/multimodal/speech_cv/models/visual_rnnt_models.py @@ -289,7 +289,7 @@ def transcribe( best_hyp = self.decoding.rnnt_decoder_predictions_tensor( encoded, encoded_len, - return_all_hypotheses=return_hypotheses, + return_hypotheses=return_hypotheses, partial_hypotheses=partial_hypothesis, ) @@ -725,7 +725,7 @@ def predict_step(self, batch, batch_idx, dataloader_idx=0): del signal best_hyp_text = self.decoding.rnnt_decoder_predictions_tensor( - encoder_output=encoded, encoded_lengths=encoded_len, return_all_hypotheses=False + encoder_output=encoded, encoded_lengths=encoded_len, return_hypotheses=False ) sample_id = sample_id.cpu().detach().numpy() diff --git a/nemo/collections/tts/g2p/models/ctc.py b/nemo/collections/tts/g2p/models/ctc.py index b8b54bd3bc71..9f7dc2ae333f 100644 --- a/nemo/collections/tts/g2p/models/ctc.py +++ b/nemo/collections/tts/g2p/models/ctc.py @@ -357,7 +357,7 @@ def _infer( ) preds_str = self.decoding.ctc_decoder_predictions_tensor( - log_probs, decoder_lengths=encoded_len, return_all_hypotheses=False + log_probs, decoder_lengths=encoded_len, return_hypotheses=False ) all_preds.extend(preds_str) diff --git a/tests/collections/asr/decoding/test_ctc_decoding.py b/tests/collections/asr/decoding/test_ctc_decoding.py index 38103342b31a..e122dd5a3fdd 100644 --- a/tests/collections/asr/decoding/test_ctc_decoding.py +++ b/tests/collections/asr/decoding/test_ctc_decoding.py @@ -126,7 +126,7 @@ def test_char_decoding_greedy_forward( with torch.no_grad(): hypotheses = decoding.ctc_decoder_predictions_tensor( - input_signal, length, fold_consecutive=True, return_all_hypotheses=False + input_signal, length, fold_consecutive=True, return_hypotheses=False ) texts = [hyp.text for hyp in hypotheses] @@ -148,7 +148,7 @@ def test_char_decoding_greedy_forward_hypotheses(self, alignments, timestamps): with torch.no_grad(): hyps = decoding.ctc_decoder_predictions_tensor( - input_signal, length, fold_consecutive=True, return_all_hypotheses=True + input_signal, length, fold_consecutive=True, return_hypotheses=True ) for idx, hyp in enumerate(hyps): @@ -179,7 +179,7 @@ def test_subword_decoding_greedy_forward(self, tmp_tokenizer): with torch.no_grad(): hypotheses = decoding.ctc_decoder_predictions_tensor( - input_signal, length, fold_consecutive=True, return_all_hypotheses=False + input_signal, length, fold_consecutive=True, return_hypotheses=False ) texts = [hyp.text for hyp in hypotheses] @@ -200,7 +200,7 @@ def test_subword_decoding_greedy_forward_hypotheses(self, tmp_tokenizer, alignme with torch.no_grad(): hyps = decoding.ctc_decoder_predictions_tensor( - input_signal, length, fold_consecutive=True, return_all_hypotheses=True + input_signal, length, fold_consecutive=True, return_hypotheses=True ) for idx, hyp in enumerate(hyps): @@ -286,11 +286,11 @@ def test_batched_decoding_logprobs( with torch.inference_mode(): hyps = unbatched_decoding.ctc_decoder_predictions_tensor( - input_signal, length, fold_consecutive=True, return_all_hypotheses=True + input_signal, length, fold_consecutive=True, return_hypotheses=True ) batched_hyps = batched_decoding.ctc_decoder_predictions_tensor( - input_signal, length, fold_consecutive=True, return_all_hypotheses=True + input_signal, length, fold_consecutive=True, return_hypotheses=True ) assert len(hyps) == len(batched_hyps) == B @@ -353,11 +353,11 @@ def test_batched_decoding_labels(self, tmp_tokenizer, timestamps, length_is_none with torch.inference_mode(): hyps = unbatched_decoding.ctc_decoder_predictions_tensor( - input_labels, length, fold_consecutive=True, return_all_hypotheses=True + input_labels, length, fold_consecutive=True, return_hypotheses=True ) batched_hyps = batched_decoding.ctc_decoder_predictions_tensor( - input_labels, length, fold_consecutive=True, return_all_hypotheses=True + input_labels, length, fold_consecutive=True, return_hypotheses=True ) assert len(hyps) == len(batched_hyps) == B diff --git a/tests/collections/asr/decoding/test_rnnt_decoding.py b/tests/collections/asr/decoding/test_rnnt_decoding.py index 7eff1fd81f3e..a239eb27c2d3 100644 --- a/tests/collections/asr/decoding/test_rnnt_decoding.py +++ b/tests/collections/asr/decoding/test_rnnt_decoding.py @@ -438,7 +438,7 @@ def test_subword_decoding_compute_timestamps(self, test_data_dir, decoding_strat decoding_cfg=cfg, decoder=model.decoder, joint=model.joint, tokenizer=model.tokenizer ) - hyps = decoding.rnnt_decoder_predictions_tensor(encoded, encoded_len, return_all_hypotheses=True) + hyps = decoding.rnnt_decoder_predictions_tensor(encoded, encoded_len, return_hypotheses=True) check_subword_timestamps(hyps[0], decoding) @@ -473,7 +473,7 @@ def test_char_decoding_compute_timestamps(self, test_data_dir, decoding_strategy decoding = RNNTDecoding(decoding_cfg=cfg, decoder=model.decoder, joint=model.joint, vocabulary=vocab) - hyps = decoding.rnnt_decoder_predictions_tensor(encoded, encoded_len, return_all_hypotheses=True) + hyps = decoding.rnnt_decoder_predictions_tensor(encoded, encoded_len, return_hypotheses=True) check_char_timestamps(hyps[0], decoding) diff --git a/tests/collections/asr/test_asr_metrics.py b/tests/collections/asr/test_asr_metrics.py index 48a5ff401e8f..a87622d60a07 100644 --- a/tests/collections/asr/test_asr_metrics.py +++ b/tests/collections/asr/test_asr_metrics.py @@ -219,7 +219,7 @@ def test_wer_metric_return_hypothesis(self, batch_dim_index, test_wer_bpe): # pass batchsize 1 tensor, get back list of length 1 Hypothesis wer.decoding.preserve_alignments = True - hyp = wer.decoding.ctc_decoder_predictions_tensor(tensor, return_all_hypotheses=True) + hyp = wer.decoding.ctc_decoder_predictions_tensor(tensor, return_hypotheses=True) hyp = hyp[0] assert isinstance(hyp, Hypothesis) @@ -233,7 +233,7 @@ def test_wer_metric_return_hypothesis(self, batch_dim_index, test_wer_bpe): length = torch.tensor([tensor.shape[1 - batch_dim_index]], dtype=torch.long) # pass batchsize 1 tensor, get back list of length 1 Hypothesis [add length info] - hyp = wer.decoding.ctc_decoder_predictions_tensor(tensor, decoder_lengths=length, return_all_hypotheses=True) + hyp = wer.decoding.ctc_decoder_predictions_tensor(tensor, decoder_lengths=length, return_hypotheses=True) hyp = hyp[0] assert isinstance(hyp, Hypothesis) assert hyp.length == 3 @@ -251,7 +251,7 @@ def test_wer_metric_subword_return_hypothesis(self, batch_dim_index, test_wer_bp # pass batchsize 1 tensor, get back list of length 1 Hypothesis wer.decoding.preserve_alignments = True - hyp = wer.decoding.ctc_decoder_predictions_tensor(tensor, return_all_hypotheses=True) + hyp = wer.decoding.ctc_decoder_predictions_tensor(tensor, return_hypotheses=True) hyp = hyp[0] assert isinstance(hyp, Hypothesis) @@ -265,7 +265,7 @@ def test_wer_metric_subword_return_hypothesis(self, batch_dim_index, test_wer_bp length = torch.tensor([tensor.shape[1 - batch_dim_index]], dtype=torch.long) # pass batchsize 1 tensor, get back list of length 1 Hypothesis [add length info] - hyp = wer.decoding.ctc_decoder_predictions_tensor(tensor, decoder_lengths=length, return_all_hypotheses=True) + hyp = wer.decoding.ctc_decoder_predictions_tensor(tensor, decoder_lengths=length, return_hypotheses=True) hyp = hyp[0] assert isinstance(hyp, Hypothesis) assert hyp.length == 3 @@ -389,7 +389,7 @@ def test_char_decoding_logprobs(self): decoding_cfg = CTCDecodingConfig() decoding = CTCDecoding(decoding_cfg, vocabulary=self.vocabulary) - hyp = decoding.ctc_decoder_predictions_tensor(decoder_outputs, decoder_lens, return_all_hypotheses=True) + hyp = decoding.ctc_decoder_predictions_tensor(decoder_outputs, decoder_lens, return_hypotheses=True) hyp = hyp[0] # type: Hypothesis assert isinstance(hyp.y_sequence, torch.Tensor) assert hyp.length == torch.tensor(T, dtype=torch.int32) @@ -401,7 +401,7 @@ def test_char_decoding_logprobs(self): decoding_cfg = CTCDecodingConfig(preserve_alignments=True, compute_timestamps=True) decoding = CTCDecoding(decoding_cfg, vocabulary=self.vocabulary) - hyp = decoding.ctc_decoder_predictions_tensor(decoder_outputs, decoder_lens, return_all_hypotheses=True) + hyp = decoding.ctc_decoder_predictions_tensor(decoder_outputs, decoder_lens, return_hypotheses=True) hyp = hyp[0] # type: Hypothesis assert isinstance(hyp.y_sequence, torch.Tensor) assert hyp.length == torch.tensor(T, dtype=torch.int32) @@ -420,7 +420,7 @@ def test_subword_decoding_logprobs(self): decoding_cfg = CTCBPEDecodingConfig() decoding = CTCBPEDecoding(decoding_cfg, tokenizer=self.char_tokenizer) - hyp = decoding.ctc_decoder_predictions_tensor(decoder_outputs, decoder_lens, return_all_hypotheses=True) + hyp = decoding.ctc_decoder_predictions_tensor(decoder_outputs, decoder_lens, return_hypotheses=True) hyp = hyp[0] # type: Hypothesis assert isinstance(hyp.y_sequence, torch.Tensor) assert hyp.length == torch.tensor(T, dtype=torch.int32) @@ -432,7 +432,7 @@ def test_subword_decoding_logprobs(self): decoding_cfg = CTCBPEDecodingConfig(preserve_alignments=True, compute_timestamps=True) decoding = CTCBPEDecoding(decoding_cfg, tokenizer=self.char_tokenizer) - hyp = decoding.ctc_decoder_predictions_tensor(decoder_outputs, decoder_lens, return_all_hypotheses=True) + hyp = decoding.ctc_decoder_predictions_tensor(decoder_outputs, decoder_lens, return_hypotheses=True) hyp = hyp[0] # type: Hypothesis assert isinstance(hyp.y_sequence, torch.Tensor) assert hyp.length == torch.tensor(T, dtype=torch.int32) @@ -451,7 +451,7 @@ def test_char_decoding_labels(self): decoding_cfg = CTCDecodingConfig() decoding = CTCDecoding(decoding_cfg, vocabulary=self.vocabulary) - hyp = decoding.ctc_decoder_predictions_tensor(decoder_outputs, decoder_lens, return_all_hypotheses=True) + hyp = decoding.ctc_decoder_predictions_tensor(decoder_outputs, decoder_lens, return_hypotheses=True) hyp = hyp[0] # type: Hypothesis assert isinstance(hyp.y_sequence, torch.Tensor) assert hyp.length == torch.tensor(T, dtype=torch.int32) @@ -465,13 +465,13 @@ def test_char_decoding_labels(self): # Cannot compute alignments from labels with pytest.raises(ValueError): - _ = decoding.ctc_decoder_predictions_tensor(decoder_outputs, decoder_lens, return_all_hypotheses=True) + _ = decoding.ctc_decoder_predictions_tensor(decoder_outputs, decoder_lens, return_hypotheses=True) # Preserve timestamps decoding_cfg = CTCDecodingConfig(preserve_alignments=False, compute_timestamps=True) decoding = CTCDecoding(decoding_cfg, vocabulary=self.vocabulary) - hyp = decoding.ctc_decoder_predictions_tensor(decoder_outputs, decoder_lens, return_all_hypotheses=True) + hyp = decoding.ctc_decoder_predictions_tensor(decoder_outputs, decoder_lens, return_hypotheses=True) hyp = hyp[0] # type: Hypothesis assert isinstance(hyp.y_sequence, torch.Tensor) assert hyp.length == torch.tensor(T, dtype=torch.int32) @@ -490,7 +490,7 @@ def test_subword_decoding_logprobs(self): decoding_cfg = CTCBPEDecodingConfig() decoding = CTCBPEDecoding(decoding_cfg, tokenizer=self.char_tokenizer) - hyp = decoding.ctc_decoder_predictions_tensor(decoder_outputs, decoder_lens, return_all_hypotheses=True) + hyp = decoding.ctc_decoder_predictions_tensor(decoder_outputs, decoder_lens, return_hypotheses=True) hyp = hyp[0] # type: Hypothesis assert isinstance(hyp.y_sequence, torch.Tensor) assert hyp.length == torch.tensor(T, dtype=torch.int32) @@ -502,7 +502,7 @@ def test_subword_decoding_logprobs(self): decoding_cfg = CTCBPEDecodingConfig(preserve_alignments=True, compute_timestamps=True) decoding = CTCBPEDecoding(decoding_cfg, tokenizer=self.char_tokenizer) - hyp = decoding.ctc_decoder_predictions_tensor(decoder_outputs, decoder_lens, return_all_hypotheses=True) + hyp = decoding.ctc_decoder_predictions_tensor(decoder_outputs, decoder_lens, return_hypotheses=True) hyp = hyp[0] # type: Hypothesis assert isinstance(hyp.y_sequence, torch.Tensor) assert hyp.length == torch.tensor(T, dtype=torch.int32) @@ -521,7 +521,7 @@ def test_subword_decoding_labels(self): decoding_cfg = CTCBPEDecodingConfig() decoding = CTCBPEDecoding(decoding_cfg, tokenizer=self.char_tokenizer) - hyp = decoding.ctc_decoder_predictions_tensor(decoder_outputs, decoder_lens, return_all_hypotheses=True) + hyp = decoding.ctc_decoder_predictions_tensor(decoder_outputs, decoder_lens, return_hypotheses=True) hyp = hyp[0] # type: Hypothesis assert isinstance(hyp.y_sequence, torch.Tensor) assert hyp.length == torch.tensor(T, dtype=torch.int32) @@ -535,13 +535,13 @@ def test_subword_decoding_labels(self): # Cannot compute alignments from labels with pytest.raises(ValueError): - _ = decoding.ctc_decoder_predictions_tensor(decoder_outputs, decoder_lens, return_all_hypotheses=True) + _ = decoding.ctc_decoder_predictions_tensor(decoder_outputs, decoder_lens, return_hypotheses=True) # Preserve timestamps decoding_cfg = CTCBPEDecodingConfig(preserve_alignments=False, compute_timestamps=True) decoding = CTCBPEDecoding(decoding_cfg, tokenizer=self.char_tokenizer) - hyp = decoding.ctc_decoder_predictions_tensor(decoder_outputs, decoder_lens, return_all_hypotheses=True) + hyp = decoding.ctc_decoder_predictions_tensor(decoder_outputs, decoder_lens, return_hypotheses=True) hyp = hyp[0] # type: Hypothesis assert isinstance(hyp.y_sequence, torch.Tensor) assert hyp.length == torch.tensor(T, dtype=torch.int32) diff --git a/tutorials/asr/ASR_Context_Biasing.ipynb b/tutorials/asr/ASR_Context_Biasing.ipynb index 7a62bfc1dbf2..6c551e00b2bf 100644 --- a/tutorials/asr/ASR_Context_Biasing.ipynb +++ b/tutorials/asr/ASR_Context_Biasing.ipynb @@ -259,10 +259,6 @@ "execution_count": null, "id": "d34ee0ba", "metadata": { - "collapsed": true, - "jupyter": { - "outputs_hidden": true - }, "scrolled": true }, "outputs": [], @@ -322,7 +318,7 @@ "\n", "for idx, ref in enumerate(ref_text):\n", " ref = ref.split()\n", - " hyp = recog_results[idx].split()\n", + " hyp = recog_results[idx].text.split()\n", " texterrors_ali = texterrors.align_texts(ref, hyp, False)\n", " ali = []\n", " for i in range(len(texterrors_ali[0])):\n", @@ -898,7 +894,7 @@ " print(f\"[ref text]: {target_transcripts[idx]}\")\n", " else:\n", " # if no spotted words, use standard greedy predictions\n", - " pred_text = ctc_model.wer.decoding.ctc_decoder_predictions_tensor(greedy_predicts)[0]" + " pred_text = ctc_model.wer.decoding.ctc_decoder_predictions_tensor(greedy_predicts)[0].text" ] }, {