From 6b32b4fda43c5e24ff6bccef8011b6df892c0218 Mon Sep 17 00:00:00 2001 From: Taylor Yeonbok Lee Date: Tue, 29 Oct 2024 11:16:15 -0700 Subject: [PATCH] [GPU] Fix sdpa opt accuracy (#27262) ### Details: - Fix accuracy for sdpa_opt ### Tickets: - 154583 --- .../src/kernel_selector/cl_kernels/sdpa_opt.cl | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/sdpa_opt.cl b/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/sdpa_opt.cl index 8e6be800f37cf0..c114332f393c0e 100644 --- a/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/sdpa_opt.cl +++ b/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/sdpa_opt.cl @@ -190,7 +190,7 @@ KERNEL(sdpa_opt)( // SLM for query inputs __local INPUT0_TYPE query_local[HEAD_SIZE * TARGET_SEQ_LEN_BLOCK_SIZE]; // SLM for intermediate QK results - __local OUTPUT_TYPE qk_local[SEQ_LEN_PARTITION_SIZE * TARGET_SEQ_LEN_BLOCK_SIZE]; + __local SOFTMAX_ACCUMULATOR_TYPE qk_local[SEQ_LEN_PARTITION_SIZE * TARGET_SEQ_LEN_BLOCK_SIZE]; // SLM buffers for SoftMax calculation and qk_max/qk_sums results aggregation across all WG __local SOFTMAX_ACCUMULATOR_TYPE qk_max_vals[SUBGROUPS_PER_WG * TARGET_SEQ_LEN_BLOCK_SIZE]; __local SOFTMAX_ACCUMULATOR_TYPE qk_sum_vals[SUBGROUPS_PER_WG * TARGET_SEQ_LEN_BLOCK_SIZE]; @@ -259,7 +259,7 @@ KERNEL(sdpa_opt)( uint key_offset = INPUT1_GET_INDEX(b_idx, b1_idx, start_partition_idx + seq_len, 0); #endif - INPUT0_TYPE acc[TARGET_SEQ_LEN_BLOCK_SIZE] = {INPUT0_VAL_ZERO}; + SOFTMAX_ACCUMULATOR_TYPE acc[TARGET_SEQ_LEN_BLOCK_SIZE] = {SOFTMAX_ACCUMULATOR_VAL_ZERO}; #if IS_KV_COMPRESSED const uint comp_offset = GET_COMPRESSION_INDEX(KEY_COMPRESSION_SCALE, b_idx, b1_idx / BROADCAST_GROUP_SIZE, start_partition_idx + seq_len, 0); @@ -294,7 +294,7 @@ KERNEL(sdpa_opt)( } unroll_for(uint i = 0; i < KEY_BLOCK_SIZE; i++) { - acc[seq_idx] = mad(query_vals_reg[i], key_vals[i], acc[seq_idx]); + acc[seq_idx] = mad(TO_SOFTMAX_ACCUMULATOR_TYPE(query_vals_reg[i]), TO_SOFTMAX_ACCUMULATOR_TYPE(key_vals[i]), acc[seq_idx]); } query_offset += HEAD_SIZE; @@ -326,7 +326,7 @@ KERNEL(sdpa_opt)( } unroll_for(uint i = 0; i < KEY_BLOCK_SIZE; i++) { - acc[seq_idx] = mad(query_vals_reg[i], key_vals[i], acc[seq_idx]); + acc[seq_idx] = mad(TO_SOFTMAX_ACCUMULATOR_TYPE(query_vals_reg[i]), TO_SOFTMAX_ACCUMULATOR_TYPE(key_vals[i]), acc[seq_idx]); } query_offset += HEAD_SIZE; @@ -358,7 +358,7 @@ KERNEL(sdpa_opt)( } unroll_for(uint i = 0; i < KEY_BLOCK_SIZE; i++) { - acc[seq_idx] = mad(query_vals_reg[i], key_vals[i], acc[seq_idx]); + acc[seq_idx] = mad(TO_SOFTMAX_ACCUMULATOR_TYPE(query_vals_reg[i]), TO_SOFTMAX_ACCUMULATOR_TYPE(key_vals[i]), acc[seq_idx]); } query_offset += HEAD_SIZE; @@ -389,7 +389,7 @@ KERNEL(sdpa_opt)( query_vals_reg = query_local[query_offset + i * SUBGROUP_SIZE]; } - acc[seq_idx] = mad(query_vals_reg, key_vals, acc[seq_idx]); + acc[seq_idx] = mad(TO_SOFTMAX_ACCUMULATOR_TYPE(query_vals_reg), TO_SOFTMAX_ACCUMULATOR_TYPE(key_vals), acc[seq_idx]); query_offset += HEAD_SIZE; } } @@ -405,7 +405,7 @@ KERNEL(sdpa_opt)( // Wait until all SG finishes their calculations and apply scale and attention mask to the results barrier(CLK_LOCAL_MEM_FENCE); - INPUT0_TYPE qk_val[TARGET_SEQ_LEN_BLOCK_SIZE]; + SOFTMAX_ACCUMULATOR_TYPE qk_val[TARGET_SEQ_LEN_BLOCK_SIZE]; const uint seq_idx_end = 1; for (uint seq_idx = 0; seq_idx < seq_idx_end; seq_idx++) { // Iterate over all values QK values in SLM and apply scale and attention mask