Skip to content

Commit b10c64c

Browse files
rasmithRandall SmithProExpertProg
authored
[ROCm][Bugfix][Model] Fix illegal memory access when running qwen3_moe models with rms_norm (Qwen3-235B-A22B, Qwen3-30B-A3B, etc.) (#26192)
Signed-off-by: Randall Smith <ransmith@amd.com> Signed-off-by: Randall Smith <Randall.Smith@amd.com> Signed-off-by: rasmith <Randall.Smith@amd.com> Co-authored-by: Randall Smith <ransmith@amd.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
1 parent 0925b28 commit b10c64c

File tree

1 file changed

+16
-8
lines changed

1 file changed

+16
-8
lines changed

csrc/layernorm_kernels.cu

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -364,18 +364,26 @@ void rms_norm(torch::Tensor& out, // [..., hidden_size]
364364
TORCH_CHECK(weight.is_contiguous());
365365

366366
int hidden_size = input.size(-1);
367-
int num_tokens = input.numel() / hidden_size;
368-
int64_t input_stride = input.stride(-2);
367+
368+
// We cannot just use `input.stride(-2)` if the tensor is not row-major.
369+
// Instead, we use a 2d view to get the second-innermost stride.
370+
// That way the dimensions (except the last one) can be arbitrarily permuted.
371+
torch::Tensor input_view = input.view({-1, hidden_size});
372+
373+
int num_tokens = input_view.numel() / hidden_size;
374+
int64_t input_stride = input_view.stride(-2);
369375

370376
dim3 grid(num_tokens);
371377
dim3 block(std::min(hidden_size, 1024));
372-
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
378+
const at::cuda::OptionalCUDAGuard device_guard(device_of(input_view));
373379
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
374-
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "rms_norm_kernel", [&] {
375-
vllm::rms_norm_kernel<scalar_t><<<grid, block, 0, stream>>>(
376-
out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), input_stride,
377-
weight.data_ptr<scalar_t>(), epsilon, num_tokens, hidden_size);
378-
});
380+
VLLM_DISPATCH_FLOATING_TYPES(
381+
input_view.scalar_type(), "rms_norm_kernel", [&] {
382+
vllm::rms_norm_kernel<scalar_t><<<grid, block, 0, stream>>>(
383+
out.data_ptr<scalar_t>(), input_view.data_ptr<scalar_t>(),
384+
input_stride, weight.data_ptr<scalar_t>(), epsilon, num_tokens,
385+
hidden_size);
386+
});
379387
}
380388

381389
#define LAUNCH_FUSED_ADD_RMS_NORM(width) \

0 commit comments

Comments
 (0)