Skip to content

Commit 36a963f

Browse files
IwakuraReindcampora
authored andcommitted
[Deepseek v3.2] Remove extra logics in indexer (vllm-project#26465)
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com> Signed-off-by: Daniel Campora <961215+dcampora@users.noreply.github.com> Signed-off-by: Lain <siyuanf@nvidia.com> Co-authored-by: Daniel Campora <961215+dcampora@users.noreply.github.com> Signed-off-by: 0xrushi <6279035+0xrushi@users.noreply.github.com>
1 parent d703e3e commit 36a963f

File tree

5 files changed

+141
-40
lines changed

5 files changed

+141
-40
lines changed

csrc/ops.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,10 @@ void top_k_per_row(const torch::Tensor& logits, const torch::Tensor& rowStarts,
101101
const torch::Tensor& rowEnds, torch::Tensor& indices,
102102
int64_t numRows, int64_t stride0, int64_t stride1);
103103

104+
void top_k_per_row_decode(const torch::Tensor& logits, int64_t next_n,
105+
const torch::Tensor& seq_lens, torch::Tensor& indices,
106+
int64_t numRows, int64_t stride0, int64_t stride1);
107+
104108
void rms_norm_static_fp8_quant(torch::Tensor& out, torch::Tensor& input,
105109
torch::Tensor& weight, torch::Tensor& scale,
106110
double epsilon);

csrc/sampler.cu

Lines changed: 61 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -54,15 +54,10 @@ static inline __device__ uint16_t extractBinIdx(float x) {
5454
return 511 - (tmp.u16 >> 7);
5555
}
5656

57-
template <int kNumThreadsPerBlock = 512>
58-
static __global__ void topKPerRow(const float* logits, const int* rowStarts,
59-
const int* rowEnds, int* outIndices,
60-
int stride0, int stride1) {
61-
// The number of bins in the histogram.
62-
static constexpr int kNumBins = 512;
63-
64-
// The top-k width.
65-
static constexpr int kTopK = 2048;
57+
template <int kNumThreadsPerBlock = 512, int kNumBins = 512, int kTopK = 2048>
58+
__device__ void topKPerRowJob(const float* logits, const int rowStart,
59+
const int rowEnd, const int rowIdx,
60+
int* outIndices, int stride0, int stride1) {
6661
// The number of elements per thread for the final top-k sort.
6762
static constexpr int kNumTopKItemsPerThread = kTopK / kNumThreadsPerBlock;
6863
// The class to sort the elements during the final top-k sort.
@@ -108,10 +103,6 @@ static __global__ void topKPerRow(const float* logits, const int* rowStarts,
108103
// Shared memory counter to register the candidates for the final phase.
109104
__shared__ int smemFinalDstIdx[1];
110105

111-
// The row computed by this block.
112-
int rowIdx = blockIdx.x;
113-
// The range of logits within the row.
114-
int rowStart = rowStarts[rowIdx], rowEnd = rowEnds[rowIdx];
115106
// The length of the row.
116107
int rowLen = rowEnd - rowStart;
117108

@@ -260,6 +251,49 @@ static __global__ void topKPerRow(const float* logits, const int* rowStarts,
260251
}
261252
}
262253

254+
template <int kNumThreadsPerBlock = 512>
255+
static __global__ void topKPerRow(const float* logits, const int* rowStarts,
256+
const int* rowEnds, int* outIndices,
257+
int stride0, int stride1) {
258+
// The number of bins in the histogram.
259+
static constexpr int kNumBins = 512;
260+
261+
// The top-k width.
262+
static constexpr int kTopK = 2048;
263+
264+
// The row computed by this block.
265+
int rowIdx = blockIdx.x;
266+
267+
// The range of logits within the row.
268+
int rowStart = rowStarts[rowIdx];
269+
int rowEnd = rowEnds[rowIdx];
270+
271+
topKPerRowJob<kNumThreadsPerBlock, kNumBins, kTopK>(
272+
logits, rowStart, rowEnd, rowIdx, outIndices, stride0, stride1);
273+
}
274+
275+
template <int kNumThreadsPerBlock = 512>
276+
static __global__ void topKPerRowDecode(const float* logits, const int* seqLens,
277+
int* outIndices, int stride0,
278+
int stride1, int next_n) {
279+
// The number of bins in the histogram.
280+
static constexpr int kNumBins = 512;
281+
282+
// The top-k width.
283+
static constexpr int kTopK = 2048;
284+
285+
// The row computed by this block.
286+
int rowIdx = blockIdx.x;
287+
288+
// The range of logits within the row.
289+
int rowStart = 0;
290+
int seq_len = seqLens[rowIdx / next_n];
291+
int rowEnd = seq_len - next_n + (rowIdx % next_n) + 1;
292+
293+
topKPerRowJob<kNumThreadsPerBlock, kNumBins, kTopK>(
294+
logits, rowStart, rowEnd, rowIdx, outIndices, stride0, stride1);
295+
}
296+
263297
} // namespace vllm
264298

265299
void apply_repetition_penalties_(
@@ -303,6 +337,20 @@ void apply_repetition_penalties_(
303337
});
304338
}
305339

340+
void top_k_per_row_decode(const torch::Tensor& logits, int64_t next_n,
341+
const torch::Tensor& seqLens, torch::Tensor& indices,
342+
int64_t numRows, int64_t stride0, int64_t stride1) {
343+
// Compute the results on the device.
344+
constexpr int kNumThreadsPerBlock = 512;
345+
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
346+
347+
vllm::topKPerRowDecode<kNumThreadsPerBlock>
348+
<<<numRows, kNumThreadsPerBlock, 0, stream>>>(
349+
logits.data_ptr<float>(), seqLens.data_ptr<int>(),
350+
indices.data_ptr<int>(), static_cast<int>(stride0),
351+
static_cast<int>(stride1), static_cast<int>(next_n));
352+
}
353+
306354
void top_k_per_row(const torch::Tensor& logits, const torch::Tensor& rowStarts,
307355
const torch::Tensor& rowEnds, torch::Tensor& indices,
308356
int64_t numRows, int64_t stride0, int64_t stride1) {

csrc/torch_bindings.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,12 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
189189
"int stride1) -> ()");
190190
ops.impl("top_k_per_row", torch::kCUDA, &top_k_per_row);
191191

192+
ops.def(
193+
"top_k_per_row_decode(Tensor logits, int next_n, "
194+
"Tensor seq_lens, Tensor! indices, int numRows, "
195+
"int stride0, int stride1) -> ()");
196+
ops.impl("top_k_per_row_decode", torch::kCUDA, &top_k_per_row_decode);
197+
192198
// Layernorm-quant
193199
// Apply Root Mean Square (RMS) Normalization to the input tensor.
194200
ops.def(

tests/kernels/test_top_k_per_row.py

Lines changed: 59 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
# Test parameters
1111
NUM_ROWS = [1, 32, 2050]
1212
TOP_K_VALUES = [2048]
13+
BATCH_SIZE = [1, 2, 4, 2048, 4096]
14+
NEXT_N = [1, 2, 4, 8]
1315

1416

1517
def create_random_logits(
@@ -114,7 +116,7 @@ def test_top_k_per_row(
114116
logits = create_random_logits(row_starts, row_ends, vocab_size, torch.float32, 42)
115117

116118
# Create output tensors
117-
indices = torch.empty((num_rows, 2048), dtype=torch.int32, device="cuda")
119+
indices = torch.empty((num_rows, top_k), dtype=torch.int32, device="cuda")
118120

119121
# Run CUDA implementation
120122
torch.ops._C.top_k_per_row(
@@ -138,3 +140,59 @@ def test_top_k_per_row(
138140
assert compare_top_k_results(
139141
logits, indices, torch_indices, row_starts, row_ends, top_k
140142
), "CUDA top_k_per_row results don't match torch.topk"
143+
144+
145+
@pytest.mark.parametrize("top_k", TOP_K_VALUES)
146+
@pytest.mark.parametrize("batch_size", BATCH_SIZE)
147+
@pytest.mark.parametrize("next_n", NEXT_N)
148+
@pytest.mark.skipif(not current_platform.is_cuda(), reason="This test requires CUDA")
149+
@torch.inference_mode()
150+
def test_top_k_per_row_decode(
151+
top_k: int,
152+
batch_size: int,
153+
next_n: int,
154+
) -> None:
155+
"""
156+
Test top_k_per_row with seq_lens tensor.
157+
"""
158+
torch.set_default_device("cuda:0")
159+
160+
# Create test data
161+
num_rows = batch_size * next_n
162+
vocab_size = 20000
163+
seq_lens = torch.randint(
164+
vocab_size, (batch_size,), dtype=torch.int32, device="cuda"
165+
)
166+
row_starts = torch.zeros(num_rows, dtype=torch.int32, device="cuda")
167+
row_indices = torch.arange(num_rows, device="cuda") // next_n
168+
next_n_offset = torch.arange(num_rows, device="cuda") % next_n
169+
row_ends = seq_lens[row_indices] - next_n + next_n_offset + 1
170+
logits = create_random_logits(row_starts, row_ends, vocab_size, torch.float32, 42)
171+
172+
# Create output tensors
173+
indices = torch.empty((num_rows, top_k), dtype=torch.int32, device="cuda")
174+
175+
# Run CUDA implementation
176+
torch.ops._C.top_k_per_row_decode(
177+
logits,
178+
next_n,
179+
seq_lens,
180+
indices,
181+
num_rows,
182+
logits.stride(0),
183+
logits.stride(1),
184+
)
185+
186+
torch.cuda.synchronize()
187+
188+
# Run reference implementation
189+
torch_indices = logits.topk(min(top_k, max(row_ends)), dim=-1)[1]
190+
mask_lo = torch_indices >= 0
191+
mask_hi = (torch_indices - (row_ends - row_starts)[:, None]) < 0
192+
mask = mask_lo & mask_hi
193+
torch_indices = torch_indices.masked_fill(~mask, -1)
194+
195+
# Compare results
196+
assert compare_top_k_results(
197+
logits, indices, torch_indices, row_starts, row_ends, top_k
198+
), "CUDA top_k_per_row results don't match torch.topk"

vllm/model_executor/models/deepseek_v2.py

Lines changed: 11 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -580,9 +580,9 @@ def sparse_attn_indexer(
580580
)
581581
num_rows = logits.shape[0]
582582
assert topk_tokens == 2048, "top_k_per_row assumes size 2048"
583-
topk_indices = torch.empty(
584-
num_rows, topk_tokens, dtype=torch.int32, device=logits.device
585-
)
583+
topk_indices = topk_indices_buffer[
584+
chunk.token_start : chunk.token_end, :topk_tokens
585+
]
586586
torch.ops._C.top_k_per_row(
587587
logits,
588588
chunk.cu_seqlen_ks,
@@ -592,9 +592,6 @@ def sparse_attn_indexer(
592592
logits.stride(0),
593593
logits.stride(1),
594594
)
595-
topk_indices_buffer[
596-
chunk.token_start : chunk.token_end, : topk_indices.shape[-1]
597-
] = topk_indices.to(dtype=torch.int32)
598595

599596
if has_decode:
600597
decode_metadata = attn_metadata.decode
@@ -628,26 +625,14 @@ def sparse_attn_indexer(
628625
decode_metadata.schedule_metadata,
629626
max_model_len=max_model_len,
630627
)
631-
# padded query len
632-
current_device = padded_q_fp8_decode_tokens.device
633-
padded_num_tokens = batch_size * next_n
634-
row_indices = torch.arange(padded_num_tokens, device=current_device) // next_n
635-
next_n_offset = (
636-
torch.arange(padded_num_tokens, device=padded_q_fp8_decode_tokens.device)
637-
% next_n
638-
)
639-
index_end_pos = (
640-
decode_metadata.seq_lens[row_indices] - next_n + next_n_offset + 1
641-
).unsqueeze(1)
642628
num_rows = logits.shape[0]
643629
assert topk_tokens == 2048, "top_k_per_row assumes size 2048"
644-
topk_indices = torch.empty(
645-
num_rows, topk_tokens, dtype=torch.int32, device=logits.device
646-
)
647-
torch.ops._C.top_k_per_row(
630+
topk_indices = topk_indices_buffer[:num_decode_tokens, :topk_tokens]
631+
632+
torch.ops._C.top_k_per_row_decode(
648633
logits,
649-
torch.zeros(num_rows, dtype=torch.int32, device=logits.device),
650-
index_end_pos.to(dtype=torch.int32, device=logits.device),
634+
next_n,
635+
decode_metadata.seq_lens,
651636
topk_indices,
652637
num_rows,
653638
logits.stride(0),
@@ -660,9 +645,9 @@ def sparse_attn_indexer(
660645
topk_indices.reshape(batch_size, -1, topk_indices.shape[-1]),
661646
decode_lens,
662647
)
663-
topk_indices_buffer[:num_decode_tokens, : topk_indices.shape[-1]] = (
664-
topk_indices.to(dtype=torch.int32)
665-
)
648+
topk_indices_buffer[:num_decode_tokens, : topk_indices.shape[-1]] = (
649+
topk_indices
650+
)
666651

667652
return topk_indices_buffer
668653

0 commit comments

Comments
 (0)