Skip to content

Commit 97a9e5f

Browse files
Fix fused_feedforward for big tensor (#74362)
1 parent b4a021f commit 97a9e5f

File tree

1 file changed

+12
-10
lines changed

1 file changed

+12
-10
lines changed

paddle/phi/kernels/fusion/gpu/fused_dropout_act_bias.h

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -83,19 +83,19 @@ __global__ void FusedDropoutActBias(
8383
const int quant_round_type = 1,
8484
const float quant_max_bound = 127.0,
8585
const float quant_min_bound = -127.0) {
86-
int col_id = blockDim.x * blockIdx.x + threadIdx.x;
87-
int row_id = blockIdx.y;
88-
int idx = row_id * cols + col_id;
86+
int64_t col_id = static_cast<int64_t>(blockDim.x) * blockIdx.x + threadIdx.x;
87+
int64_t row_id = static_cast<int64_t>(blockIdx.y);
88+
int64_t idx = row_id * cols + col_id;
8989

9090
GPURAND(StatePhilox4_32_10_t) state;
9191
GPURAND(_init)(seed, idx, increment, &state);
9292

9393
const T factor =
9494
phi::fusion::GetFactor<T>(dropout_prob, is_upscale_in_train, is_test);
9595

96-
for (int r = row_id; r < rows; r += blockDim.y * gridDim.y) {
97-
for (int i = col_id * VecSize; i < cols;
98-
i += blockDim.x * gridDim.x * VecSize) {
96+
for (int64_t r = row_id; r < rows; r += blockDim.y * gridDim.y) {
97+
for (int64_t i = col_id * VecSize; i < cols;
98+
i += static_cast<int64_t>(blockDim.x) * gridDim.x * VecSize) {
9999
phi::fusion::FusedResidualDropoutBiasOneThread<T,
100100
MaskType,
101101
VecSize,
@@ -311,12 +311,13 @@ __global__ void FusedDropoutActGrad(Functor act_grad,
311311
const T factor,
312312
const int64_t size,
313313
T *dx) {
314-
int64_t idx = blockDim.x * blockIdx.x + threadIdx.x;
314+
int64_t idx = static_cast<int64_t>(blockDim.x) * blockIdx.x + threadIdx.x;
315315

316316
using LoadT = phi::AlignedVector<T, VecSize>;
317317
using StoreT = phi::AlignedVector<T, VecSize>;
318318
using MaskLoadT = phi::AlignedVector<MaskType, VecSize>;
319-
for (int i = idx * VecSize; i < size; i += blockDim.x * gridDim.x * VecSize) {
319+
for (int64_t i = idx * VecSize; i < size;
320+
i += static_cast<int64_t>(blockDim.x) * gridDim.x * VecSize) {
320321
LoadT dout_vec;
321322
LoadT src_vec;
322323
MaskLoadT mask_vec;
@@ -359,7 +360,7 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void FusedDropoutActBiasGrad(
359360
const int64_t cols,
360361
T *dx,
361362
T *dbias) {
362-
int64_t col_id = blockIdx.x * blockDim.x + threadIdx.x;
363+
int64_t col_id = static_cast<int64_t>(blockIdx.x) * blockDim.x + threadIdx.x;
363364

364365
using LoadT = phi::AlignedVector<T, VecSize>;
365366
using StoreT = phi::AlignedVector<T, VecSize>;
@@ -368,7 +369,8 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void FusedDropoutActBiasGrad(
368369
// calculate the dx and temporary sum
369370
if (col_id * VecSize < cols) {
370371
for (int row_id = threadIdx.y; row_id < rows; row_id += blockDim.y) {
371-
int index = row_id * cols + col_id * VecSize;
372+
int64_t index = static_cast<int64_t>(row_id) * cols +
373+
static_cast<int64_t>(col_id) * VecSize;
372374
LoadT dout_vec;
373375
LoadT src_vec;
374376
LoadT bias_vec;

0 commit comments

Comments
 (0)