@@ -196,14 +196,16 @@ struct KernelTraits<1> {
196196
197197// //////////////////////////////////////////////////////////////////////////////////////////////////
198198
199+ constexpr int DEEP_SEEK_ACTIVATION_NUM_THREADS_PER_CTA = 128 ;
200+
199201template <typename KernelParams>
200202__global__ void activationDeepSeekKernel (KernelParams params) {
201203 using Type = typename KernelParams::Type;
202204 int32_t constexpr NumTokensPerCta = KernelParams::NumTokensPerCta;
203205 using KernelTraits = KernelTraits<NumTokensPerCta>;
204206 using MaxOp = typename KernelTraits::MaxOp;
205207 using PackedType = typename KernelTraits::PackedType;
206- using BlockReduce = cub::BlockReduce<PackedType, 128 >;
208+ using BlockReduce = cub::BlockReduce<PackedType, DEEP_SEEK_ACTIVATION_NUM_THREADS_PER_CTA >;
207209
208210 __shared__ float s_scaleOutArr[NumTokensPerCta];
209211 __shared__ typename BlockReduce::TempStorage tempStorage;
@@ -235,6 +237,15 @@ __global__ void activationDeepSeekKernel(KernelParams params) {
235237 tokenCtaIdx += gridDim .y * NumTokensPerCta) {
236238 for (int hiddenIdx = threadIdx .x + blockDim .x * blockIdx .x ; hiddenIdx < params.innerDim / 2 ;
237239 hiddenIdx += blockDim .x * gridDim .x ) {
240+ #pragma unroll
241+ for (int tokenInCtaIdx = 0 ; tokenInCtaIdx < NumTokensPerCta; tokenInCtaIdx++) {
242+ scale1Arr[tokenInCtaIdx] = 0 .0f ;
243+ scale2Arr[tokenInCtaIdx] = 0 .0f ;
244+ dataX1Arr[tokenInCtaIdx] = 0 .0f ;
245+ dataX2Arr[tokenInCtaIdx] = 0 .0f ;
246+ outArr[tokenInCtaIdx] = 0 .0f ;
247+ absOutArr[tokenInCtaIdx] = 0 .0f ;
248+ }
238249#pragma unroll
239250 for (int tokenInCtaIdx = 0 ; tokenInCtaIdx < NumTokensPerCta; tokenInCtaIdx++) {
240251 int const tokenIdx = tokenCtaIdx + tokenInCtaIdx;
@@ -328,7 +339,6 @@ void run(Data const& data, void* stream) {
328339 if (data.mUseDeepSeekFp8 ) {
329340 constexpr int NUM_ELTS_PER_LOAD = 1 ;
330341 constexpr int NUM_ELTS_PER_SF = 128 ;
331- int const NUM_THREADS_PER_CTA = 128 ;
332342
333343 int device{-1 };
334344 cudaGetDevice (&device);
@@ -355,8 +365,8 @@ void run(Data const& data, void* stream) {
355365
356366 const dim3 grid (gridSizeX, gridSizeY, data.topK );
357367
358- LAUNCH_ACTIVATION (data, activationDeepSeekKernel, numTokensPerCta, grid, NUM_THREADS_PER_CTA, 0 ,
359- stream);
368+ LAUNCH_ACTIVATION (data, activationDeepSeekKernel, numTokensPerCta, grid,
369+ DEEP_SEEK_ACTIVATION_NUM_THREADS_PER_CTA, 0 , stream);
360370 } else {
361371 int const numThreads = 256 ;
362372 const dim3 grid (data.innerDim / 128 , data.topK , data.numTokens );
0 commit comments