Skip to content

Commit

Permalink
[GPU] Fix sdpa opt accuracy (openvinotoolkit#27262)
Browse files Browse the repository at this point in the history
### Details:
 - Fix accuracy for sdpa_opt 

### Tickets:
 - 154583
  • Loading branch information
yeonbok authored and CuriousPanCake committed Nov 6, 2024
1 parent 7867f87 commit 6b32b4f
Showing 1 changed file with 7 additions and 7 deletions.
14 changes: 7 additions & 7 deletions src/plugins/intel_gpu/src/kernel_selector/cl_kernels/sdpa_opt.cl
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ KERNEL(sdpa_opt)(
// SLM for query inputs
__local INPUT0_TYPE query_local[HEAD_SIZE * TARGET_SEQ_LEN_BLOCK_SIZE];
// SLM for intermediate QK results
__local OUTPUT_TYPE qk_local[SEQ_LEN_PARTITION_SIZE * TARGET_SEQ_LEN_BLOCK_SIZE];
__local SOFTMAX_ACCUMULATOR_TYPE qk_local[SEQ_LEN_PARTITION_SIZE * TARGET_SEQ_LEN_BLOCK_SIZE];
// SLM buffers for SoftMax calculation and qk_max/qk_sums results aggregation across all WG
__local SOFTMAX_ACCUMULATOR_TYPE qk_max_vals[SUBGROUPS_PER_WG * TARGET_SEQ_LEN_BLOCK_SIZE];
__local SOFTMAX_ACCUMULATOR_TYPE qk_sum_vals[SUBGROUPS_PER_WG * TARGET_SEQ_LEN_BLOCK_SIZE];
Expand Down Expand Up @@ -259,7 +259,7 @@ KERNEL(sdpa_opt)(
uint key_offset = INPUT1_GET_INDEX(b_idx, b1_idx, start_partition_idx + seq_len, 0);
#endif

INPUT0_TYPE acc[TARGET_SEQ_LEN_BLOCK_SIZE] = {INPUT0_VAL_ZERO};
SOFTMAX_ACCUMULATOR_TYPE acc[TARGET_SEQ_LEN_BLOCK_SIZE] = {SOFTMAX_ACCUMULATOR_VAL_ZERO};

#if IS_KV_COMPRESSED
const uint comp_offset = GET_COMPRESSION_INDEX(KEY_COMPRESSION_SCALE, b_idx, b1_idx / BROADCAST_GROUP_SIZE, start_partition_idx + seq_len, 0);
Expand Down Expand Up @@ -294,7 +294,7 @@ KERNEL(sdpa_opt)(
}

unroll_for(uint i = 0; i < KEY_BLOCK_SIZE; i++) {
acc[seq_idx] = mad(query_vals_reg[i], key_vals[i], acc[seq_idx]);
acc[seq_idx] = mad(TO_SOFTMAX_ACCUMULATOR_TYPE(query_vals_reg[i]), TO_SOFTMAX_ACCUMULATOR_TYPE(key_vals[i]), acc[seq_idx]);
}

query_offset += HEAD_SIZE;
Expand Down Expand Up @@ -326,7 +326,7 @@ KERNEL(sdpa_opt)(
}

unroll_for(uint i = 0; i < KEY_BLOCK_SIZE; i++) {
acc[seq_idx] = mad(query_vals_reg[i], key_vals[i], acc[seq_idx]);
acc[seq_idx] = mad(TO_SOFTMAX_ACCUMULATOR_TYPE(query_vals_reg[i]), TO_SOFTMAX_ACCUMULATOR_TYPE(key_vals[i]), acc[seq_idx]);
}

query_offset += HEAD_SIZE;
Expand Down Expand Up @@ -358,7 +358,7 @@ KERNEL(sdpa_opt)(
}

unroll_for(uint i = 0; i < KEY_BLOCK_SIZE; i++) {
acc[seq_idx] = mad(query_vals_reg[i], key_vals[i], acc[seq_idx]);
acc[seq_idx] = mad(TO_SOFTMAX_ACCUMULATOR_TYPE(query_vals_reg[i]), TO_SOFTMAX_ACCUMULATOR_TYPE(key_vals[i]), acc[seq_idx]);
}

query_offset += HEAD_SIZE;
Expand Down Expand Up @@ -389,7 +389,7 @@ KERNEL(sdpa_opt)(
query_vals_reg = query_local[query_offset + i * SUBGROUP_SIZE];
}

acc[seq_idx] = mad(query_vals_reg, key_vals, acc[seq_idx]);
acc[seq_idx] = mad(TO_SOFTMAX_ACCUMULATOR_TYPE(query_vals_reg), TO_SOFTMAX_ACCUMULATOR_TYPE(key_vals), acc[seq_idx]);
query_offset += HEAD_SIZE;
}
}
Expand All @@ -405,7 +405,7 @@ KERNEL(sdpa_opt)(
// Wait until all SG finishes their calculations and apply scale and attention mask to the results
barrier(CLK_LOCAL_MEM_FENCE);

INPUT0_TYPE qk_val[TARGET_SEQ_LEN_BLOCK_SIZE];
SOFTMAX_ACCUMULATOR_TYPE qk_val[TARGET_SEQ_LEN_BLOCK_SIZE];
const uint seq_idx_end = 1;
for (uint seq_idx = 0; seq_idx < seq_idx_end; seq_idx++) {
// Iterate over all values QK values in SLM and apply scale and attention mask
Expand Down

0 comments on commit 6b32b4f

Please sign in to comment.