diff --git a/paddle/phi/kernels/funcs/weight_only_gemv.cu b/paddle/phi/kernels/funcs/weight_only_gemv.cu index 3efa196adef15b..aeccf6f2370cd3 100644 --- a/paddle/phi/kernels/funcs/weight_only_gemv.cu +++ b/paddle/phi/kernels/funcs/weight_only_gemv.cu @@ -729,7 +729,7 @@ __global__ void weight_only_batched_gemv_multi_warp(const int8_t* qweight, *reinterpret_cast(in_v + y), v); } - accumulator[b] += v.x + v.y; + accumulator[b] = accumulator[b] + static_cast(v.x + v.y); } else { #pragma unroll for (int x = 0; x < NPerBlock / 2; ++x) {