Skip to content

Commit 5d6d1ad

Browse files
authored
[KERNEL] Sampler. CUDA kernel for applying repetition penalty (#18437)
1 parent 1409ef9 commit 5d6d1ad

File tree

7 files changed

+218
-9
lines changed

7 files changed

+218
-9
lines changed

CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,7 @@ set(VLLM_EXT_SRC
242242
"csrc/activation_kernels.cu"
243243
"csrc/layernorm_kernels.cu"
244244
"csrc/layernorm_quant_kernels.cu"
245+
"csrc/sampler.cu"
245246
"csrc/cuda_view.cu"
246247
"csrc/quantization/gptq/q_gemm.cu"
247248
"csrc/quantization/compressed_tensors/int8_quant_kernels.cu"

csrc/ops.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,11 @@ void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight,
9292
void fused_add_rms_norm(torch::Tensor& input, torch::Tensor& residual,
9393
torch::Tensor& weight, double epsilon);
9494

95+
void apply_repetition_penalties_(torch::Tensor& logits,
96+
const torch::Tensor& prompt_mask,
97+
const torch::Tensor& output_mask,
98+
const torch::Tensor& repetition_penalties);
99+
95100
void rms_norm_static_fp8_quant(torch::Tensor& out, torch::Tensor& input,
96101
torch::Tensor& weight, torch::Tensor& scale,
97102
double epsilon);

csrc/sampler.cu

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
#include "dispatch_utils.h"
2+
3+
#include <torch/cuda.h>
4+
#include <c10/cuda/CUDAGuard.h>
5+
6+
#ifndef USE_ROCM
7+
#include <cub/cub.cuh>
8+
#else
9+
#include <hipcub/hipcub.hpp>
10+
#endif
11+
12+
namespace vllm {
13+
14+
template <typename scalar_t>
15+
__global__ void apply_repetition_penalties_kernel(
16+
scalar_t* __restrict__ logits, // [num_seqs, vocab_size]
17+
const bool* __restrict__ prompt_mask, // [num_seqs, vocab_size]
18+
const bool* __restrict__ output_mask, // [num_seqs, vocab_size]
19+
const scalar_t* __restrict__ repetition_penalties, // [num_seqs]
20+
const int num_seqs, const int vocab_size, const int tile_size) {
21+
// Each block handles one sequence and a tile of vocab
22+
const int seq_idx = blockIdx.x;
23+
if (seq_idx >= num_seqs) return;
24+
25+
const int tile_start = blockIdx.y * tile_size;
26+
const int tile_end = min(tile_start + tile_size, vocab_size);
27+
28+
// Load repetition penalty for this sequence
29+
const scalar_t penalty = repetition_penalties[seq_idx];
30+
31+
// Each thread processes multiple vocab items within the tile
32+
for (int vocab_idx = tile_start + threadIdx.x; vocab_idx < tile_end;
33+
vocab_idx += blockDim.x) {
34+
const int64_t idx = static_cast<int64_t>(seq_idx) * vocab_size + vocab_idx;
35+
const bool is_repeated = prompt_mask[idx] || output_mask[idx];
36+
if (is_repeated) {
37+
scalar_t logit = logits[idx];
38+
if (logit > 0) {
39+
logits[idx] = logit / penalty;
40+
} else {
41+
logits[idx] = logit * penalty;
42+
}
43+
}
44+
}
45+
}
46+
47+
} // namespace vllm
48+
49+
void apply_repetition_penalties_(
50+
torch::Tensor& logits, // [num_seqs, vocab_size], in-place
51+
const torch::Tensor& prompt_mask, // [num_seqs, vocab_size]
52+
const torch::Tensor& output_mask, // [num_seqs, vocab_size]
53+
const torch::Tensor& repetition_penalties) { // [num_seqs]
54+
TORCH_CHECK(logits.is_contiguous());
55+
TORCH_CHECK(prompt_mask.is_contiguous());
56+
TORCH_CHECK(output_mask.is_contiguous());
57+
TORCH_CHECK(repetition_penalties.is_contiguous());
58+
59+
int vocab_size = logits.size(-1);
60+
int num_seqs = logits.size(0);
61+
62+
// Get number of SMs on the current device
63+
int sms = 0;
64+
cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount,
65+
logits.get_device());
66+
67+
// Compute tile_num and tile_size
68+
int tile_num =
69+
std::min(vocab_size, std::max(1, (sms + num_seqs - 1) / num_seqs));
70+
int tile_size = (vocab_size + tile_num - 1) / tile_num;
71+
72+
// Each block handles one sequence and a tile of vocab
73+
dim3 grid(num_seqs, tile_num);
74+
dim3 block(std::min(tile_size, 1024));
75+
const at::cuda::OptionalCUDAGuard device_guard(device_of(logits));
76+
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
77+
VLLM_DISPATCH_FLOATING_TYPES(
78+
logits.scalar_type(), "apply_repetition_penalties_kernel", [&] {
79+
vllm::apply_repetition_penalties_kernel<scalar_t>
80+
<<<grid, block, 0, stream>>>(
81+
logits.data_ptr<scalar_t>(), prompt_mask.data_ptr<bool>(),
82+
output_mask.data_ptr<bool>(),
83+
repetition_penalties.data_ptr<scalar_t>(), num_seqs, vocab_size,
84+
tile_size);
85+
});
86+
}

csrc/torch_bindings.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,13 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
170170
"float epsilon) -> ()");
171171
ops.impl("fused_add_rms_norm", torch::kCUDA, &fused_add_rms_norm);
172172

