@@ -44,6 +44,245 @@ __global__ void apply_repetition_penalties_kernel(
4444  }
4545}
4646
47+ static  inline  __device__  uint16_t  extractBinIdx (float  x) {
48+   union  {
49+     __half h;
50+     uint16_t  u16 ;
51+   } tmp;
52+   tmp.h  = __float2half_rn (x);
53+   tmp.u16  = (x < 0 .f ) ? (~tmp.u16  & 0xffff ) : (tmp.u16  | 0x8000 );
54+   return  511  - (tmp.u16  >> 7 );
55+ }
56+ 
57+ template  <int  kNumThreadsPerBlock  = 512 >
58+ static  __global__  void  topKPerRow (const  float * logits, const  int * rowStarts,
59+                                   const  int * rowEnds, int * outIndices,
60+                                   float * outLogits, int  stride0, int  stride1) {
61+   //  The number of bins in the histogram.
62+   static  constexpr  int  kNumBins  = 512 ;
63+ 
64+   //  The top-k width.
65+   static  constexpr  int  kTopK  = 2048 ;
66+   //  The number of elements per thread for the final top-k sort.
67+   static  constexpr  int  kNumTopKItemsPerThread  = kTopK  / kNumThreadsPerBlock ;
68+   //  The class to sort the elements during the final top-k sort.
69+   using  TopKSort = cub::BlockRadixSort<float , kNumThreadsPerBlock ,
70+                                        kNumTopKItemsPerThread , int >;
71+ 
72+   //  The number of slots for the final pass.
73+   static  constexpr  int  kNumFinalItems  = 3072 ;
74+   //  The number of elements per thread for the final sort.
75+   static  constexpr  int  kNumFinalItemsPerThread  =
76+       kNumFinalItems  / kNumThreadsPerBlock ;
77+   //  The class to sort the elements during the final pass.
78+   using  FinalSort = cub::BlockRadixSort<float , kNumThreadsPerBlock ,
79+                                         kNumFinalItemsPerThread , int >;
80+ 
81+   //  The class to compute the inclusive prefix-sum over the histogram.
82+   using  Scan = cub::BlockScan<int , kNumThreadsPerBlock >;
83+ 
84+   //  Shared memory to compute the block scan.
85+   __shared__  typename  Scan::TempStorage smemScan;
86+ 
87+   //  The structure to store the final items (for the final pass).
88+   struct  FinalItems  {
89+     //  Shared memory to store the indices for the final pass.
90+     int  indices[kNumFinalItems ];
91+     //  Shared memory to store the logits for the final pass.
92+     float  logits[kNumFinalItems ];
93+   };
94+ 
95+   //  Shared memory to compute the block sort.
96+   __shared__  union  {
97+     FinalItems items;
98+     typename  FinalSort::TempStorage finalSort;
99+     typename  TopKSort::TempStorage topKSort;
100+   } smemFinal;
101+ 
102+   //  Shared memory to store the histogram.
103+   __shared__  int  smemHistogram[kNumBins ];
104+   //  Shared memory to store the selected indices.
105+   __shared__  int  smemIndices[kTopK ];
106+   //  Shared memory to store the selected logits.
107+   __shared__  float  smemLogits[kTopK ];
108+   //  Shared memory to store the threshold bin.
109+   __shared__  int  smemThresholdBinIdx[1 ];
110+   //  Shared memory counter to register the candidates for the final phase.
111+   __shared__  int  smemFinalDstIdx[1 ];
112+ 
113+   //  The row computed by this block.
114+   int  rowIdx = blockIdx .x ;
115+   //  The range of logits within the row.
116+   int  rowStart = rowStarts[rowIdx], rowEnd = rowEnds[rowIdx];
117+   //  The length of the row.
118+   int  rowLen = rowEnd - rowStart;
119+ 
120+   //  Shortcut if the length of the row is smaller than Top-K. Indices are not
121+   //  sorted by their corresponding logit.
122+   if  (rowLen <= kTopK ) {
123+     for  (int  rowIt = threadIdx .x ; rowIt < rowLen;
124+          rowIt += kNumThreadsPerBlock ) {
125+       int  idx = rowStart + rowIt;
126+       outIndices[rowIdx * kTopK  + rowIt] = idx - rowStart;
127+       outLogits[rowIdx * kTopK  + rowIt] =
128+           logits[rowIdx * stride0 + idx * stride1];
129+     }
130+     for  (int  rowIt = rowLen + threadIdx .x ; rowIt < kTopK ;
131+          rowIt += kNumThreadsPerBlock ) {
132+       outIndices[rowIdx * kTopK  + rowIt] = -1 ;
133+       outLogits[rowIdx * kTopK  + rowIt] = -FLT_MAX;
134+     }
135+     return ;
136+   }
137+ 
138+   //  Clear the histogram.
139+   if  (threadIdx .x  < kNumBins ) {
140+     smemHistogram[threadIdx .x ] = 0 ;
141+   }
142+ 
143+   //  Make sure the histogram is ready.
144+   __syncthreads ();
145+ 
146+   //  Fetch elements one-by-one.
147+   for  (int  rowIt = rowStart + threadIdx .x ; rowIt < rowEnd;
148+        rowIt += kNumThreadsPerBlock ) {
149+     uint16_t  idx = extractBinIdx (logits[rowIdx * stride0 + rowIt * stride1]);
150+     atomicAdd (&smemHistogram[idx], 1 );
151+   }
152+ 
153+   //  Make sure the histogram is ready.
154+   __syncthreads ();
155+ 
156+   //  Read the values from SMEM.
157+   int  binCount{0 };
158+   if  (threadIdx .x  < kNumBins ) {
159+     binCount = smemHistogram[threadIdx .x ];
160+   }
161+ 
162+   //  Make sure each thread has read its value.
163+   __syncthreads ();
164+ 
165+   //  Compute the prefix sum.
166+   int  prefixSum{0 }, totalSum{0 };
167+   Scan (smemScan).ExclusiveSum (binCount, prefixSum, totalSum);
168+ 
169+   //  Update the histogram with the prefix sums.
170+   if  (threadIdx .x  < kNumBins ) {
171+     smemHistogram[threadIdx .x ] = prefixSum;
172+   }
173+ 
174+   //  Make sure the data is in shared memory.
175+   __syncthreads ();
176+ 
177+   //  Find the last valid bin.
178+   if  (threadIdx .x  < kNumBins ) {
179+     int  nextPrefixSum =
180+         threadIdx .x  == kNumBins  - 1  ? totalSum : smemHistogram[threadIdx .x  + 1 ];
181+     if  (prefixSum < kTopK  && nextPrefixSum >= kTopK ) {
182+       smemThresholdBinIdx[0 ] = threadIdx .x ;
183+     }
184+   }
185+ 
186+   //  Clear the counter to store the items for the final phase.
187+   if  (threadIdx .x  == 0 ) {
188+     smemFinalDstIdx[0 ] = 0 ;
189+   }
190+ 
191+   //  Make sure the data is in shared memory.
192+   __syncthreads ();
193+ 
194+   //  The threshold bin.
195+   int  thresholdBinIdx = smemThresholdBinIdx[0 ];
196+ 
197+   //  Fetch elements one-by-one and populate the shared memory buffers.
198+   for  (int  rowIt = rowStart + threadIdx .x ; rowIt < rowEnd;
199+        rowIt += kNumThreadsPerBlock ) {
200+     float  logit = logits[rowIdx * stride0 + rowIt * stride1];
201+     uint16_t  idx = extractBinIdx (logit);
202+     if  (idx < thresholdBinIdx) {
203+       int  dstIdx = atomicAdd (&smemHistogram[idx], 1 );
204+       smemLogits[dstIdx] = logit;
205+       smemIndices[dstIdx] = rowIt;
206+     } else  if  (idx == thresholdBinIdx) {
207+       int  dstIdx = atomicAdd (&smemFinalDstIdx[0 ], 1 );
208+       if  (dstIdx < kNumFinalItems ) {
209+         smemFinal.items .logits [dstIdx] = logit;
210+         smemFinal.items .indices [dstIdx] = rowIt;
211+       }
212+     }
213+   }
214+ 
215+   //  Make sure the elements are in shared memory.
216+   __syncthreads ();
217+ 
218+   //  The logits of the elements to be sorted in the final pass.
219+   float  finalLogits[kNumFinalItemsPerThread ];
220+   //  The indices of the elements to be sorted in the final pass.
221+   int  finalIndices[kNumFinalItemsPerThread ];
222+ 
223+ //  Init.
224+ #pragma  unroll
225+   for  (int  ii = 0 ; ii < kNumFinalItemsPerThread ; ++ii) {
226+     finalLogits[ii] = -FLT_MAX;
227+   }
228+ 
229+ //  Read the elements from SMEM.
230+ #pragma  unroll
231+   for  (int  ii = 0 ; ii < kNumFinalItemsPerThread ; ++ii) {
232+     int  srcIdx = ii * kNumThreadsPerBlock  + threadIdx .x ;
233+     if  (srcIdx < smemFinalDstIdx[0 ]) {
234+       finalLogits[ii] = smemFinal.items .logits [srcIdx];
235+       finalIndices[ii] = smemFinal.items .indices [srcIdx];
236+     }
237+   }
238+ 
239+   //  Make sure the shared memory has been read.
240+   __syncthreads ();
241+ 
242+   //  Sort the elements.
243+   FinalSort (smemFinal.finalSort )
244+       .SortDescendingBlockedToStriped (finalLogits, finalIndices);
245+ 
246+   //  Copy the data back to the shared memory storage.
247+   int  baseIdx = thresholdBinIdx > 0  ? smemHistogram[thresholdBinIdx - 1 ] : 0 ;
248+ #pragma  unroll
249+   for  (int  ii = 0 ; ii < kNumFinalItemsPerThread ; ++ii) {
250+     int  srcIdx = ii * kNumThreadsPerBlock  + threadIdx .x ;
251+     int  dstIdx = baseIdx + srcIdx;
252+     if  (dstIdx < kTopK ) {
253+       smemLogits[dstIdx] = finalLogits[ii];
254+       smemIndices[dstIdx] = finalIndices[ii];
255+     }
256+   }
257+ 
258+   //  Make sure the data is in shared memory.
259+   __syncthreads ();
260+ 
261+   //  The topK logits.
262+   float  topKLogits[kNumTopKItemsPerThread ];
263+   //  The topK indices.
264+   int  topKIndices[kNumTopKItemsPerThread ];
265+ 
266+ //  Load from shared memory.
267+ #pragma  unroll
268+   for  (int  ii = 0 ; ii < kNumTopKItemsPerThread ; ++ii) {
269+     topKLogits[ii] = smemLogits[ii * kNumThreadsPerBlock  + threadIdx .x ];
270+     topKIndices[ii] = smemIndices[ii * kNumThreadsPerBlock  + threadIdx .x ];
271+   }
272+ 
273+   //  Sort the elements.
274+   TopKSort (smemFinal.topKSort )
275+       .SortDescendingBlockedToStriped (topKLogits, topKIndices);
276+ 
277+ //  Store to global memory.
278+ #pragma  unroll
279+   for  (int  ii = 0 ; ii < kNumTopKItemsPerThread ; ++ii) {
280+     int  offset = rowIdx * kTopK  + ii * kNumThreadsPerBlock  + threadIdx .x ;
281+     outIndices[offset] = topKIndices[ii] - rowStart;
282+     outLogits[offset] = topKLogits[ii];
283+   }
284+ }
285+ 
47286}  //  namespace vllm
48287
49288void  apply_repetition_penalties_ (
@@ -85,4 +324,20 @@ void apply_repetition_penalties_(
85324                repetition_penalties.data_ptr <scalar_t >(), num_seqs, vocab_size,
86325                tile_size);
87326      });
88- }
327+ }
328+ 
329+ void  top_k_per_row (const  torch::Tensor& logits, const  torch::Tensor& rowStarts,
330+                    const  torch::Tensor& rowEnds, torch::Tensor& indices,
331+                    torch::Tensor& values, int64_t  numRows, int64_t  stride0,
332+                    int64_t  stride1) {
333+   //  Compute the results on the device.
334+   constexpr  int  kNumThreadsPerBlock  = 512 ;
335+   const  cudaStream_t stream = at::cuda::getCurrentCUDAStream ();
336+ 
337+   vllm::topKPerRow<kNumThreadsPerBlock >
338+       <<<numRows, kNumThreadsPerBlock , 0 , stream>>> (
339+           logits.data_ptr <float >(), rowStarts.data_ptr <int >(),
340+           rowEnds.data_ptr <int >(), indices.data_ptr <int >(),
341+           values.data_ptr <float >(), static_cast <int >(stride0),
342+           static_cast <int >(stride1));
343+ }
0 commit comments