@@ -54,15 +54,10 @@ static inline __device__ uint16_t extractBinIdx(float x) {
5454 return 511 - (tmp.u16 >> 7 );
5555}
5656
57- template <int kNumThreadsPerBlock = 512 >
58- static __global__ void topKPerRow (const float * logits, const int * rowStarts,
59- const int * rowEnds, int * outIndices,
60- int stride0, int stride1) {
61- // The number of bins in the histogram.
62- static constexpr int kNumBins = 512 ;
63-
64- // The top-k width.
65- static constexpr int kTopK = 2048 ;
57+ template <int kNumThreadsPerBlock = 512 , int kNumBins = 512 , int kTopK = 2048 >
58+ __device__ void topKPerRowJob (const float * logits, const int rowStart,
59+ const int rowEnd, const int rowIdx,
60+ int * outIndices, int stride0, int stride1) {
6661 // The number of elements per thread for the final top-k sort.
6762 static constexpr int kNumTopKItemsPerThread = kTopK / kNumThreadsPerBlock ;
6863 // The class to sort the elements during the final top-k sort.
@@ -108,10 +103,6 @@ static __global__ void topKPerRow(const float* logits, const int* rowStarts,
108103 // Shared memory counter to register the candidates for the final phase.
109104 __shared__ int smemFinalDstIdx[1 ];
110105
111- // The row computed by this block.
112- int rowIdx = blockIdx .x ;
113- // The range of logits within the row.
114- int rowStart = rowStarts[rowIdx], rowEnd = rowEnds[rowIdx];
115106 // The length of the row.
116107 int rowLen = rowEnd - rowStart;
117108
@@ -260,6 +251,49 @@ static __global__ void topKPerRow(const float* logits, const int* rowStarts,
260251 }
261252}
262253
254+ template <int kNumThreadsPerBlock = 512 >
255+ static __global__ void topKPerRow (const float * logits, const int * rowStarts,
256+ const int * rowEnds, int * outIndices,
257+ int stride0, int stride1) {
258+ // The number of bins in the histogram.
259+ static constexpr int kNumBins = 512 ;
260+
261+ // The top-k width.
262+ static constexpr int kTopK = 2048 ;
263+
264+ // The row computed by this block.
265+ int rowIdx = blockIdx .x ;
266+
267+ // The range of logits within the row.
268+ int rowStart = rowStarts[rowIdx];
269+ int rowEnd = rowEnds[rowIdx];
270+
271+ topKPerRowJob<kNumThreadsPerBlock , kNumBins , kTopK >(
272+ logits, rowStart, rowEnd, rowIdx, outIndices, stride0, stride1);
273+ }
274+
275+ template <int kNumThreadsPerBlock = 512 >
276+ static __global__ void topKPerRowDecode (const float * logits, const int * seqLens,
277+ int * outIndices, int stride0,
278+ int stride1, int next_n) {
279+ // The number of bins in the histogram.
280+ static constexpr int kNumBins = 512 ;
281+
282+ // The top-k width.
283+ static constexpr int kTopK = 2048 ;
284+
285+ // The row computed by this block.
286+ int rowIdx = blockIdx .x ;
287+
288+ // The range of logits within the row.
289+ int rowStart = 0 ;
290+ int seq_len = seqLens[rowIdx / next_n];
291+ int rowEnd = seq_len - next_n + (rowIdx % next_n) + 1 ;
292+
293+ topKPerRowJob<kNumThreadsPerBlock , kNumBins , kTopK >(
294+ logits, rowStart, rowEnd, rowIdx, outIndices, stride0, stride1);
295+ }
296+
263297} // namespace vllm
264298
265299void apply_repetition_penalties_ (
@@ -303,6 +337,20 @@ void apply_repetition_penalties_(
303337 });
304338}
305339
340+ void top_k_per_row_decode (const torch::Tensor& logits, int64_t next_n,
341+ const torch::Tensor& seqLens, torch::Tensor& indices,
342+ int64_t numRows, int64_t stride0, int64_t stride1) {
343+ // Compute the results on the device.
344+ constexpr int kNumThreadsPerBlock = 512 ;
345+ const cudaStream_t stream = at::cuda::getCurrentCUDAStream ();
346+
347+ vllm::topKPerRowDecode<kNumThreadsPerBlock >
348+ <<<numRows, kNumThreadsPerBlock , 0 , stream>>> (
349+ logits.data_ptr <float >(), seqLens.data_ptr <int >(),
350+ indices.data_ptr <int >(), static_cast <int >(stride0),
351+ static_cast <int >(stride1), static_cast <int >(next_n));
352+ }
353+
306354void top_k_per_row (const torch::Tensor& logits, const torch::Tensor& rowStarts,
307355 const torch::Tensor& rowEnds, torch::Tensor& indices,
308356 int64_t numRows, int64_t stride0, int64_t stride1) {
0 commit comments