Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions csrc/sampler.cu
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ void apply_repetition_penalties_(
int vocab_size = logits.size(-1);
int num_seqs = logits.size(0);

if (num_seqs == 0) return;

// Get number of SMs on the current device
int sms = 0;
cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount,
Expand Down
48 changes: 48 additions & 0 deletions tests/kernels/test_apply_repetition_penalties.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,3 +75,51 @@ def test_apply_repetition_penalties(
# Test the operator by applying the opcheck utility
opcheck(torch.ops._C.apply_repetition_penalties_,
(logits.clone(), prompt_mask, output_mask, repetition_penalties))


@pytest.mark.skipif(not current_platform.is_cuda(),
reason="This test for checking CUDA kernel")
@torch.inference_mode()
def test_apply_repetition_penalties_zero_seqs() -> None:
"""
Test the apply_repetition_penalties custom op with num_seqs=0
against a reference implementation.
"""
num_seqs = 0
vocab_size = 17
repetition_penalty = 1.05
dtype = torch.float32
seed = 0

current_platform.seed_everything(seed)
torch.set_default_device("cuda:0")

# Create test data
logits = torch.randn(num_seqs, vocab_size, dtype=dtype)

# Create masks with some random tokens marked as repeated
prompt_mask = torch.zeros(num_seqs, vocab_size, dtype=torch.bool)
output_mask = torch.zeros(num_seqs, vocab_size, dtype=torch.bool)

# No tokens to mark as repeated since num_seqs=0

# Create repetition penalties tensor
repetition_penalties = torch.full((num_seqs, ),
repetition_penalty,
dtype=dtype)

# Run all three implementations
logits_torch = logits.clone()
logits_cuda = logits.clone()

apply_repetition_penalties_torch(logits_torch, prompt_mask, output_mask,
repetition_penalties)
apply_repetition_penalties_cuda(logits_cuda, prompt_mask, output_mask,
repetition_penalties)

# Compare all outputs to reference
torch.testing.assert_close(logits_torch, logits_cuda, rtol=1e-3, atol=1e-3)

# Test the operator by applying the opcheck utility
opcheck(torch.ops._C.apply_repetition_penalties_,
(logits.clone(), prompt_mask, output_mask, repetition_penalties))
Loading