Skip to content

Commit

Permalink
Returned previous names of return_hypotheses
Browse files Browse the repository at this point in the history
Signed-off-by: Ssofja <sofiakostandian@gmail.com>
  • Loading branch information
Ssofja committed Feb 7, 2025
1 parent 8549980 commit 3987d43
Show file tree
Hide file tree
Showing 19 changed files with 59 additions and 70 deletions.
2 changes: 1 addition & 1 deletion nemo/collections/asr/metrics/bleu.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def __init__(
encoder_hidden_states=predictions,
encoder_input_mask=predictions_mask,
decoder_input_ids=input_ids,
return_all_hypotheses=False,
return_hypotheses=False,
)
else:
raise TypeError(f"WER metric does not support decoding of type {type(self.decoding)}")
Expand Down
2 changes: 1 addition & 1 deletion nemo/collections/asr/metrics/wer.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ def __init__(
encoder_hidden_states=predictions,
encoder_input_mask=predictions_mask,
decoder_input_ids=input_ids,
return_all_hypotheses=False,
return_hypotheses=False,
)
else:
raise TypeError(f"WER metric does not support decoding of type {type(self.decoding)}")
Expand Down
8 changes: 4 additions & 4 deletions nemo/collections/asr/models/aed_multitask_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,7 +457,7 @@ def transcribe(
self,
audio: Union[str, List[str], np.ndarray, DataLoader],
batch_size: int = 4,
return_all_hypotheses: bool = False,
return_hypotheses: bool = False,
num_workers: int = 0,
channel_selector: Optional[ChannelSelectorType] = None,
augmentor: DictConfig = None,
Expand Down Expand Up @@ -506,7 +506,7 @@ def transcribe(
if override_config is None:
trcfg = MultiTaskTranscriptionConfig(
batch_size=batch_size,
return_all_hypotheses=return_all_hypotheses,
return_hypotheses=return_hypotheses,
num_workers=num_workers,
channel_selector=channel_selector,
augmentor=augmentor,
Expand Down Expand Up @@ -955,7 +955,7 @@ def _transcribe_output_processing(self, outputs, trcfg: MultiTaskTranscriptionCo
encoder_hidden_states=enc_states,
encoder_input_mask=enc_mask,
decoder_input_ids=decoder_input_ids,
return_all_hypotheses=trcfg.return_all_hypotheses,
return_hypotheses=trcfg.return_hypotheses,
)

del enc_states, enc_mask, decoder_input_ids
Expand Down Expand Up @@ -1086,7 +1086,7 @@ def predict_step(
encoder_hidden_states=enc_states,
encoder_input_mask=enc_mask,
decoder_input_ids=batch.prompt,
return_all_hypotheses=False,
return_hypotheses=False,
)
if batch.cuts:
return list(zip(batch.cuts, text))
Expand Down
4 changes: 2 additions & 2 deletions nemo/collections/asr/models/ctc_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -616,7 +616,7 @@ def predict_step(self, batch, batch_idx, dataloader_idx=0):
transcribed_texts = self.wer.decoding.ctc_decoder_predictions_tensor(
decoder_outputs=log_probs,
decoder_lengths=encoded_len,
return_all_hypotheses=False,
return_hypotheses=False,
)

if isinstance(sample_id, torch.Tensor):
Expand Down Expand Up @@ -711,7 +711,7 @@ def _transcribe_output_processing(self, outputs, trcfg: TranscribeConfig) -> Gen
hypotheses = self.decoding.ctc_decoder_predictions_tensor(
logits,
decoder_lengths=logits_len,
return_all_hypotheses=trcfg.return_all_hypotheses,
return_hypotheses=trcfg.return_hypotheses,
)
if trcfg.return_hypotheses:
if logits.is_cuda:
Expand Down
4 changes: 2 additions & 2 deletions nemo/collections/asr/models/hybrid_rnnt_ctc_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ def _transcribe_output_processing(
hypotheses = self.ctc_decoding.ctc_decoder_predictions_tensor(
logits,
encoded_len,
return_all_hypotheses=trcfg.return_all_hypotheses,
return_hypotheses=trcfg.return_hypotheses,
)
logits = logits.cpu()

Expand Down Expand Up @@ -504,7 +504,7 @@ def predict_step(self, batch, batch_idx, dataloader_idx=0):
del signal

best_hyp_text = self.decoding.rnnt_decoder_predictions_tensor(
encoder_output=encoded, encoded_lengths=encoded_len, return_all_hypotheses=False
encoder_output=encoded, encoded_lengths=encoded_len, return_hypotheses=False
)
if isinstance(sample_id, torch.Tensor):
sample_id = sample_id.cpu().detach().numpy()
Expand Down
4 changes: 2 additions & 2 deletions nemo/collections/asr/models/rnnt_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -816,7 +816,7 @@ def predict_step(self, batch, batch_idx, dataloader_idx=0):
del signal

best_hyp_text = self.decoding.rnnt_decoder_predictions_tensor(
encoder_output=encoded, encoded_lengths=encoded_len, return_all_hypotheses=False
encoder_output=encoded, encoded_lengths=encoded_len, return_hypotheses=False
)

if isinstance(sample_id, torch.Tensor):
Expand Down Expand Up @@ -946,7 +946,7 @@ def _transcribe_output_processing(
hyp = self.decoding.rnnt_decoder_predictions_tensor(
encoded,
encoded_len,
return_all_hypotheses=trcfg.return_all_hypotheses,
return_hypotheses=trcfg.return_hypotheses,
partial_hypotheses=trcfg.partial_hypothesis,
)
# cleanup memory
Expand Down
2 changes: 1 addition & 1 deletion nemo/collections/asr/parts/mixins/mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -705,7 +705,7 @@ def conformer_stream_step(
best_hyp = self.decoding.rnnt_decoder_predictions_tensor(
encoder_output=encoded,
encoded_lengths=encoded_len,
return_all_hypotheses=True,
return_hypotheses=True,
partial_hypotheses=previous_hypotheses,
)
greedy_predictions = [hyp.y_sequence for hyp in best_hyp]
Expand Down
9 changes: 1 addition & 8 deletions nemo/collections/asr/parts/mixins/transcription.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ class InternalTranscribeConfig:
@dataclass
class TranscribeConfig:
batch_size: int = 4
return_all_hypotheses: bool = False
return_hypotheses: bool = False
num_workers: Optional[int] = None
channel_selector: ChannelSelectorType = None
augmentor: Optional[DictConfig] = None
Expand All @@ -69,13 +69,6 @@ class TranscribeConfig:

_internal: Optional[InternalTranscribeConfig] = None

@property
def return_hypotheses(self):
return self.return_all_hypotheses

@return_hypotheses.setter
def return_hypotheses(self, value):
self.return_all_hypotheses = value


def get_value_from_transcription_config(trcfg, key, default):
Expand Down
10 changes: 5 additions & 5 deletions nemo/collections/asr/parts/submodules/ctc_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,7 @@ def ctc_decoder_predictions_tensor(
decoder_outputs: torch.Tensor,
decoder_lengths: torch.Tensor = None,
fold_consecutive: bool = True,
return_all_hypotheses: bool = False,
return_hypotheses: bool = False,
) -> Union[List[Hypothesis], List[List[Hypothesis]]]:
"""
Decodes a sequence of labels to words
Expand All @@ -372,7 +372,7 @@ def ctc_decoder_predictions_tensor(
of the sequence in the padded `predictions` tensor.
fold_consecutive: Bool, determine whether to perform "ctc collapse", folding consecutive tokens
into a single token.
return_all_hypotheses: Bool flag whether to return just the decoding predictions of the model
return_hypotheses: Bool flag whether to return just the decoding predictions of the model
or a Hypothesis object that holds information such as the decoded `text`,
the `alignment` of emited by the CTC Model, and the `length` of the sequence (if available).
May also contain the log-probabilities of the decoder (if this method is called via
Expand Down Expand Up @@ -426,7 +426,7 @@ def ctc_decoder_predictions_tensor(

all_hypotheses.append(decoded_hyps)

if return_all_hypotheses:
if return_hypotheses:
return all_hypotheses # type: list[list[Hypothesis]]

# alaptev: The line below might contain a bug. Do we really want all_hyp_text to be flat?
Expand All @@ -444,7 +444,7 @@ def ctc_decoder_predictions_tensor(
# If computing timestamps
if self.compute_timestamps is True:
# greedy decoding, can get high-level confidence scores
if return_all_hypotheses and (self.preserve_word_confidence or self.preserve_token_confidence):
if return_hypotheses and (self.preserve_word_confidence or self.preserve_token_confidence):
hypotheses = self.compute_confidence(hypotheses)
else:
# remove unused token_repetitions from Hypothesis.text
Expand All @@ -454,7 +454,7 @@ def ctc_decoder_predictions_tensor(
for hyp_idx in range(len(hypotheses)):
hypotheses[hyp_idx] = self.compute_ctc_timestamps(hypotheses[hyp_idx], timestamp_type)

if return_all_hypotheses:
if return_hypotheses:
return hypotheses

return [Hypothesis(h.score, h.y_sequence, h.text) for h in hypotheses]
Expand Down
6 changes: 3 additions & 3 deletions nemo/collections/asr/parts/submodules/multitask_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ def decode_predictions_tensor(
encoder_hidden_states: torch.Tensor,
encoder_input_mask: torch.Tensor,
decoder_input_ids: Optional[torch.Tensor] = None,
return_all_hypotheses: bool = False,
return_hypotheses: bool = False,
partial_hypotheses: Optional[List[Hypothesis]] = None,
) -> Union[List[Hypothesis], List[List[Hypothesis]]]:
"""
Expand Down Expand Up @@ -260,7 +260,7 @@ def decode_predictions_tensor(
hypotheses.append(decoded_hyps[0]) # best hypothesis
all_hypotheses.append(decoded_hyps)

if return_all_hypotheses:
if return_hypotheses:
return all_hypotheses

all_hyp = [[Hypothesis(h.score, h.y_sequence, h.text) for h in hh] for hh in all_hypotheses]
Expand All @@ -269,7 +269,7 @@ def decode_predictions_tensor(
else:
hypotheses = self.decode_hypothesis(prediction_list)

if return_all_hypotheses:
if return_hypotheses:
# greedy decoding, can get high-level confidence scores
if self.preserve_frame_confidence and (
self.preserve_word_confidence or self.preserve_token_confidence
Expand Down
6 changes: 3 additions & 3 deletions nemo/collections/asr/parts/submodules/rnnt_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,7 +492,7 @@ def rnnt_decoder_predictions_tensor(
self,
encoder_output: torch.Tensor,
encoded_lengths: torch.Tensor,
return_all_hypotheses: bool = False,
return_hypotheses: bool = False,
partial_hypotheses: Optional[List[Hypothesis]] = None,
) -> Union[List[Hypothesis], List[List[Hypothesis]]]:
"""
Expand Down Expand Up @@ -541,7 +541,7 @@ def rnnt_decoder_predictions_tensor(
hypotheses.append(decoded_hyps[0]) # best hypothesis
all_hypotheses.append(decoded_hyps)

if return_all_hypotheses:
if return_hypotheses:
return all_hypotheses # type: list[list[Hypothesis]]

all_hyp = [[Hypothesis(h.score, h.y_sequence, h.text) for h in hh] for hh in all_hypotheses]
Expand All @@ -556,7 +556,7 @@ def rnnt_decoder_predictions_tensor(
for hyp_idx in range(len(hypotheses)):
hypotheses[hyp_idx] = self.compute_rnnt_timestamps(hypotheses[hyp_idx], timestamp_type)

if return_all_hypotheses:
if return_hypotheses:
# greedy decoding, can get high-level confidence scores
if self.preserve_frame_confidence and (
self.preserve_word_confidence or self.preserve_token_confidence
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ def transcribe(
current_hypotheses = self.decoding.ctc_decoder_predictions_tensor(
logits,
decoder_lengths=logits_len,
return_all_hypotheses=return_hypotheses,
return_hypotheses=return_hypotheses,
)

if return_hypotheses:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def transcribe(
best_hyp = self.ctc_decoding.ctc_decoder_predictions_tensor(
logits,
encoded_len,
return_all_hypotheses=return_hypotheses,
return_hypotheses=return_hypotheses,
)
if return_hypotheses:
# dump log probs per file
Expand Down Expand Up @@ -468,7 +468,7 @@ def predict_step(self, batch, batch_idx, dataloader_idx=0):
del signal

best_hyp_text = self.decoding.rnnt_decoder_predictions_tensor(
encoder_output=encoded, encoded_lengths=encoded_len, return_all_hypotheses=False
encoder_output=encoded, encoded_lengths=encoded_len, return_hypotheses=False
)

sample_id = sample_id.cpu().detach().numpy()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,7 @@ def transcribe(
best_hyp = self.decoding.rnnt_decoder_predictions_tensor(
encoded,
encoded_len,
return_all_hypotheses=return_hypotheses,
return_hypotheses=return_hypotheses,
partial_hypotheses=partial_hypothesis,
)

Expand Down Expand Up @@ -725,7 +725,7 @@ def predict_step(self, batch, batch_idx, dataloader_idx=0):
del signal

best_hyp_text = self.decoding.rnnt_decoder_predictions_tensor(
encoder_output=encoded, encoded_lengths=encoded_len, return_all_hypotheses=False
encoder_output=encoded, encoded_lengths=encoded_len, return_hypotheses=False
)

sample_id = sample_id.cpu().detach().numpy()
Expand Down
2 changes: 1 addition & 1 deletion nemo/collections/tts/g2p/models/ctc.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,7 @@ def _infer(
)

preds_str = self.decoding.ctc_decoder_predictions_tensor(
log_probs, decoder_lengths=encoded_len, return_all_hypotheses=False
log_probs, decoder_lengths=encoded_len, return_hypotheses=False
)
all_preds.extend(preds_str)

Expand Down
16 changes: 8 additions & 8 deletions tests/collections/asr/decoding/test_ctc_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def test_char_decoding_greedy_forward(

with torch.no_grad():
hypotheses = decoding.ctc_decoder_predictions_tensor(
input_signal, length, fold_consecutive=True, return_all_hypotheses=False
input_signal, length, fold_consecutive=True, return_hypotheses=False
)
texts = [hyp.text for hyp in hypotheses]

Expand All @@ -148,7 +148,7 @@ def test_char_decoding_greedy_forward_hypotheses(self, alignments, timestamps):

with torch.no_grad():
hyps = decoding.ctc_decoder_predictions_tensor(
input_signal, length, fold_consecutive=True, return_all_hypotheses=True
input_signal, length, fold_consecutive=True, return_hypotheses=True
)

for idx, hyp in enumerate(hyps):
Expand Down Expand Up @@ -179,7 +179,7 @@ def test_subword_decoding_greedy_forward(self, tmp_tokenizer):

with torch.no_grad():
hypotheses = decoding.ctc_decoder_predictions_tensor(
input_signal, length, fold_consecutive=True, return_all_hypotheses=False
input_signal, length, fold_consecutive=True, return_hypotheses=False
)
texts = [hyp.text for hyp in hypotheses]

Expand All @@ -200,7 +200,7 @@ def test_subword_decoding_greedy_forward_hypotheses(self, tmp_tokenizer, alignme

with torch.no_grad():
hyps = decoding.ctc_decoder_predictions_tensor(
input_signal, length, fold_consecutive=True, return_all_hypotheses=True
input_signal, length, fold_consecutive=True, return_hypotheses=True
)

for idx, hyp in enumerate(hyps):
Expand Down Expand Up @@ -286,11 +286,11 @@ def test_batched_decoding_logprobs(

with torch.inference_mode():
hyps = unbatched_decoding.ctc_decoder_predictions_tensor(
input_signal, length, fold_consecutive=True, return_all_hypotheses=True
input_signal, length, fold_consecutive=True, return_hypotheses=True
)

batched_hyps = batched_decoding.ctc_decoder_predictions_tensor(
input_signal, length, fold_consecutive=True, return_all_hypotheses=True
input_signal, length, fold_consecutive=True, return_hypotheses=True
)

assert len(hyps) == len(batched_hyps) == B
Expand Down Expand Up @@ -353,11 +353,11 @@ def test_batched_decoding_labels(self, tmp_tokenizer, timestamps, length_is_none

with torch.inference_mode():
hyps = unbatched_decoding.ctc_decoder_predictions_tensor(
input_labels, length, fold_consecutive=True, return_all_hypotheses=True
input_labels, length, fold_consecutive=True, return_hypotheses=True
)

batched_hyps = batched_decoding.ctc_decoder_predictions_tensor(
input_labels, length, fold_consecutive=True, return_all_hypotheses=True
input_labels, length, fold_consecutive=True, return_hypotheses=True
)

assert len(hyps) == len(batched_hyps) == B
Expand Down
4 changes: 2 additions & 2 deletions tests/collections/asr/decoding/test_rnnt_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,7 +438,7 @@ def test_subword_decoding_compute_timestamps(self, test_data_dir, decoding_strat
decoding_cfg=cfg, decoder=model.decoder, joint=model.joint, tokenizer=model.tokenizer
)

hyps = decoding.rnnt_decoder_predictions_tensor(encoded, encoded_len, return_all_hypotheses=True)
hyps = decoding.rnnt_decoder_predictions_tensor(encoded, encoded_len, return_hypotheses=True)

check_subword_timestamps(hyps[0], decoding)

Expand Down Expand Up @@ -473,7 +473,7 @@ def test_char_decoding_compute_timestamps(self, test_data_dir, decoding_strategy

decoding = RNNTDecoding(decoding_cfg=cfg, decoder=model.decoder, joint=model.joint, vocabulary=vocab)

hyps = decoding.rnnt_decoder_predictions_tensor(encoded, encoded_len, return_all_hypotheses=True)
hyps = decoding.rnnt_decoder_predictions_tensor(encoded, encoded_len, return_hypotheses=True)

check_char_timestamps(hyps[0], decoding)

Expand Down
Loading

0 comments on commit 3987d43

Please sign in to comment.