Skip to content

Commit cf2df82

Browse files
authored
fix: DeepSeek activation uninitialized data (#2128)
1 parent d56be0d commit cf2df82

File tree

1 file changed

+14
-4
lines changed

1 file changed

+14
-4
lines changed

csrc/trtllm_fused_moe_dev_kernel.cu

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -196,14 +196,16 @@ struct KernelTraits<1> {
196196

197197
////////////////////////////////////////////////////////////////////////////////////////////////////
198198

199+
constexpr int DEEP_SEEK_ACTIVATION_NUM_THREADS_PER_CTA = 128;
200+
199201
template <typename KernelParams>
200202
__global__ void activationDeepSeekKernel(KernelParams params) {
201203
using Type = typename KernelParams::Type;
202204
int32_t constexpr NumTokensPerCta = KernelParams::NumTokensPerCta;
203205
using KernelTraits = KernelTraits<NumTokensPerCta>;
204206
using MaxOp = typename KernelTraits::MaxOp;
205207
using PackedType = typename KernelTraits::PackedType;
206-
using BlockReduce = cub::BlockReduce<PackedType, 128>;
208+
using BlockReduce = cub::BlockReduce<PackedType, DEEP_SEEK_ACTIVATION_NUM_THREADS_PER_CTA>;
207209

208210
__shared__ float s_scaleOutArr[NumTokensPerCta];
209211
__shared__ typename BlockReduce::TempStorage tempStorage;
@@ -235,6 +237,15 @@ __global__ void activationDeepSeekKernel(KernelParams params) {
235237
tokenCtaIdx += gridDim.y * NumTokensPerCta) {
236238
for (int hiddenIdx = threadIdx.x + blockDim.x * blockIdx.x; hiddenIdx < params.innerDim / 2;
237239
hiddenIdx += blockDim.x * gridDim.x) {
240+
#pragma unroll
241+
for (int tokenInCtaIdx = 0; tokenInCtaIdx < NumTokensPerCta; tokenInCtaIdx++) {
242+
scale1Arr[tokenInCtaIdx] = 0.0f;
243+
scale2Arr[tokenInCtaIdx] = 0.0f;
244+
dataX1Arr[tokenInCtaIdx] = 0.0f;
245+
dataX2Arr[tokenInCtaIdx] = 0.0f;
246+
outArr[tokenInCtaIdx] = 0.0f;
247+
absOutArr[tokenInCtaIdx] = 0.0f;
248+
}
238249
#pragma unroll
239250
for (int tokenInCtaIdx = 0; tokenInCtaIdx < NumTokensPerCta; tokenInCtaIdx++) {
240251
int const tokenIdx = tokenCtaIdx + tokenInCtaIdx;
@@ -328,7 +339,6 @@ void run(Data const& data, void* stream) {
328339
if (data.mUseDeepSeekFp8) {
329340
constexpr int NUM_ELTS_PER_LOAD = 1;
330341
constexpr int NUM_ELTS_PER_SF = 128;
331-
int const NUM_THREADS_PER_CTA = 128;
332342

333343
int device{-1};
334344
cudaGetDevice(&device);
@@ -355,8 +365,8 @@ void run(Data const& data, void* stream) {
355365

356366
const dim3 grid(gridSizeX, gridSizeY, data.topK);
357367

358-
LAUNCH_ACTIVATION(data, activationDeepSeekKernel, numTokensPerCta, grid, NUM_THREADS_PER_CTA, 0,
359-
stream);
368+
LAUNCH_ACTIVATION(data, activationDeepSeekKernel, numTokensPerCta, grid,
369+
DEEP_SEEK_ACTIVATION_NUM_THREADS_PER_CTA, 0, stream);
360370
} else {
361371
int const numThreads = 256;
362372
const dim3 grid(data.innerDim / 128, data.topK, data.numTokens);

0 commit comments

Comments
 (0)