173+
// Apply repetition penalties to logits in-place
174+
ops.def(
175+
"apply_repetition_penalties_(Tensor! logits, Tensor prompt_mask, "
176+
"Tensor output_mask, Tensor repetition_penalties) -> ()");
177+
ops.impl("apply_repetition_penalties_", torch::kCUDA,
178+
&apply_repetition_penalties_);
179+
173180
// Layernorm-quant
174181
// Apply Root Mean Square (RMS) Normalization to the input tensor.
175182
ops.def(
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
import pytest
3+
import torch
4+
5+
from tests.kernels.utils import opcheck
6+
from vllm._custom_ops import (apply_repetition_penalties_cuda,
7+
apply_repetition_penalties_torch)
8+
from vllm.platforms import current_platform
9+
10+
NUM_SEQS = [1, 2, 3, 4, 8, 13, 17, 32, 37, 256, 1023, 1024, 1025]
11+
# [stress, stress, stress, Qwen, llama 4]
12+
VOCAB_SIZES = [17, 256, 1019, 151936, 202048]
13+
REPETITION_PENALTY_VALUES = [1.05]
14+
SEEDS = [0]
15+
DTYPES = [torch.float32, torch.float16]
16+
17+
18+
@pytest.mark.parametrize("num_seqs", NUM_SEQS)
19+
@pytest.mark.parametrize("vocab_size", VOCAB_SIZES)
20+
@pytest.mark.parametrize("repetition_penalty", REPETITION_PENALTY_VALUES)
21+
@pytest.mark.parametrize("dtype", DTYPES)
22+
@pytest.mark.parametrize("seed", SEEDS)
23+
@pytest.mark.skipif(not current_platform.is_cuda(),
24+
reason="This test for checking CUDA kernel")
25+
@torch.inference_mode()
26+
def test_apply_repetition_penalties(
27+
num_seqs: int,
28+
vocab_size: int,
29+
repetition_penalty: float,
30+
dtype: torch.dtype,
31+
seed: int,
32+
) -> None:
33+
"""
34+
Test the apply_repetition_penalties custom op
35+
against a reference implementation.
36+
"""
37+
current_platform.seed_everything(seed)
38+
torch.set_default_device("cuda:0")
39+
40+
# Create test data
41+
logits = torch.randn(num_seqs, vocab_size, dtype=dtype)
42+
43+
# Create masks with some random tokens marked as repeated
44+
prompt_mask = torch.zeros(num_seqs, vocab_size, dtype=torch.bool)
45+
output_mask = torch.zeros(num_seqs, vocab_size, dtype=torch.bool)
46+
47+
# Mark some tokens as repeated in prompt and output
48+
prompt_indices = torch.randint(0, vocab_size,
49+
(num_seqs, max(1, vocab_size // 200)))
50+
output_indices = torch.randint(0, vocab_size,
51+
(num_seqs, max(1, vocab_size // 200)))
52+
53+
for i in range(num_seqs):
54+
prompt_mask[i, prompt_indices[i]] = True
55+
output_mask[i, output_indices[i]] = True
56+
57+
# Create repetition penalties tensor
58+
repetition_penalties = torch.full((num_seqs, ),
59+
repetition_penalty,
60+
dtype=dtype)
61+
62+
# Run all three implementations
63+
logits_torch = logits.clone()
64+
logits_cuda = logits.clone()
65+
66+
apply_repetition_penalties_torch(logits_torch, prompt_mask, output_mask,
67+
repetition_penalties)
68+
apply_repetition_penalties_cuda(logits_cuda, prompt_mask, output_mask,
69+
repetition_penalties)
70+
71+
# Compare all outputs to reference
72+
torch.testing.assert_close(logits_torch, logits_cuda, rtol=1e-3, atol=1e-3)
73+
74+
# Test the operator by applying the opcheck utility
75+
opcheck(torch.ops._C.apply_repetition_penalties_,
76+
(logits.clone(), prompt_mask, output_mask, repetition_penalties))

vllm/_custom_ops.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -282,6 +282,45 @@ def fused_add_rms_norm(input: torch.Tensor, residual: torch.Tensor,
282282
torch.ops._C.fused_add_rms_norm(input, residual, weight, epsilon)
283283

284284

285+
def apply_repetition_penalties_torch(
286+
logits: torch.Tensor, prompt_mask: torch.Tensor,
287+
output_mask: torch.Tensor, repetition_penalties: torch.Tensor) -> None:
288+
repetition_penalties = repetition_penalties.unsqueeze(dim=1).repeat(
289+
1, logits.size(1))
290+
# If token appears in prompt or output, apply, otherwise use 1.0 for no-op.
291+
penalties = torch.where(prompt_mask | output_mask, repetition_penalties,
292+
1.0)
293+
# If logits are positive, divide by penalty, otherwise multiply by penalty.
294+
scaling = torch.where(logits > 0, 1.0 / penalties, penalties)
295+
logits *= scaling
296+
297+
298+
def apply_repetition_penalties_cuda(
299+
logits: torch.Tensor, prompt_mask: torch.Tensor,
300+
output_mask: torch.Tensor, repetition_penalties: torch.Tensor) -> None:
301+
torch.ops._C.apply_repetition_penalties_(logits, prompt_mask, output_mask,
302+
repetition_penalties)
303+
304+
305+
def apply_repetition_penalties(logits: torch.Tensor, prompt_mask: torch.Tensor,
306+
output_mask: torch.Tensor,
307+
repetition_penalties: torch.Tensor) -> None:
308+
"""Apply repetition penalties to logits in-place.
309+
310+
Args:
311+
logits: The logits tensor of shape [num_seqs, vocab_size].
312+
prompt_mask: A boolean tensor indicating which tokens appear in the prompt.
313+
output_mask: A boolean tensor indicating which tokens appear in the output.
314+
repetition_penalties: The repetition penalties of shape (num_seqs, ).
315+
"""
316+
if current_platform.is_cuda() and logits.is_contiguous():
317+
apply_repetition_penalties_cuda(logits, prompt_mask, output_mask,
318+
repetition_penalties)
319+
else:
320+
apply_repetition_penalties_torch(logits, prompt_mask, output_mask,
321+
repetition_penalties)
322+
323+
285324
def advance_step_flashattn(num_seqs: int, num_queries: int, block_size: int,
286325
input_tokens: torch.Tensor,
287326
sampled_token_ids: torch.Tensor,

vllm/model_executor/layers/utils.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -50,16 +50,11 @@ def apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor,
5050
vocab_size, num_seqs)
5151
output_bin_counts, output_mask = get_token_bin_counts_and_mask(
5252
output_tokens_tensor, vocab_size, num_seqs)
53-
repetition_penalties = repetition_penalties.unsqueeze(dim=1).repeat(
54-
1, vocab_size)
5553

56-
# If token appears in prompt or output, apply, otherwise use 1.0 for no-op.
57-
penalties = torch.where(prompt_mask | output_mask, repetition_penalties,
58-
1.0)
59-
60-
# If logits are positive, divide by penalty, otherwise multiply by penalty.
61-
scaling = torch.where(logits > 0, 1.0 / penalties, penalties)
62-
logits *= scaling
54+
# Apply repetition penalties as a custom op
55+
from vllm._custom_ops import apply_repetition_penalties
56+
apply_repetition_penalties(logits, prompt_mask, output_mask,
57+
repetition_penalties)
6358

6459
# We follow the definition in OpenAI API.
6560
# Refer to https://platform.openai.com/docs/api-reference/parameter-details

0 commit comments

Comments
 (0)