Skip to content

Commit

Permalink
Revert "Fix GreedyBatchedCTCInfer regression from GreedyCTCInfer. (#9347
Browse files Browse the repository at this point in the history
)" (#9351)

This reverts commit aed9d07.
  • Loading branch information
titu1994 authored May 30, 2024
1 parent aed9d07 commit f397086
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 76 deletions.
12 changes: 1 addition & 11 deletions nemo/collections/asr/parts/submodules/ctc_greedy_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,17 +394,7 @@ def forward(

if decoder_lengths is None:
logging.warning(_DECODER_LENGTHS_NONE_WARNING, mode=logging_mode.ONCE)
decoder_lengths = torch.tensor(
[decoder_output.shape[1]], dtype=torch.long, device=decoder_output.device
).expand(decoder_output.shape[0])

# GreedyCTCInfer::forward(), by accident, works with
# decoder_lengths on either CPU or GPU when decoder_output is
# on GPU. For the sake of backwards compatibility, we also
# allow decoder_lengths to be on the CPU device. In this case,
# we simply copy the decoder_lengths from CPU to GPU. If both
# tensors are already on the same device, this is a no-op.
decoder_lengths = decoder_lengths.to(decoder_output.device)
decoder_lengths = torch.tensor([decoder_output.shape[1]], dtype=torch.long).expand(decoder_output.shape[0])

if decoder_output.ndim == 2:
hypotheses = self._greedy_decode_labels_batched(decoder_output, decoder_lengths)
Expand Down
71 changes: 6 additions & 65 deletions tests/collections/asr/decoding/test_ctc_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,41 +200,8 @@ def test_subword_decoding_greedy_forward_hypotheses(self, tmp_tokenizer, alignme
@pytest.mark.parametrize('timestamps', [False, True])
@pytest.mark.parametrize('preserve_frame_confidence', [False, True])
@pytest.mark.parametrize('length_is_none', [False, True])
@pytest.mark.parametrize(
"logprobs_device",
[
torch.device("cpu"),
pytest.param(
torch.device("cuda"),
marks=pytest.mark.skipif(
not torch.cuda.is_available(),
reason='CUDA required for test.',
),
),
],
)
@pytest.mark.parametrize(
"length_device",
[
torch.device("cpu"),
pytest.param(
torch.device("cuda"),
marks=pytest.mark.skipif(
not torch.cuda.is_available(),
reason='CUDA required for test.',
),
),
],
)
def test_batched_decoding_logprobs(
self,
tmp_tokenizer,
alignments,
timestamps,
preserve_frame_confidence,
length_is_none,
logprobs_device,
length_device,
self, tmp_tokenizer, alignments, timestamps, preserve_frame_confidence, length_is_none
):
cfg = CTCBPEDecodingConfig(
strategy='greedy',
Expand All @@ -250,15 +217,15 @@ def test_batched_decoding_logprobs(
torch.manual_seed(1)
B, T = 4, 20
V = unbatched_decoding.tokenizer.tokenizer.vocab_size + 1
input_signal = torch.randn(size=(B, T, V), device=logprobs_device)
input_signal = torch.randn(size=(B, T, V))
# Set the blank index to a very high probability to make sure
# that we always handle at least a few blanks.
input_signal[:, 0, unbatched_decoding.tokenizer.tokenizer.vocab_size] = 1000
input_signal[:, 1, unbatched_decoding.tokenizer.tokenizer.vocab_size] = 1000
if length_is_none:
length = None
else:
length = torch.randint(low=1, high=T, size=[B], device=length_device)
length = torch.randint(low=1, high=T, size=[B])

with torch.inference_mode():
hyps, _ = unbatched_decoding.ctc_decoder_predictions_tensor(
Expand All @@ -282,33 +249,7 @@ def test_batched_decoding_logprobs(
@pytest.mark.unit
@pytest.mark.parametrize('timestamps', [False, True])
@pytest.mark.parametrize('length_is_none', [False, True])
@pytest.mark.parametrize(
"labels_device",
[
torch.device("cpu"),
pytest.param(
torch.device("cuda"),
marks=pytest.mark.skipif(
not torch.cuda.is_available(),
reason='CUDA required for test.',
),
),
],
)
@pytest.mark.parametrize(
"length_device",
[
torch.device("cpu"),
pytest.param(
torch.device("cuda"),
marks=pytest.mark.skipif(
not torch.cuda.is_available(),
reason='CUDA required for test.',
),
),
],
)
def test_batched_decoding_labels(self, tmp_tokenizer, timestamps, length_is_none, labels_device, length_device):
def test_batched_decoding_labels(self, tmp_tokenizer, timestamps, length_is_none):
cfg = CTCBPEDecodingConfig(strategy='greedy', compute_timestamps=timestamps)
unbatched_decoding = CTCBPEDecoding(decoding_cfg=cfg, tokenizer=tmp_tokenizer)
cfg.strategy = 'greedy_batched'
Expand All @@ -317,15 +258,15 @@ def test_batched_decoding_labels(self, tmp_tokenizer, timestamps, length_is_none
torch.manual_seed(1)
B, T = 4, 20
V = unbatched_decoding.tokenizer.tokenizer.vocab_size + 1
input_labels = torch.randint(V, size=(B, T), device=labels_device)
input_labels = torch.randint(V, size=(B, T))
# Set some indices to blank to make sure that we always handle
# at least a few blanks.
input_labels[:, 0] = unbatched_decoding.tokenizer.tokenizer.vocab_size
input_labels[:, 1] = unbatched_decoding.tokenizer.tokenizer.vocab_size
if length_is_none:
length = None
else:
length = torch.randint(low=1, high=T, size=[B], device=length_device)
length = torch.randint(low=1, high=T, size=[B])

with torch.inference_mode():
hyps, _ = unbatched_decoding.ctc_decoder_predictions_tensor(
Expand Down

0 comments on commit f397086

Please sign in to comment.