Skip to content

Commit

Permalink
fix overflow in softmax_kernel when process long seqlen and big batch…
Browse files Browse the repository at this point in the history
…_size (#524)
  • Loading branch information
zhangxin81 authored Apr 19, 2023
1 parent c6ba315 commit adb21c3
Showing 1 changed file with 9 additions and 9 deletions.
18 changes: 9 additions & 9 deletions src/fastertransformer/kernels/unfused_attention_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -268,23 +268,23 @@ __global__ void softmax_kernel(T* attn_score,
// attn_mask, [batch_size, q_length, k_length]
// linear_bias_slopes, [num_heads]

const int bi = blockIdx.y; // Batch index.
const int hi = blockIdx.z; // Head index.
const int64_t bi = blockIdx.y; // Batch index.
const int64_t hi = blockIdx.z; // Head index.

__shared__ float s_mean, s_max;

const float linear_bias_slope = linear_bias_slopes != nullptr ? (float)linear_bias_slopes[hi] : 0.0f;

// Loop along with Q dimension.
for (int qi = blockIdx.x; qi < q_length; qi += gridDim.x) {
for (int64_t qi = blockIdx.x; qi < q_length; qi += gridDim.x) {

float data[ITEMS_PER_THREAD];
int qk_offset;
int64_t qk_offset;
float local_max = -1e20f;

// Loop along with K dimension.
for (int i = 0; blockDim.x * i + threadIdx.x < k_length; i++) {
int ki = blockDim.x * i + threadIdx.x; // Index of K dimension.
for (int64_t i = 0; blockDim.x * i + threadIdx.x < k_length; i++) {
int64_t ki = blockDim.x * i + threadIdx.x; // Index of K dimension.
qk_offset = ((bi * head_num + hi) * q_length + qi) * k_length + ki;

float qk_val = static_cast<float>(qk[qk_offset]);
Expand All @@ -297,7 +297,7 @@ __global__ void softmax_kernel(T* attn_score,
qk_bias += static_cast<float>(linear_bias_slope * (ki - qi));
}

int mask_offset = (bi * q_length + qi) * k_length + ki;
int64_t mask_offset = (bi * q_length + qi) * k_length + ki;
float mask_val = static_cast<float>(ldg(&attn_mask[mask_offset]));
qk_bias += (1.0f - mask_val) * -10000.0f;

Expand All @@ -312,7 +312,7 @@ __global__ void softmax_kernel(T* attn_score,
__syncthreads();

float local_sum = 0;
for (int i = 0; blockDim.x * i + threadIdx.x < k_length; i++) {
for (int64_t i = 0; blockDim.x * i + threadIdx.x < k_length; i++) {
data[i] = __expf(data[i] - s_max);
local_sum += data[i];
}
Expand All @@ -324,7 +324,7 @@ __global__ void softmax_kernel(T* attn_score,
}
__syncthreads();

for (int i = 0; blockDim.x * i + threadIdx.x < k_length; i++) {
for (int64_t i = 0; blockDim.x * i + threadIdx.x < k_length; i++) {
qk_offset = ((bi * head_num + hi) * q_length + qi) * k_length + blockDim.x * i + threadIdx.x;
attn_score[qk_offset] = (T)(data[i] * s_mean);
}
Expand Down

0 comments on commit adb21c3

Please sign in to comment.