Skip to content

Commit b771dce

Browse files
authored
Fix possible big tensor problem in fused_transpose_split_quant (#74964)
1 parent a492585 commit b771dce

File tree

1 file changed

+12
-9
lines changed

1 file changed

+12
-9
lines changed

paddle/phi/kernels/fusion/gpu/fused_transpose_split_quant_kernel.cu

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,8 @@ __device__ void BlockLoad(const InT* input,
4242
const uint32_t local_off_M = threadIdx.y + i * 16;
4343
const uint32_t off_m = blockIdx.x * 128 + local_off_M;
4444
const uint32_t off_k = blockIdx.y * 128 + threadIdx.x * VecSize;
45-
const size_t offset = off_m * K + off_k;
45+
const size_t offset =
46+
static_cast<size_t>(off_m) * static_cast<size_t>(K) + off_k;
4647

4748
float scale;
4849
if constexpr (need_dequant) {
@@ -53,15 +54,17 @@ __device__ void BlockLoad(const InT* input,
5354

5455
#pragma unroll
5556
for (uint32_t j = 0; j < 4; j += VecSize) {
56-
const size_t idx = offset + j * 32;
57-
using LoadT = VecType<InT, VecSize>;
58-
LoadT data = *reinterpret_cast<const LoadT*>(input + idx);
57+
if (off_k + j * 32 < K) {
58+
const size_t idx = offset + j * 32;
59+
using LoadT = VecType<InT, VecSize>;
60+
LoadT data = *reinterpret_cast<const LoadT*>(input + idx);
5961
#pragma unroll
60-
for (uint32_t k = 0; k < VecSize; k++) {
61-
if constexpr (need_dequant) {
62-
x[i][j + k] = __float2bfloat16(static_cast<float>(data[k]) * scale);
63-
} else {
64-
x[i][j + k] = (*reinterpret_cast<__nv_bfloat16*>(&data[k]));
62+
for (uint32_t k = 0; k < VecSize; k++) {
63+
if constexpr (need_dequant) {
64+
x[i][j + k] = __float2bfloat16(static_cast<float>(data[k]) * scale);
65+
} else {
66+
x[i][j + k] = (*reinterpret_cast<__nv_bfloat16*>(&data[k]));
67+
}
6568
}
6669
}
6770
}

0 commit comments

Comments
 (0)