Skip to content

Commit

Permalink
Fix GreedyBatchedCTCInfer regression from GreedyCTCInfer. (#9347)
Browse files Browse the repository at this point in the history
* Fix GreedyBatchedCTCInfer regression from GreedyCTCInfer.

decoder_lengths is allowed to be on CPU even when decoder_output is on
GPU. This matches the behavior of GreedyCTCInfer. Even though that
behavior is unintentional, there is code depending on that behavior,
including our jupyter notebooks.

Signed-off-by: Daniel Galvez <dgalvez@nvidia.com>

* Apply isort and black reformatting

Signed-off-by: titu1994 <titu1994@users.noreply.github.com>

---------

Signed-off-by: Daniel Galvez <dgalvez@nvidia.com>
Signed-off-by: titu1994 <titu1994@users.noreply.github.com>
Co-authored-by: Somshubra Majumdar <titu1994@gmail.com>
Co-authored-by: titu1994 <titu1994@users.noreply.github.com>
Co-authored-by: Nithin Rao <nithinrao.koluguri@gmail.com>
  • Loading branch information
4 people authored May 30, 2024
1 parent b6595cb commit aed9d07
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 7 deletions.
12 changes: 11 additions & 1 deletion nemo/collections/asr/parts/submodules/ctc_greedy_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,7 +394,17 @@ def forward(

if decoder_lengths is None:
logging.warning(_DECODER_LENGTHS_NONE_WARNING, mode=logging_mode.ONCE)
decoder_lengths = torch.tensor([decoder_output.shape[1]], dtype=torch.long).expand(decoder_output.shape[0])
decoder_lengths = torch.tensor(
[decoder_output.shape[1]], dtype=torch.long, device=decoder_output.device
).expand(decoder_output.shape[0])

# GreedyCTCInfer::forward(), by accident, works with
# decoder_lengths on either CPU or GPU when decoder_output is
# on GPU. For the sake of backwards compatibility, we also
# allow decoder_lengths to be on the CPU device. In this case,
# we simply copy the decoder_lengths from CPU to GPU. If both
# tensors are already on the same device, this is a no-op.
decoder_lengths = decoder_lengths.to(decoder_output.device)

if decoder_output.ndim == 2:
hypotheses = self._greedy_decode_labels_batched(decoder_output, decoder_lengths)
Expand Down
71 changes: 65 additions & 6 deletions tests/collections/asr/decoding/test_ctc_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,8 +200,41 @@ def test_subword_decoding_greedy_forward_hypotheses(self, tmp_tokenizer, alignme
@pytest.mark.parametrize('timestamps', [False, True])
@pytest.mark.parametrize('preserve_frame_confidence', [False, True])
@pytest.mark.parametrize('length_is_none', [False, True])
@pytest.mark.parametrize(
"logprobs_device",
[
torch.device("cpu"),
pytest.param(
torch.device("cuda"),
marks=pytest.mark.skipif(
not torch.cuda.is_available(),
reason='CUDA required for test.',
),
),
],
)
@pytest.mark.parametrize(
"length_device",
[
torch.device("cpu"),
pytest.param(
torch.device("cuda"),
marks=pytest.mark.skipif(
not torch.cuda.is_available(),
reason='CUDA required for test.',
),
),
],
)
def test_batched_decoding_logprobs(
self, tmp_tokenizer, alignments, timestamps, preserve_frame_confidence, length_is_none
self,
tmp_tokenizer,
alignments,
timestamps,
preserve_frame_confidence,
length_is_none,
logprobs_device,
length_device,
):
cfg = CTCBPEDecodingConfig(
strategy='greedy',
Expand All @@ -217,15 +250,15 @@ def test_batched_decoding_logprobs(
torch.manual_seed(1)
B, T = 4, 20
V = unbatched_decoding.tokenizer.tokenizer.vocab_size + 1
input_signal = torch.randn(size=(B, T, V))
input_signal = torch.randn(size=(B, T, V), device=logprobs_device)
# Set the blank index to a very high probability to make sure
# that we always handle at least a few blanks.
input_signal[:, 0, unbatched_decoding.tokenizer.tokenizer.vocab_size] = 1000
input_signal[:, 1, unbatched_decoding.tokenizer.tokenizer.vocab_size] = 1000
if length_is_none:
length = None
else:
length = torch.randint(low=1, high=T, size=[B])
length = torch.randint(low=1, high=T, size=[B], device=length_device)

with torch.inference_mode():
hyps, _ = unbatched_decoding.ctc_decoder_predictions_tensor(
Expand All @@ -249,7 +282,33 @@ def test_batched_decoding_logprobs(
@pytest.mark.unit
@pytest.mark.parametrize('timestamps', [False, True])
@pytest.mark.parametrize('length_is_none', [False, True])
def test_batched_decoding_labels(self, tmp_tokenizer, timestamps, length_is_none):
@pytest.mark.parametrize(
"labels_device",
[
torch.device("cpu"),
pytest.param(
torch.device("cuda"),
marks=pytest.mark.skipif(
not torch.cuda.is_available(),
reason='CUDA required for test.',
),
),
],
)
@pytest.mark.parametrize(
"length_device",
[
torch.device("cpu"),
pytest.param(
torch.device("cuda"),
marks=pytest.mark.skipif(
not torch.cuda.is_available(),
reason='CUDA required for test.',
),
),
],
)
def test_batched_decoding_labels(self, tmp_tokenizer, timestamps, length_is_none, labels_device, length_device):
cfg = CTCBPEDecodingConfig(strategy='greedy', compute_timestamps=timestamps)
unbatched_decoding = CTCBPEDecoding(decoding_cfg=cfg, tokenizer=tmp_tokenizer)
cfg.strategy = 'greedy_batched'
Expand All @@ -258,15 +317,15 @@ def test_batched_decoding_labels(self, tmp_tokenizer, timestamps, length_is_none
torch.manual_seed(1)
B, T = 4, 20
V = unbatched_decoding.tokenizer.tokenizer.vocab_size + 1
input_labels = torch.randint(V, size=(B, T))
input_labels = torch.randint(V, size=(B, T), device=labels_device)
# Set some indices to blank to make sure that we always handle
# at least a few blanks.
input_labels[:, 0] = unbatched_decoding.tokenizer.tokenizer.vocab_size
input_labels[:, 1] = unbatched_decoding.tokenizer.tokenizer.vocab_size
if length_is_none:
length = None
else:
length = torch.randint(low=1, high=T, size=[B])
length = torch.randint(low=1, high=T, size=[B], device=length_device)

with torch.inference_mode():
hyps, _ = unbatched_decoding.ctc_decoder_predictions_tensor(
Expand Down

0 comments on commit aed9d07

Please sign in to comment.