diff --git a/paddle/fluid/operators/fused/fused_layernorm_residual_dropout_bias.h b/paddle/fluid/operators/fused/fused_layernorm_residual_dropout_bias.h index d53a24a57e3cc..aa613dd3f5ce0 100644 --- a/paddle/fluid/operators/fused/fused_layernorm_residual_dropout_bias.h +++ b/paddle/fluid/operators/fused/fused_layernorm_residual_dropout_bias.h @@ -156,9 +156,9 @@ __global__ void FusedLayernormResidualDropoutBias( } /* -* @brief layernorm(residual + dropout(x)); + * @brief layernorm(residual + dropout(x)); * Conditions: - * (1) The number of cols is 1024; + * (1) The number of cols is 768/1024/4096; * (2) layer_norm scale and bias is not null; * (3) linear bias is null; * @param @@ -166,6 +166,7 @@ __global__ void FusedLayernormResidualDropoutBias( * cols: 1024 * x_: [rows, cols], inputs * residual_:[rows, cols] + * bias_: [cols], linear bias, can be null * gamma_: [cols]: layernorm scale, not null * beta_: [cols], layernorm bias, not null * mask_out_: [rows, cols], dropout result @@ -173,7 +174,7 @@ __global__ void FusedLayernormResidualDropoutBias( * y_: [rows, cols], layernorm result * mean_out_: [rows]: layernorm means * var_out_: [rows]: layernorm vars -*/ + */ template < typename T, typename U, typename ScaleT = U, typename MaskType = uint8_t, int VecSize = 8, int WARPS_M = 4, int WARPS_N = 1, int BYTES_PER_LDG = 16, @@ -182,14 +183,16 @@ template < int THREADS_PER_CTA = WARPS_M *THREADS_PER_ROW, int ROWS_PER_CTA = WARPS_M, int ELTS_PER_ROW_PER_CTA = THREADS_PER_ROW *VecSize, int LDGS = ELTS_PER_ROW / ELTS_PER_ROW_PER_CTA> -__global__ __launch_bounds__(THREADS_PER_CTA) void fused_ln_fwd_1024_kernel( +__global__ __launch_bounds__(THREADS_PER_CTA) void fused_fast_ln_fwd_kernel( int rows, int cols, uint64_t seed, const float dropout_prob, const bool is_upscale_in_train, const bool is_test, const uint64_t increment, const float epsilon, const T *__restrict__ x_ptr, - const T *__restrict__ residual_ptr, const ScaleT *__restrict__ gamma_ptr, - const ScaleT *__restrict__ beta_ptr, MaskType *__restrict__ mask_out_ptr, - U *__restrict__ mean_out_ptr, U *__restrict__ var_out_ptr, - T *__restrict__ residual_out_ptr, T *__restrict__ y_ptr) { + const T *__restrict__ residual_ptr, const T *__restrict__ bias_ptr, + const ScaleT *__restrict__ gamma_ptr, const ScaleT *__restrict__ beta_ptr, + MaskType *__restrict__ mask_out_ptr, U *__restrict__ mean_out_ptr, + U *__restrict__ var_out_ptr, T *__restrict__ residual_out_ptr, + T *__restrict__ y_ptr) { + __shared__ U smem[WARPS_M * WARPS_N]; using Vec = phi::AlignedVector; using Vec_scale = phi::AlignedVector; using MaskStoreT = phi::AlignedVector; @@ -204,12 +207,22 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fused_ln_fwd_1024_kernel( const int c = warp_n * THREADS_PER_WARP + lane; // lane const int r = bidx * ROWS_PER_CTA + warp_m; // row id - int idx = r * LN_NUM_COLS + c; + int idx = r * ELTS_PER_ROW + c; curandStatePhilox4_32_10_t state; curand_init(seed, idx, increment, &state); T factor = GetFactor(dropout_prob, is_upscale_in_train, is_test); + // bias + Vec bias[LDGS]; + if (bias_ptr != nullptr) { +#pragma unroll + for (int it = 0, col = c; it < LDGS; it++) { + phi::Load(bias_ptr + col * VecSize, &bias[it]); + col += THREADS_PER_ROW; + } + } + Vec_scale gamma[LDGS]; Vec_scale beta[LDGS]; #pragma unroll @@ -219,14 +232,14 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fused_ln_fwd_1024_kernel( col += THREADS_PER_ROW; } - constexpr U rn = 1.f / U(LN_NUM_COLS); + constexpr U rn = 1.f / U(ELTS_PER_ROW); for (int row = r; row < rows; row += gridDim.x * ROWS_PER_CTA) { Vec x[LDGS]; Vec residual[LDGS]; #pragma unroll for (int it = 0, col = c; it < LDGS; it++) { - phi::Load(x_ptr + row * LN_NUM_COLS + col * VecSize, &x[it]); - phi::Load(residual_ptr + row * LN_NUM_COLS + col * VecSize, + phi::Load(x_ptr + row * ELTS_PER_ROW + col * VecSize, &x[it]); + phi::Load(residual_ptr + row * ELTS_PER_ROW + col * VecSize, &residual[it]); col += THREADS_PER_ROW; } @@ -255,14 +268,28 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fused_ln_fwd_1024_kernel( // 4 * 8 U xf[LDGS * VecSize]; + if (bias_ptr != nullptr) { #pragma unroll - for (int it = 0; it < LDGS; it++) { + for (int it = 0; it < LDGS; it++) { #pragma unroll - for (int jt = 0; jt < VecSize; jt++) { - // dropout(x) + residual - x[it][jt] = x[it][jt] * static_cast(mask_vec[it][jt]) * factor + - residual[it][jt]; - xf[it * VecSize + jt] = U(x[it][jt]); + for (int jt = 0; jt < VecSize; jt++) { + // dropout(x) + residual + x[it][jt] = (x[it][jt] + bias[it][jt]) * + static_cast(mask_vec[it][jt]) * factor + + residual[it][jt]; + xf[it * VecSize + jt] = U(x[it][jt]); + } + } + } else { +#pragma unroll + for (int it = 0; it < LDGS; it++) { +#pragma unroll + for (int jt = 0; jt < VecSize; jt++) { + // dropout(x) + residual + x[it][jt] = x[it][jt] * static_cast(mask_vec[it][jt]) * factor + + residual[it][jt]; + xf[it * VecSize + jt] = U(x[it][jt]); + } } } @@ -270,9 +297,9 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fused_ln_fwd_1024_kernel( #pragma unroll for (int it = 0, col = c; it < LDGS; it++) { phi::Store( - x[it], residual_out_ptr + row * LN_NUM_COLS + col * VecSize); + x[it], residual_out_ptr + row * ELTS_PER_ROW + col * VecSize); phi::Store( - mask_vec[it], mask_out_ptr + row * LN_NUM_COLS + col * VecSize); + mask_vec[it], mask_out_ptr + row * ELTS_PER_ROW + col * VecSize); col += THREADS_PER_ROW; } @@ -289,6 +316,22 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fused_ln_fwd_1024_kernel( for (int it = 1; it < THREADS_PER_WARP; it *= 2) { mu_local += __shfl_xor_sync(uint32_t(-1), mu_local, it); } + if (WARPS_N > 1) { + if (lane == 0) { + smem[warp_m * WARPS_N + warp_n] = mu_local; + } + __syncthreads(); + if (tidx == 0) { + mu_local = 0.f; +#pragma unroll + for (int it = 0; it < WARPS_N; ++it) { + mu_local += smem[warp_m * WARPS_N + it]; + } + smem[warp_m] = mu_local; + } + __syncthreads(); + mu_local = smem[warp_m]; + } mu_local *= rn; if (lane == 0) { mean_out_ptr[row] = mu_local; @@ -308,6 +351,22 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fused_ln_fwd_1024_kernel( for (int it = 1; it < THREADS_PER_WARP; it *= 2) { var_local += __shfl_xor_sync(uint32_t(-1), var_local, it); } + if (WARPS_N > 1) { + if (lane == 0) { + smem[warp_m * WARPS_N + warp_n] = var_local; + } + __syncthreads(); + if (tidx == 0) { + var_local = 0.f; +#pragma unroll + for (int it = 0; it < WARPS_N; ++it) { + var_local += smem[warp_m * WARPS_N + it]; + } + smem[warp_m] = var_local; + } + __syncthreads(); + var_local = smem[warp_m]; + } U rsigma = rsqrtf(var_local * rn + epsilon); if (lane == 0) { // Note: the stored var is different for paddle(ln) and apex (fast ln). @@ -332,7 +391,7 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fused_ln_fwd_1024_kernel( #pragma unroll for (int it = 0, col = c; it < LDGS; it++) { - phi::Store(x[it], y_ptr + row * LN_NUM_COLS + col * VecSize); + phi::Store(x[it], y_ptr + row * ELTS_PER_ROW + col * VecSize); col += THREADS_PER_ROW; } } @@ -390,12 +449,37 @@ void LaunchLayernormResidualDropoutBias( return; } - bool can_call_1024_kernel = false; - if (cols == 1024 && scale != nullptr && layernorm_bias != nullptr && - bias == nullptr) { - can_call_1024_kernel = true; +#define LAUNCH_FUSED_FAST_LN_KERNEL_BASE(cols) \ + case (cols): { \ + constexpr int WARPS_N = cols < 1024 ? 1 : (cols / 1024); \ + constexpr int WARPS_M = 4 / WARPS_N; \ + const int THREADS_PER_WARP = 32; \ + const int BYTES_PER_LDG = 16; \ + const int VecSize = BYTES_PER_LDG / sizeof(T); \ + const int THREADS_PER_CTA = WARPS_N * THREADS_PER_WARP * WARPS_M; \ + const int ROWS_PER_CTA = WARPS_M; \ + const int grid = \ + static_cast(std::ceil(rows / static_cast(ROWS_PER_CTA))); \ + fused_fast_ln_fwd_kernel< \ + T, U, LayerNormScaleBiasT, uint8_t, \ + VecSize, WARPS_M, WARPS_N, BYTES_PER_LDG, \ + cols><<>>( \ + rows, cols, seed, dropout_prob, is_upscale_in_train, is_test, \ + increment, epsilon, src, residual, bias, scale, layernorm_bias, \ + mask_data, mean, var, dst, layernorm_dst); \ + } break + +#define LAUNCH_FUSED_FAST_LN_KERNEL \ + LAUNCH_FUSED_FAST_LN_KERNEL_BASE(768); \ + LAUNCH_FUSED_FAST_LN_KERNEL_BASE(1024); \ + LAUNCH_FUSED_FAST_LN_KERNEL_BASE(4096) + + bool can_call_fast_ln_kernel = false; + if ((cols == 768 || cols == 1024 || cols == 4096) && scale != nullptr && + layernorm_bias != nullptr) { + can_call_fast_ln_kernel = true; } - VLOG(6) << "can_call_1024_kernel = " << can_call_1024_kernel; + VLOG(6) << "can_call_fast_ln_kernel = " << can_call_fast_ln_kernel; const int VecSize = MAX_CACHE_BYTES / sizeof(T); if (cols % VecSize != 0) { @@ -407,26 +491,15 @@ void LaunchLayernormResidualDropoutBias( epsilon, src, residual, bias, scale, layernorm_bias, mask_data, dst, layernorm_dst, mean, var); } else { - if (can_call_1024_kernel) { - const int WARPS_M = 4; - const int WARPS_N = 1; - const int THREADS_PER_WARP = 32; - const int BYTES_PER_LDG = 16; - const int VecSize = BYTES_PER_LDG / sizeof(T); - - const int THREADS_PER_CTA = WARPS_N * THREADS_PER_WARP * WARPS_M; - const int ROWS_PER_CTA = WARPS_M; - - // Note: the grid can not exceed max_grid of the gpu. - const int grid = - static_cast(std::ceil(rows / static_cast(ROWS_PER_CTA))); - fused_ln_fwd_1024_kernel< - T, U, LayerNormScaleBiasT, uint8_t, - VecSize, WARPS_M, WARPS_N, - BYTES_PER_LDG><<>>( - rows, cols, seed, dropout_prob, is_upscale_in_train, is_test, - increment, epsilon, src, residual, scale, layernorm_bias, mask_data, - mean, var, dst, layernorm_dst); + if (can_call_fast_ln_kernel) { + switch (cols) { + LAUNCH_FUSED_FAST_LN_KERNEL; + default: + PADDLE_THROW(platform::errors::InvalidArgument( + "Only when column is equal to 768/1024/4096 is supported for " + "now")); + break; + } } else { int blockDim = GetDesiredBlockDim(cols / VecSize); FusedLayernormResidualDropoutBias<