diff --git a/nemo/collections/asr/parts/submodules/ctc_greedy_decoding.py b/nemo/collections/asr/parts/submodules/ctc_greedy_decoding.py index a7f57c82279a..74204cf73d8e 100644 --- a/nemo/collections/asr/parts/submodules/ctc_greedy_decoding.py +++ b/nemo/collections/asr/parts/submodules/ctc_greedy_decoding.py @@ -394,7 +394,17 @@ def forward( if decoder_lengths is None: logging.warning(_DECODER_LENGTHS_NONE_WARNING, mode=logging_mode.ONCE) - decoder_lengths = torch.tensor([decoder_output.shape[1]], dtype=torch.long).expand(decoder_output.shape[0]) + decoder_lengths = torch.tensor( + [decoder_output.shape[1]], dtype=torch.long, device=decoder_output.device + ).expand(decoder_output.shape[0]) + + # GreedyCTCInfer::forward(), by accident, works with + # decoder_lengths on either CPU or GPU when decoder_output is + # on GPU. For the sake of backwards compatibility, we also + # allow decoder_lengths to be on the CPU device. In this case, + # we simply copy the decoder_lengths from CPU to GPU. If both + # tensors are already on the same device, this is a no-op. + decoder_lengths = decoder_lengths.to(decoder_output.device) if decoder_output.ndim == 2: hypotheses = self._greedy_decode_labels_batched(decoder_output, decoder_lengths) diff --git a/tests/collections/asr/decoding/test_ctc_decoding.py b/tests/collections/asr/decoding/test_ctc_decoding.py index a42d61f051ad..580344fed395 100644 --- a/tests/collections/asr/decoding/test_ctc_decoding.py +++ b/tests/collections/asr/decoding/test_ctc_decoding.py @@ -200,8 +200,41 @@ def test_subword_decoding_greedy_forward_hypotheses(self, tmp_tokenizer, alignme @pytest.mark.parametrize('timestamps', [False, True]) @pytest.mark.parametrize('preserve_frame_confidence', [False, True]) @pytest.mark.parametrize('length_is_none', [False, True]) + @pytest.mark.parametrize( + "logprobs_device", + [ + torch.device("cpu"), + pytest.param( + torch.device("cuda"), + marks=pytest.mark.skipif( + not torch.cuda.is_available(), + reason='CUDA required for test.', + ), + ), + ], + ) + @pytest.mark.parametrize( + "length_device", + [ + torch.device("cpu"), + pytest.param( + torch.device("cuda"), + marks=pytest.mark.skipif( + not torch.cuda.is_available(), + reason='CUDA required for test.', + ), + ), + ], + ) def test_batched_decoding_logprobs( - self, tmp_tokenizer, alignments, timestamps, preserve_frame_confidence, length_is_none + self, + tmp_tokenizer, + alignments, + timestamps, + preserve_frame_confidence, + length_is_none, + logprobs_device, + length_device, ): cfg = CTCBPEDecodingConfig( strategy='greedy', @@ -217,7 +250,7 @@ def test_batched_decoding_logprobs( torch.manual_seed(1) B, T = 4, 20 V = unbatched_decoding.tokenizer.tokenizer.vocab_size + 1 - input_signal = torch.randn(size=(B, T, V)) + input_signal = torch.randn(size=(B, T, V), device=logprobs_device) # Set the blank index to a very high probability to make sure # that we always handle at least a few blanks. input_signal[:, 0, unbatched_decoding.tokenizer.tokenizer.vocab_size] = 1000 @@ -225,7 +258,7 @@ def test_batched_decoding_logprobs( if length_is_none: length = None else: - length = torch.randint(low=1, high=T, size=[B]) + length = torch.randint(low=1, high=T, size=[B], device=length_device) with torch.inference_mode(): hyps, _ = unbatched_decoding.ctc_decoder_predictions_tensor( @@ -249,7 +282,33 @@ def test_batched_decoding_logprobs( @pytest.mark.unit @pytest.mark.parametrize('timestamps', [False, True]) @pytest.mark.parametrize('length_is_none', [False, True]) - def test_batched_decoding_labels(self, tmp_tokenizer, timestamps, length_is_none): + @pytest.mark.parametrize( + "labels_device", + [ + torch.device("cpu"), + pytest.param( + torch.device("cuda"), + marks=pytest.mark.skipif( + not torch.cuda.is_available(), + reason='CUDA required for test.', + ), + ), + ], + ) + @pytest.mark.parametrize( + "length_device", + [ + torch.device("cpu"), + pytest.param( + torch.device("cuda"), + marks=pytest.mark.skipif( + not torch.cuda.is_available(), + reason='CUDA required for test.', + ), + ), + ], + ) + def test_batched_decoding_labels(self, tmp_tokenizer, timestamps, length_is_none, labels_device, length_device): cfg = CTCBPEDecodingConfig(strategy='greedy', compute_timestamps=timestamps) unbatched_decoding = CTCBPEDecoding(decoding_cfg=cfg, tokenizer=tmp_tokenizer) cfg.strategy = 'greedy_batched' @@ -258,7 +317,7 @@ def test_batched_decoding_labels(self, tmp_tokenizer, timestamps, length_is_none torch.manual_seed(1) B, T = 4, 20 V = unbatched_decoding.tokenizer.tokenizer.vocab_size + 1 - input_labels = torch.randint(V, size=(B, T)) + input_labels = torch.randint(V, size=(B, T), device=labels_device) # Set some indices to blank to make sure that we always handle # at least a few blanks. input_labels[:, 0] = unbatched_decoding.tokenizer.tokenizer.vocab_size @@ -266,7 +325,7 @@ def test_batched_decoding_labels(self, tmp_tokenizer, timestamps, length_is_none if length_is_none: length = None else: - length = torch.randint(low=1, high=T, size=[B]) + length = torch.randint(low=1, high=T, size=[B], device=length_device) with torch.inference_mode(): hyps, _ = unbatched_decoding.ctc_decoder_predictions_tensor(