diff --git a/csrc/sampler.cu b/csrc/sampler.cu index ee5793dda0ef..b0cce2e98d22 100644 --- a/csrc/sampler.cu +++ b/csrc/sampler.cu @@ -59,6 +59,8 @@ void apply_repetition_penalties_( int vocab_size = logits.size(-1); int num_seqs = logits.size(0); + if (num_seqs == 0) return; + // Get number of SMs on the current device int sms = 0; cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, diff --git a/tests/kernels/test_apply_repetition_penalties.py b/tests/kernels/test_apply_repetition_penalties.py index 5df52dc42f0a..90380b872d6c 100644 --- a/tests/kernels/test_apply_repetition_penalties.py +++ b/tests/kernels/test_apply_repetition_penalties.py @@ -75,3 +75,51 @@ def test_apply_repetition_penalties( # Test the operator by applying the opcheck utility opcheck(torch.ops._C.apply_repetition_penalties_, (logits.clone(), prompt_mask, output_mask, repetition_penalties)) + + +@pytest.mark.skipif(not current_platform.is_cuda(), + reason="This test for checking CUDA kernel") +@torch.inference_mode() +def test_apply_repetition_penalties_zero_seqs() -> None: + """ + Test the apply_repetition_penalties custom op with num_seqs=0 + against a reference implementation. + """ + num_seqs = 0 + vocab_size = 17 + repetition_penalty = 1.05 + dtype = torch.float32 + seed = 0 + + current_platform.seed_everything(seed) + torch.set_default_device("cuda:0") + + # Create test data + logits = torch.randn(num_seqs, vocab_size, dtype=dtype) + + # Create masks with some random tokens marked as repeated + prompt_mask = torch.zeros(num_seqs, vocab_size, dtype=torch.bool) + output_mask = torch.zeros(num_seqs, vocab_size, dtype=torch.bool) + + # No tokens to mark as repeated since num_seqs=0 + + # Create repetition penalties tensor + repetition_penalties = torch.full((num_seqs, ), + repetition_penalty, + dtype=dtype) + + # Run all three implementations + logits_torch = logits.clone() + logits_cuda = logits.clone() + + apply_repetition_penalties_torch(logits_torch, prompt_mask, output_mask, + repetition_penalties) + apply_repetition_penalties_cuda(logits_cuda, prompt_mask, output_mask, + repetition_penalties) + + # Compare all outputs to reference + torch.testing.assert_close(logits_torch, logits_cuda, rtol=1e-3, atol=1e-3) + + # Test the operator by applying the opcheck utility + opcheck(torch.ops._C.apply_repetition_penalties_, + (logits.clone(), prompt_mask, output_mask, repetition_penalties))