From 3d409575f0613e91fb5c905f1262865eea4c4ed1 Mon Sep 17 00:00:00 2001 From: Pan Zhaowu Date: Fri, 29 Aug 2025 00:16:13 +0800 Subject: [PATCH] fix possible big tensor problem --- .../gpu/fused_transpose_split_quant_kernel.cu | 21 +++++++++++-------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/paddle/phi/kernels/fusion/gpu/fused_transpose_split_quant_kernel.cu b/paddle/phi/kernels/fusion/gpu/fused_transpose_split_quant_kernel.cu index 16503aa32f263d..e1d122833ff7fe 100644 --- a/paddle/phi/kernels/fusion/gpu/fused_transpose_split_quant_kernel.cu +++ b/paddle/phi/kernels/fusion/gpu/fused_transpose_split_quant_kernel.cu @@ -42,7 +42,8 @@ __device__ void BlockLoad(const InT* input, const uint32_t local_off_M = threadIdx.y + i * 16; const uint32_t off_m = blockIdx.x * 128 + local_off_M; const uint32_t off_k = blockIdx.y * 128 + threadIdx.x * VecSize; - const size_t offset = off_m * K + off_k; + const size_t offset = + static_cast(off_m) * static_cast(K) + off_k; float scale; if constexpr (need_dequant) { @@ -53,15 +54,17 @@ __device__ void BlockLoad(const InT* input, #pragma unroll for (uint32_t j = 0; j < 4; j += VecSize) { - const size_t idx = offset + j * 32; - using LoadT = VecType; - LoadT data = *reinterpret_cast(input + idx); + if (off_k + j * 32 < K) { + const size_t idx = offset + j * 32; + using LoadT = VecType; + LoadT data = *reinterpret_cast(input + idx); #pragma unroll - for (uint32_t k = 0; k < VecSize; k++) { - if constexpr (need_dequant) { - x[i][j + k] = __float2bfloat16(static_cast(data[k]) * scale); - } else { - x[i][j + k] = (*reinterpret_cast<__nv_bfloat16*>(&data[k])); + for (uint32_t k = 0; k < VecSize; k++) { + if constexpr (need_dequant) { + x[i][j + k] = __float2bfloat16(static_cast(data[k]) * scale); + } else { + x[i][j + k] = (*reinterpret_cast<__nv_bfloat16*>(&data[k])); + } } } }