diff --git a/paddle/phi/kernels/funcs/layer_norm_impl.cu.h b/paddle/phi/kernels/funcs/layer_norm_impl.cu.h index 3c873f5796bd34..4eae698648996b 100644 --- a/paddle/phi/kernels/funcs/layer_norm_impl.cu.h +++ b/paddle/phi/kernels/funcs/layer_norm_impl.cu.h @@ -550,7 +550,9 @@ __inline__ __device__ void cuLoadAddStridedInputs(const int64_t i1_block, for (int k = 0; k < VPT; ++k) { const int i2 = i2_off + k; const int64_t load_idx = i1 * n2 + i2; - const int write_idx = thr_load_row_off * row_stride + thr_load_col_off + k; + const int64_t write_idx = + static_cast(thr_load_row_off) * row_stride + thr_load_col_off + + k; if (i2 < n2) { U curr_input = static_cast(input[load_idx]); U curr_dout = static_cast(dout[load_idx]); diff --git a/paddle/phi/kernels/gpu/layer_norm_kernel.cu b/paddle/phi/kernels/gpu/layer_norm_kernel.cu index 648bb6cee273b2..f621d5ed5b952c 100644 --- a/paddle/phi/kernels/gpu/layer_norm_kernel.cu +++ b/paddle/phi/kernels/gpu/layer_norm_kernel.cu @@ -159,7 +159,7 @@ struct LayerNormDataWriter { temp_dst[j] = static_cast((buffer[i * VecSize + j] - row_mean) * row_inv_var); } - v_dst[threadIdx.x + blockDim.x * i] = temp_dst; + v_dst[threadIdx.x + static_cast(blockDim.x) * i] = temp_dst; } } else { const VecScaleT *__restrict__ v_scale = @@ -168,7 +168,7 @@ struct LayerNormDataWriter { reinterpret_cast(bias); if (valid_scale && valid_bias) { for (int i = 0; i < write_times; ++i) { - int idx = threadIdx.x + blockDim.x * i; + int64_t idx = threadIdx.x + static_cast(blockDim.x) * i; VecT temp_dst; VecScaleT temp_v_scale = v_scale[idx]; VecScaleT temp_v_bias = v_bias[idx]; @@ -184,7 +184,7 @@ struct LayerNormDataWriter { } else { if (valid_scale) { for (int i = 0; i < write_times; ++i) { - int idx = threadIdx.x + blockDim.x * i; + int64_t idx = threadIdx.x + static_cast(blockDim.x) * i; VecT temp_dst; VecScaleT temp_v_scale = v_scale[idx]; #pragma unroll @@ -232,19 +232,19 @@ struct LayerNormDataWriter { if ((!valid_scale) && (!valid_bias)) { if (threadIdx.x < last_tid_idx) { for (int i = 0; i < cols_this_thread; ++i) { - row_dst[threadIdx.x + last_tid_idx * i] = + row_dst[threadIdx.x + static_cast(last_tid_idx) * i] = (buffer[i] - row_mean) * row_inv_var; } } else { for (int i = 0; i < cols_this_thread; ++i) { - row_dst[last_tid_idx * write_times + i] = + row_dst[static_cast(last_tid_idx) * write_times + i] = (buffer[i] - row_mean) * row_inv_var; } } } else if (valid_scale && valid_bias) { if (threadIdx.x < last_tid_idx) { for (int i = 0; i < cols_this_thread; ++i) { - int idx = threadIdx.x + last_tid_idx * i; + int64_t idx = threadIdx.x + static_cast(last_tid_idx) * i; row_dst[idx] = static_cast(static_cast(scale[idx]) * (buffer[i] - row_mean) * row_inv_var + @@ -252,7 +252,7 @@ struct LayerNormDataWriter { } } else { for (int i = 0; i < cols_this_thread; ++i) { - int idx = last_tid_idx * write_times + i; + int64_t idx = static_cast(last_tid_idx) * write_times + i; row_dst[idx] = static_cast(static_cast(scale[idx]) * (buffer[i] - row_mean) * row_inv_var + @@ -263,13 +263,13 @@ struct LayerNormDataWriter { if (valid_scale) { if (threadIdx.x < last_tid_idx) { for (int i = 0; i < cols_this_thread; ++i) { - int idx = threadIdx.x + last_tid_idx * i; + int64_t idx = threadIdx.x + static_cast(last_tid_idx) * i; row_dst[idx] = static_cast(static_cast(scale[idx]) * (buffer[i] - row_mean) * row_inv_var); } } else { for (int i = 0; i < cols_this_thread; ++i) { - int idx = last_tid_idx * write_times + i; + int64_t idx = static_cast(last_tid_idx) * write_times + i; row_dst[idx] = static_cast(static_cast(scale[idx]) * (buffer[i] - row_mean) * row_inv_var); } @@ -277,13 +277,13 @@ struct LayerNormDataWriter { } else { if (threadIdx.x < last_tid_idx) { for (int i = 0; i < cols_this_thread; ++i) { - int idx = threadIdx.x + last_tid_idx * i; + int64_t idx = threadIdx.x + static_cast(last_tid_idx) * i; row_dst[idx] = static_cast((buffer[i] - row_mean) * row_inv_var + static_cast(bias[idx])); } } else { for (int i = 0; i < cols_this_thread; ++i) { - int idx = last_tid_idx * write_times + i; + int64_t idx = static_cast(last_tid_idx) * write_times + i; row_dst[idx] = static_cast((buffer[i] - row_mean) * row_inv_var + static_cast(bias[idx])); }