From 9f6178eb028d36b3ed1f5985e57b7cf160acf38a Mon Sep 17 00:00:00 2001 From: blzheng Date: Wed, 16 Oct 2024 15:38:03 +0800 Subject: [PATCH] iakv: change attn_weights layout (#3248) --- .../kernels/MaskedMultiHeadAttentionKrnl.cpp | 1064 +++++++++++++---- csrc/cpu/vec/vec512/perf_kernel/add_softmax.h | 13 +- 2 files changed, 818 insertions(+), 259 deletions(-) diff --git a/csrc/cpu/aten/kernels/MaskedMultiHeadAttentionKrnl.cpp b/csrc/cpu/aten/kernels/MaskedMultiHeadAttentionKrnl.cpp index 5ea08eb4a..f17fc3ee6 100644 --- a/csrc/cpu/aten/kernels/MaskedMultiHeadAttentionKrnl.cpp +++ b/csrc/cpu/aten/kernels/MaskedMultiHeadAttentionKrnl.cpp @@ -172,6 +172,29 @@ inline void reduce_head_half( k_cache_start); } } + +inline void reduce_head_half( + const at::Half* q_ptr_start, + int qStrideB, + int64_t kv_head_group_size, + const at::Half* k_ptr_start, + at::Half* attn_w_pos, + int attn_w_stride, + int64_t head_size, + int64_t beam_size) { + for (auto i = 0; i < kv_head_group_size; i++) { + for (auto b = 0; b < beam_size; b++) { + attn_w_pos[i * attn_w_stride + b] = 0; + reduce_head_half( + q_ptr_start + i * head_size + b * qStrideB, + k_ptr_start, + attn_w_pos + i * attn_w_stride + b, + head_size, + false, + nullptr); + } + } +} #endif template @@ -196,6 +219,29 @@ inline void reduce_head( } } +template +inline void reduce_head( + const T* q_ptr_start, + int qStrideB, + int64_t kv_head_group_size, + const T* k_ptr_start, + float* attn_w_pos, + int attn_w_stride, + int64_t head_size, + int64_t beam_size) { + for (auto i = 0; i < kv_head_group_size; i++) { + for (auto b = 0; b < beam_size; b++) { + attn_w_pos[i * attn_w_stride + b] = 0; + reduce_head( + q_ptr_start + i * head_size + b * qStrideB, + k_ptr_start, + attn_w_pos + i * attn_w_stride + b, + head_size, + false, + nullptr); + } + } +} /* *reduce the attention_weights with the value embedding by the dimension of *head_size for every head @@ -247,6 +293,37 @@ inline void mul_attenion_weights_and_value_of_head( } } +template +inline void mul_attenion_weights_and_value_of_head( + float* attn_w, + int attn_w_stride, + const T* v_ptr_start, + T1* attn_out_start, + int attn_out_strideB, + int attn_out_strideH, + int kv_head_group_size, + int64_t head_size, + bool store_value, + T* v_cache_start, + uint8_t* flag_access, + int flag_access_stride, + int64_t beam_size) { + for (auto i = 0; i < kv_head_group_size; i++) { + for (auto b = 0; b < beam_size; b++) { + mul_attenion_weights_and_value_of_head( + attn_w[i * attn_w_stride + b], + v_ptr_start, + attn_out_start + i * attn_out_strideH + b * attn_out_strideB, + head_size, + store_value, + v_cache_start, + flag_access[b * flag_access_stride + i]); + if (flag_access[b * flag_access_stride + i] == 0) + flag_access[b * flag_access_stride + i] = 1; + } + } +} + #if defined(CPU_CAPABILITY_AVX512) template <> inline void mul_attenion_weights_and_value_of_head( @@ -555,6 +632,35 @@ inline void mul_attenion_weights_and_value_of_head_half( flag_access[i] = 1; } } +inline void mul_attenion_weights_and_value_of_head_half( + at::Half* attn_w, + int attn_w_stride, + const at::Half* v_ptr_start, + at::Half* attn_out_start, + int attn_out_strideB, + int attn_out_strideH, + int kv_head_group_size, + int64_t head_size, + bool store_value, + at::Half* v_cache_start, + uint8_t* flag_access, + int flag_access_stride, + int64_t beam_size) { + for (auto i = 0; i < kv_head_group_size; i++) { + for (auto b = 0; b < beam_size; b++) { + mul_attenion_weights_and_value_of_head_half( + attn_w[i * attn_w_stride + b], + v_ptr_start, + attn_out_start + i * attn_out_strideH + b * attn_out_strideB, + head_size, + store_value, + v_cache_start, + flag_access[b * flag_access_stride + i]); + if (flag_access[b * flag_access_stride + i] == 0) + flag_access[b * flag_access_stride + i] = 1; + } + } +} #endif template @@ -633,12 +739,36 @@ scale_dot_product_for_indirect_access_kv_cache( auto bs = query.size(0); auto cur_len = query.size(1); // only process cur_len==1 auto head_num = query.size(2); + auto head_size = query.size(3); + auto b_ptr = beam_idx.data_ptr(); + auto max_cache_size = beam_idx.size(0); + long new_beam_idx[beam_batch][offset + query.size(1) + 1] = {}; + auto prompt_len = b_ptr[(max_cache_size - 2) * beam_batch]; + auto prompt_bs = b_ptr[(max_cache_size - 1) * beam_batch]; + auto beam_size = 1; + if (prompt_bs != 0) { + beam_size = beam_batch / prompt_bs; + } + auto need_update_beam_idx = offset > 0 and beam_size > 1; auto kv_head = key.size(2); auto group_size = head_num / kv_head; - auto head_size = query.size(3); auto seq_len = offset + cur_len; auto kc_token_stride = beam_batch * kv_head * head_size; - auto attn_weights = at::empty({bs, head_num, cur_len, seq_len}, at::kFloat); + at::Tensor attn_weights, attn_weights2; + bool chg_attn_w_layout = false; + auto target_bs = bs; + if (beam_size > 1 && prompt_len <= 2048 && prompt_bs > 20 && + group_size == 1) { + chg_attn_w_layout = true; + attn_weights = at::empty( + {prompt_bs, head_num, cur_len, seq_len, beam_size}, at::kFloat); + attn_weights2 = at::empty( + {prompt_bs, head_num, cur_len, beam_size, seq_len}, at::kFloat); + target_bs = prompt_bs; + } else { + attn_weights = at::empty({bs, head_num, cur_len, seq_len}, at::kFloat); + attn_weights2 = attn_weights; + } query = query.contiguous(); key = key.contiguous(); auto q_ptr = query.data_ptr(); @@ -656,6 +786,7 @@ scale_dot_product_for_indirect_access_kv_cache( auto v_cache_ptr = value_cache.data_ptr(); auto attn_out_ptr = attn_outs.data_ptr(); auto attn_w_ptr = attn_weights.data_ptr(); + auto attn_w_ptr2 = attn_weights2.data_ptr(); // stride information auto qStrideB = query.stride(0); @@ -684,24 +815,14 @@ scale_dot_product_for_indirect_access_kv_cache( auto max_parallel_parts = thread_numbers * 4; auto target_block_size = 32L; - if (bs <= 32 and seq_len < 65536) { + if (target_bs <= 32 and seq_len < 65536) { target_block_size = 8L; } - auto kv_block_size = bs * head_num >= max_parallel_parts + auto kv_block_size = target_bs * head_num >= max_parallel_parts ? seq_len : std::max(seq_len / max_parallel_parts, 1L); kv_block_size = std::min(kv_block_size, target_block_size); auto kv_block_count = (seq_len + kv_block_size - 1) / kv_block_size; - auto b_ptr = beam_idx.data_ptr(); - auto max_cache_size = beam_idx.size(0); - long new_beam_idx[beam_batch][offset + query.size(1) + 1] = {}; - auto prompt_len = b_ptr[(max_cache_size - 2) * beam_batch]; - auto prompt_bs = b_ptr[(max_cache_size - 1) * beam_batch]; - auto beam_size = 1; - if (prompt_bs != 0) { - beam_size = beam_batch / prompt_bs; - } - auto need_update_beam_idx = offset > 0 and beam_size > 1; if (need_update_beam_idx) { // according to last decoded token to get the target beam for the past for (int i = 0; i < bs; i++) { @@ -717,7 +838,7 @@ scale_dot_product_for_indirect_access_kv_cache( "ipex::iakv_sdp::matmul(query, key)", c10::ArrayRef({})); #pragma omp parallel for collapse(3) for (auto block_id = 0; block_id < kv_block_count; block_id++) { - for (auto bsi = 0; bsi < bs; bsi += beam_size) { + for (auto bsi = 0; bsi < prompt_bs; bsi++) { for (auto head_group_start = 0; head_group_start < head_num; head_group_start += group_size) { auto k_start = block_id * kv_block_size; @@ -725,46 +846,110 @@ scale_dot_product_for_indirect_access_kv_cache( auto query_ti = 0; // maping the query head to key/value head to support MGA/MQA auto kv_hi = head_group_start / group_size; - for (auto bbi = 0; bbi < beam_size; bbi++) { - auto bi = bsi + bbi; + if (chg_attn_w_layout) { + auto attn_w_stride = + (bsi * head_num + head_group_start) * attn_w_strideH; for (auto ti = k_start; ti < k_start + block_size; ti++) { - auto q_ptr_start = - q_ptr + bi * qStrideB + head_group_start * qStrideH; - auto attn_w_stride = - (bi * head_num + head_group_start) * attn_w_strideH; - auto attn_w_pos = - attn_w_ptr + attn_w_stride + query_ti * seq_len + ti; - attn_w_pos[0] = 0.0f; - auto beam = need_update_beam_idx && ti >= prompt_len - ? new_beam_idx[bi][ti] - : bsi; // caculate the innerproduct for the current token and store the // key if (ti == query_ti + offset) { - auto kc_head_start = k_cache_ptr + ti * kcStrideS + - bi * kcStrideB + kv_hi * kcStrideH; - auto k_ptr_start = k_ptr + bi * kStrideB + kv_hi * kStrideH; - reduce_head( - q_ptr_start, - group_size, - k_ptr_start, - attn_w_pos, - attn_w_strideH, - head_size, - true, - kc_head_start); + for (auto bbi = 0; bbi < beam_size; bbi++) { + auto bi = bsi * beam_size + bbi; + auto q_ptr_start = + q_ptr + bi * qStrideB + head_group_start * qStrideH; + auto attn_w_pos = attn_w_ptr + attn_w_stride + + query_ti * seq_len + ti * beam_size + bbi; + auto kc_head_start = k_cache_ptr + ti * kcStrideS + + bi * kcStrideB + kv_hi * kcStrideH; + auto k_ptr_start = k_ptr + bi * kStrideB + kv_hi * kStrideH; + reduce_head( + q_ptr_start, + group_size, + k_ptr_start, + attn_w_pos, + attn_w_strideH, + head_size, + true, + kc_head_start); + } } else { // caculate the innerproduct for the past token - auto kc_head_start = k_cache_ptr + ti * kcStrideS + - beam * kcStrideB + kv_hi * kcStrideH; - reduce_head( - q_ptr_start, - group_size, - kc_head_start, - attn_w_pos, - attn_w_strideH, - head_size, - false, - nullptr); + auto bi = bsi * beam_size; + auto q_ptr_start = + q_ptr + bi * qStrideB + head_group_start * qStrideH; + auto attn_w_pos = attn_w_ptr + attn_w_stride + + query_ti * seq_len + ti * beam_size; + if (need_update_beam_idx && ti >= prompt_len) { + for (auto bbi = 0; bbi < beam_size; bbi++) { + auto beam = new_beam_idx[bi + bbi][ti]; + auto kc_head_start = k_cache_ptr + ti * kcStrideS + + beam * kcStrideB + kv_hi * kcStrideH; + reduce_head( + q_ptr_start + bbi * qStrideB, + group_size, + kc_head_start, + attn_w_pos + bbi, + attn_w_strideH, + head_size, + false, + nullptr); + } + } else { + auto kc_head_start = k_cache_ptr + ti * kcStrideS + + bi * kcStrideB + kv_hi * kcStrideH; + reduce_head( + q_ptr_start, + qStrideB, + group_size, + kc_head_start, + attn_w_pos, + attn_w_strideH, + head_size, + beam_size); + } + } + } + } else { + for (auto bbi = 0; bbi < beam_size; bbi++) { + auto bi = bsi * beam_size + bbi; + for (auto ti = k_start; ti < k_start + block_size; ti++) { + auto q_ptr_start = + q_ptr + bi * qStrideB + head_group_start * qStrideH; + auto attn_w_stride = + (bi * head_num + head_group_start) * attn_w_strideH; + auto attn_w_pos = + attn_w_ptr + attn_w_stride + query_ti * seq_len + ti; + attn_w_pos[0] = 0.0f; + auto beam = need_update_beam_idx && ti >= prompt_len + ? new_beam_idx[bi][ti] + : bsi * beam_size; + // caculate the innerproduct for the current token and store the + // key + if (ti == query_ti + offset) { + auto kc_head_start = k_cache_ptr + ti * kcStrideS + + bi * kcStrideB + kv_hi * kcStrideH; + auto k_ptr_start = k_ptr + bi * kStrideB + kv_hi * kStrideH; + reduce_head( + q_ptr_start, + group_size, + k_ptr_start, + attn_w_pos, + attn_w_strideH, + head_size, + true, + kc_head_start); + } else { // caculate the innerproduct for the past token + auto kc_head_start = k_cache_ptr + ti * kcStrideS + + beam * kcStrideB + kv_hi * kcStrideH; + reduce_head( + q_ptr_start, + group_size, + kc_head_start, + attn_w_pos, + attn_w_strideH, + head_size, + false, + nullptr); + } } } } @@ -776,56 +961,138 @@ scale_dot_product_for_indirect_access_kv_cache( RECORD_FUNCTION( "ipex::iakv_sdp::div_add_softmax", c10::ArrayRef({})); #pragma omp parallel for collapse(2) - for (auto bi = 0; bi < bs; bi++) { + for (auto bsi = 0; bsi < prompt_bs; bsi++) { for (auto hi = 0; hi < head_num; hi++) { for (auto query_ti = 0; query_ti < cur_len; query_ti++) { - auto mask_ptr_start = mask_ptr + bi * mask_bs_stride + - (hi % mask_head_num) * mask_dim2 * seq_len; - auto attn_w_stride = (bi * head_num + hi) * cur_len * seq_len; - auto attn_w_query_start = - attn_w_ptr + attn_w_stride + query_ti * seq_len; + for (auto bbi = 0; bbi < beam_size; bbi++) { + auto bi = bsi * beam_size + bbi; + auto mask_ptr_start = mask_ptr + bi * mask_bs_stride + + (hi % mask_head_num) * mask_dim2 * seq_len; // div+add+softmax #if defined(CPU_CAPABILITY_AVX512) - for (auto qi = 0; qi < 1; qi++) { auto max_val = -100000.0f; - torch_ipex::cpu::kernel:: - _dil_div_add_reduce_max_fusion_kernel( - attn_w_query_start, - mask_ptr_start + (query_ti % mask_dim2) * seq_len, - scale_factor, - seq_len, - attn_w_query_start, - max_val); + if (chg_attn_w_layout) { + auto attn_w_stride = + (bsi * head_num + hi) * cur_len * seq_len * beam_size; + auto attn_w_query_start = attn_w_ptr + attn_w_stride + + query_ti * seq_len * beam_size + bbi; + auto attn_w_query_start2 = attn_w_ptr2 + attn_w_stride + + query_ti * beam_size * seq_len + bbi * seq_len; + __m512i decrement_sequence = _mm512_set_epi32( + 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0); + __m512i beam_size_vector = _mm512_set1_epi32(beam_size); + int ti = 0; + for (ti = 0; ti <= seq_len - 16; ti += 16) { + __m512i ti_vector = _mm512_set1_epi32(ti); + __m512i index_sequence = + _mm512_add_epi32(decrement_sequence, ti_vector); + __m512i index = + _mm512_mullo_epi32(index_sequence, beam_size_vector); - torch_ipex::cpu::kernel::_dil_exp_reduce_sum_fusion_kernel( - attn_w_query_start, seq_len, attn_w_query_start, max_val); - torch_ipex::cpu::kernel::_dil_normalization_kernel( - attn_w_query_start, max_val, seq_len, attn_w_query_start); - } + __m512 data = _mm512_i32gather_ps( + index, attn_w_query_start, sizeof(float)); + _mm512_storeu_ps(attn_w_query_start2 + ti, data); + } + + for (; ti < seq_len; ti++) { + attn_w_query_start2[ti] = attn_w_query_start[ti * beam_size]; + } + torch_ipex::cpu::kernel:: + _dil_div_add_reduce_max_fusion_kernel( + attn_w_query_start2, + mask_ptr_start + (query_ti % mask_dim2) * seq_len, + scale_factor, + seq_len, + attn_w_query_start2, + max_val); + torch_ipex::cpu::kernel::_dil_exp_reduce_sum_fusion_kernel( + attn_w_query_start2, seq_len, attn_w_query_start2, max_val); + torch_ipex::cpu::kernel::_dil_normalization_kernel( + attn_w_query_start2, max_val, seq_len, attn_w_query_start2); + for (ti = 0; ti <= seq_len - 16; ti += 16) { + __m512i ti_vector = _mm512_set1_epi32(ti); + __m512i index_sequence = + _mm512_add_epi32(decrement_sequence, ti_vector); + __m512i index = + _mm512_mullo_epi32(index_sequence, beam_size_vector); + __m512 data = _mm512_loadu_ps(attn_w_query_start2 + ti); + _mm512_i32scatter_ps( + attn_w_query_start, index, data, sizeof(float)); + } + for (; ti < seq_len; ti++) { + attn_w_query_start[ti * beam_size] = attn_w_query_start2[ti]; + } + } else { + auto attn_w_stride = (bi * head_num + hi) * cur_len * seq_len; + auto attn_w_query_start = + attn_w_ptr + attn_w_stride + query_ti * seq_len; + torch_ipex::cpu::kernel:: + _dil_div_add_reduce_max_fusion_kernel( + attn_w_query_start, + mask_ptr_start + (query_ti % mask_dim2) * seq_len, + scale_factor, + seq_len, + attn_w_query_start, + max_val); + torch_ipex::cpu::kernel::_dil_exp_reduce_sum_fusion_kernel( + attn_w_query_start, seq_len, attn_w_query_start, max_val); + torch_ipex::cpu::kernel::_dil_normalization_kernel( + attn_w_query_start, max_val, seq_len, attn_w_query_start); + } #else - for (auto qi = 0; qi < 1; qi++) { auto max_val = -100000.0f; - // div+add and find max - for (auto si = 0; si < seq_len; si++) { - attn_w_query_start[si] = attn_w_query_start[si] / scale_factor + - mask_ptr_start[(query_ti % mask_dim2) * seq_len + si]; - if (attn_w_query_start[si] > max_val) { - max_val = attn_w_query_start[si]; + if (chg_attn_w_layout) { + auto attn_w_stride = + (bsi * head_num + hi) * cur_len * seq_len * beam_size; + auto attn_w_query_start = attn_w_ptr + attn_w_stride + + query_ti * seq_len * beam_size + bbi; + auto total_len = seq_len * beam_size; + // div+add and find max + for (auto si = 0; si < total_len; si += beam_size) { + attn_w_query_start[si] = attn_w_query_start[si] / scale_factor + + mask_ptr_start[(query_ti % mask_dim2) * seq_len + + si / beam_size]; + if (attn_w_query_start[si] > max_val) { + max_val = attn_w_query_start[si]; + } + } + // softmax + float sum = 0.0f; + // exp and sum + for (auto si = 0; si < total_len; si += beam_size) { + attn_w_query_start[si] = exp(attn_w_query_start[si] - max_val); + sum += attn_w_query_start[si]; + } + // normalization + for (auto si = 0; si < total_len; si += beam_size) { + attn_w_query_start[si] = attn_w_query_start[si] / sum; + } + } else { + auto attn_w_stride = (bi * head_num + hi) * cur_len * seq_len; + auto attn_w_query_start = + attn_w_ptr + attn_w_stride + query_ti * seq_len; + // div+add and find max + for (auto si = 0; si < seq_len; si++) { + attn_w_query_start[si] = attn_w_query_start[si] / scale_factor + + mask_ptr_start[(query_ti % mask_dim2) * seq_len + si]; + if (attn_w_query_start[si] > max_val) { + max_val = attn_w_query_start[si]; + } + } + // softmax + float sum = 0.0f; + // exp and sum + for (auto si = 0; si < seq_len; si++) { + attn_w_query_start[si] = exp(attn_w_query_start[si] - max_val); + sum += attn_w_query_start[si]; + } + // normalization + for (auto si = 0; si < seq_len; si++) { + attn_w_query_start[si] = attn_w_query_start[si] / sum; } } - // softmax - float sum = 0.0f; - // exp and sum - for (auto si = 0; si < seq_len; si++) { - attn_w_query_start[si] = exp(attn_w_query_start[si] - max_val); - sum += attn_w_query_start[si]; - } - // normalization - for (auto si = 0; si < seq_len; si++) { - attn_w_query_start[si] = attn_w_query_start[si] / sum; - } - } #endif + } } } } @@ -848,7 +1115,7 @@ scale_dot_product_for_indirect_access_kv_cache( c10::ArrayRef({})); #pragma omp parallel for collapse(3) for (auto block_id = 0; block_id < kv_block_count; block_id++) { - for (auto bsi = 0; bsi < bs; bsi += beam_size) { + for (auto bsi = 0; bsi < prompt_bs; bsi++) { for (auto hi = 0; hi < head_num; hi += group_size) { auto thread_id = 0; if (kv_block_size < seq_len) @@ -858,58 +1125,149 @@ scale_dot_product_for_indirect_access_kv_cache( auto query_ti = 0; // maping the query head to key/value head to support MGA/MQA auto kv_hi = hi / group_size; - for (auto bbi = 0; bbi < beam_size; bbi++) { - auto bi = bsi + bbi; + if (chg_attn_w_layout) { + auto attn_w_stride = (bsi * head_num + hi) * attn_w_strideH; for (auto vi = v_start; vi < v_start + block_size; vi++) { - auto attn_w_stride = (bi * head_num + hi) * attn_w_strideH; - auto attn_w_query_start = - attn_w_ptr + attn_w_stride + query_ti * seq_len + vi; - // calculate weighted value and store the result to attn_outs[bs, - // head_num, cur_len, head_size] - auto attn_out_start = private_attn_out_ptr + - thread_id * attn_outs_stride_privT + - bi * attn_outs_stride_privB + hi * attn_outs_stride_privH; - auto flag_access_start = flag_access_ptr + - head_num * bs * thread_id + head_num * bi + hi; - - auto beam = need_update_beam_idx && vi >= prompt_len - ? new_beam_idx[bi][vi] - : bsi; - // caculate the innerproduct for the current token and store the - // key if (vi == offset) { - auto v_cache_head_start = v_cache_ptr + vi * vcStrideS + - bi * vcStrideB + kv_hi * vcStrideH; - auto v_ptr_start = v_ptr + bi * vStrideB + kv_hi * vStrideH; - mul_attenion_weights_and_value_of_head( - attn_w_query_start, - attn_w_strideH, - v_ptr_start, - attn_out_start, - head_size, - group_size, - head_size, - true, - v_cache_head_start, - flag_access_start); + for (auto bbi = 0; bbi < beam_size; bbi++) { + auto bi = bsi * beam_size + bbi; + auto attn_w_query_start = attn_w_ptr + attn_w_stride + + query_ti * seq_len + vi * beam_size + bbi; + // calculate weighted value and store the result to + // attn_outs[bs, head_num, cur_len, head_size] + auto attn_out_start = private_attn_out_ptr + + thread_id * attn_outs_stride_privT + + bi * attn_outs_stride_privB + hi * attn_outs_stride_privH; + auto flag_access_start = flag_access_ptr + + head_num * bs * thread_id + head_num * bi + hi; + auto v_cache_head_start = v_cache_ptr + vi * vcStrideS + + bi * vcStrideB + kv_hi * vcStrideH; + auto v_ptr_start = v_ptr + bi * vStrideB + kv_hi * vStrideH; + mul_attenion_weights_and_value_of_head( + attn_w_query_start, + attn_w_strideH, + v_ptr_start, + attn_out_start, + head_size, + group_size, + head_size, + true, + v_cache_head_start, + flag_access_start); + } } else { // caculate the innerproduct for the past token - auto v_cache_head_start = v_cache_ptr + vi * vcStrideS + - beam * vcStrideB + kv_hi * vcStrideH; - mul_attenion_weights_and_value_of_head( - attn_w_query_start, - attn_w_strideH, - v_cache_head_start, - attn_out_start, - head_size, - group_size, - head_size, - false, - nullptr, - flag_access_start); + if (need_update_beam_idx && vi >= prompt_len) { + for (auto bbi = 0; bbi < beam_size; bbi++) { + auto bi = bsi * beam_size + bbi; + auto attn_w_query_start = attn_w_ptr + attn_w_stride + + query_ti * seq_len + vi * beam_size + bbi; + // calculate weighted value and store the result to + // attn_outs[bs, head_num, cur_len, head_size] + auto attn_out_start = private_attn_out_ptr + + thread_id * attn_outs_stride_privT + + bi * attn_outs_stride_privB + + hi * attn_outs_stride_privH; + auto flag_access_start = flag_access_ptr + + head_num * bs * thread_id + head_num * bi + hi; + auto v_ptr_start = v_ptr + bi * vStrideB + kv_hi * vStrideH; + auto beam = new_beam_idx[bi][vi]; + auto v_cache_head_start = v_cache_ptr + vi * vcStrideS + + beam * vcStrideB + kv_hi * vcStrideH; + mul_attenion_weights_and_value_of_head( + attn_w_query_start, + attn_w_strideH, + v_cache_head_start, + attn_out_start, + head_size, + group_size, + head_size, + false, + nullptr, + flag_access_start); + } + } else { + auto bi = bsi * beam_size; + auto attn_w_query_start = attn_w_ptr + attn_w_stride + + query_ti * seq_len + vi * beam_size; + // calculate weighted value and store the result to + // attn_outs[bs, head_num, cur_len, head_size] + auto attn_out_start = private_attn_out_ptr + + thread_id * attn_outs_stride_privT + + bi * attn_outs_stride_privB + hi * attn_outs_stride_privH; + auto flag_access_start = flag_access_ptr + + head_num * bs * thread_id + head_num * bi + hi; + auto v_cache_head_start = v_cache_ptr + vi * vcStrideS + + bi * vcStrideB + kv_hi * vcStrideH; + mul_attenion_weights_and_value_of_head( + attn_w_query_start, + attn_w_strideH, + v_cache_head_start, + attn_out_start, + attn_outs_stride_privB, + head_size, + group_size, + head_size, + false, + nullptr, + flag_access_start, + head_num, + beam_size); + } + } + } + } else { + for (auto bbi = 0; bbi < beam_size; bbi++) { + auto bi = bsi * beam_size + bbi; + for (auto vi = v_start; vi < v_start + block_size; vi++) { + auto attn_w_stride = (bi * head_num + hi) * attn_w_strideH; + auto attn_w_query_start = + attn_w_ptr + attn_w_stride + query_ti * seq_len + vi; + // calculate weighted value and store the result to + // attn_outs[bs, head_num, cur_len, head_size] + auto attn_out_start = private_attn_out_ptr + + thread_id * attn_outs_stride_privT + + bi * attn_outs_stride_privB + hi * attn_outs_stride_privH; + auto flag_access_start = flag_access_ptr + + head_num * bs * thread_id + head_num * bi + hi; + + auto beam = need_update_beam_idx && vi >= prompt_len + ? new_beam_idx[bi][vi] + : bsi * beam_size; + // caculate the innerproduct for the current token and store the + // key + if (vi == offset) { + auto v_cache_head_start = v_cache_ptr + vi * vcStrideS + + bi * vcStrideB + kv_hi * vcStrideH; + auto v_ptr_start = v_ptr + bi * vStrideB + kv_hi * vStrideH; + mul_attenion_weights_and_value_of_head( + attn_w_query_start, + attn_w_strideH, + v_ptr_start, + attn_out_start, + head_size, + group_size, + head_size, + true, + v_cache_head_start, + flag_access_start); + } else { + // caculate the innerproduct for the past token + auto v_cache_head_start = v_cache_ptr + vi * vcStrideS + + beam * vcStrideB + kv_hi * vcStrideH; + mul_attenion_weights_and_value_of_head( + attn_w_query_start, + attn_w_strideH, + v_cache_head_start, + attn_out_start, + head_size, + group_size, + head_size, + false, + nullptr, + flag_access_start); + } } - if (flag_access[thread_id][bi][hi] == 0) - flag_access[thread_id][bi][hi] = 1; } } } @@ -974,13 +1332,36 @@ scale_dot_product_for_indirect_access_kv_cache_half( auto bs = query.size(0); auto cur_len = query.size(1); // only process cur_len==1 auto head_num = query.size(2); + auto head_size = query.size(3); + auto b_ptr = beam_idx.data_ptr(); + auto max_cache_size = beam_idx.size(0); + long new_beam_idx[beam_batch][offset + query.size(1) + 1] = {}; + auto prompt_len = b_ptr[(max_cache_size - 2) * beam_batch]; + auto prompt_bs = b_ptr[(max_cache_size - 1) * beam_batch]; + auto beam_size = 1; + if (prompt_bs != 0) { + beam_size = beam_batch / prompt_bs; + } + auto need_update_beam_idx = offset > 0 and beam_size > 1; auto kv_head = key.size(2); auto group_size = head_num / kv_head; - auto head_size = query.size(3); auto seq_len = offset + cur_len; auto kc_token_stride = beam_batch * kv_head * head_size; - auto attn_weights = - at::empty({bs, head_num, cur_len, seq_len}, key.options()); + at::Tensor attn_weights, attn_weights2; + bool chg_attn_w_layout = false; + auto target_bs = bs; + if (beam_size > 1 && prompt_len <= 2048 && prompt_bs > 20 && + group_size == 1) { + chg_attn_w_layout = true; + attn_weights = at::empty( + {prompt_bs, head_num, cur_len, seq_len, beam_size}, key.options()); + attn_weights2 = at::empty( + {prompt_bs, head_num, cur_len, beam_size, seq_len}, key.options()); + target_bs = prompt_bs; + } else { + attn_weights = at::empty({bs, head_num, cur_len, seq_len}, key.options()); + attn_weights2 = attn_weights; + } query = query.contiguous(); key = key.contiguous(); auto q_ptr = query.data_ptr(); @@ -998,6 +1379,7 @@ scale_dot_product_for_indirect_access_kv_cache_half( auto v_cache_ptr = value_cache.data_ptr(); auto attn_out_ptr = attn_outs.data_ptr(); auto attn_w_ptr = attn_weights.data_ptr(); + auto attn_w_ptr2 = attn_weights2.data_ptr(); // stride information auto qStrideB = query.stride(0); @@ -1026,24 +1408,14 @@ scale_dot_product_for_indirect_access_kv_cache_half( auto max_parallel_parts = thread_numbers * 4; auto target_block_size = 32L; - if (bs <= 32 and seq_len < 65536) { + if (target_bs <= 32 and seq_len < 65536) { target_block_size = 8L; } - auto kv_block_size = bs * head_num >= max_parallel_parts + auto kv_block_size = target_bs * head_num >= max_parallel_parts ? seq_len : std::max(seq_len / max_parallel_parts, 1L); kv_block_size = std::min(kv_block_size, target_block_size); auto kv_block_count = (seq_len + kv_block_size - 1) / kv_block_size; - auto b_ptr = beam_idx.data_ptr(); - auto max_cache_size = beam_idx.size(0); - long new_beam_idx[beam_batch][offset + query.size(1) + 1] = {}; - auto prompt_len = b_ptr[(max_cache_size - 2) * beam_batch]; - auto prompt_bs = b_ptr[(max_cache_size - 1) * beam_batch]; - auto beam_size = 1; - if (prompt_bs != 0) { - beam_size = beam_batch / prompt_bs; - } - auto need_update_beam_idx = offset > 0 and beam_size > 1; if (need_update_beam_idx) { // according to the last decoded token to get the target beam for the past // token @@ -1060,7 +1432,7 @@ scale_dot_product_for_indirect_access_kv_cache_half( "ipex::iakv_sdp::matmul(query, key)", c10::ArrayRef({})); #pragma omp parallel for collapse(3) for (auto block_id = 0; block_id < kv_block_count; block_id++) { - for (auto bsi = 0; bsi < bs; bsi += beam_size) { + for (auto bsi = 0; bsi < prompt_bs; bsi++) { for (auto head_group_start = 0; head_group_start < head_num; head_group_start += group_size) { auto k_start = block_id * kv_block_size; @@ -1068,46 +1440,110 @@ scale_dot_product_for_indirect_access_kv_cache_half( auto query_ti = 0; // maping the query head to key/value head to support MGA/MQA auto kv_hi = head_group_start / group_size; - for (auto bbi = 0; bbi < beam_size; bbi++) { - auto bi = bsi + bbi; + if (chg_attn_w_layout) { + auto attn_w_stride = + (bsi * head_num + head_group_start) * attn_w_strideH; for (auto ti = k_start; ti < k_start + block_size; ti++) { - auto q_ptr_start = - q_ptr + bi * qStrideB + head_group_start * qStrideH; - auto attn_w_stride = - (bi * head_num + head_group_start) * attn_w_strideH; - auto attn_w_pos = - attn_w_ptr + attn_w_stride + query_ti * seq_len + ti; - attn_w_pos[0] = 0.0f; - auto beam = need_update_beam_idx && ti >= prompt_len - ? new_beam_idx[bi][ti] - : bi - bi % beam_size; // caculate the innerproduct for the current token and store the // key if (ti == query_ti + offset) { - auto kc_head_start = k_cache_ptr + ti * kcStrideS + - bi * kcStrideB + kv_hi * kcStrideH; - auto k_ptr_start = k_ptr + bi * kStrideB + kv_hi * kStrideH; - reduce_head_half( - q_ptr_start, - group_size, - k_ptr_start, - attn_w_pos, - attn_w_strideH, - head_size, - true, - kc_head_start); + for (auto bbi = 0; bbi < beam_size; bbi++) { + auto bi = bsi * beam_size + bbi; + auto q_ptr_start = + q_ptr + bi * qStrideB + head_group_start * qStrideH; + auto attn_w_pos = attn_w_ptr + attn_w_stride + + query_ti * seq_len + ti * beam_size + bbi; + auto kc_head_start = k_cache_ptr + ti * kcStrideS + + bi * kcStrideB + kv_hi * kcStrideH; + auto k_ptr_start = k_ptr + bi * kStrideB + kv_hi * kStrideH; + reduce_head_half( + q_ptr_start, + group_size, + k_ptr_start, + attn_w_pos, + attn_w_strideH, + head_size, + true, + kc_head_start); + } } else { // caculate the innerproduct for the past token - auto kc_head_start = k_cache_ptr + ti * kcStrideS + - beam * kcStrideB + kv_hi * kcStrideH; - reduce_head_half( - q_ptr_start, - group_size, - kc_head_start, - attn_w_pos, - attn_w_strideH, - head_size, - false, - nullptr); + auto bi = bsi * beam_size; + auto q_ptr_start = + q_ptr + bi * qStrideB + head_group_start * qStrideH; + auto attn_w_pos = attn_w_ptr + attn_w_stride + + query_ti * seq_len + ti * beam_size; + if (need_update_beam_idx && ti >= prompt_len) { + for (auto bbi = 0; bbi < beam_size; bbi++) { + auto beam = new_beam_idx[bi + bbi][ti]; + auto kc_head_start = k_cache_ptr + ti * kcStrideS + + beam * kcStrideB + kv_hi * kcStrideH; + reduce_head_half( + q_ptr_start + bbi * qStrideB, + group_size, + kc_head_start, + attn_w_pos + bbi, + attn_w_strideH, + head_size, + false, + nullptr); + } + } else { + auto kc_head_start = k_cache_ptr + ti * kcStrideS + + bi * kcStrideB + kv_hi * kcStrideH; + reduce_head_half( + q_ptr_start, + qStrideB, + group_size, + kc_head_start, + attn_w_pos, + attn_w_strideH, + head_size, + beam_size); + } + } + } + } else { + for (auto bbi = 0; bbi < beam_size; bbi++) { + auto bi = bsi * beam_size + bbi; + for (auto ti = k_start; ti < k_start + block_size; ti++) { + auto q_ptr_start = + q_ptr + bi * qStrideB + head_group_start * qStrideH; + auto attn_w_stride = + (bi * head_num + head_group_start) * attn_w_strideH; + auto attn_w_pos = + attn_w_ptr + attn_w_stride + query_ti * seq_len + ti; + attn_w_pos[0] = 0.0f; + auto beam = need_update_beam_idx && ti >= prompt_len + ? new_beam_idx[bi][ti] + : bsi * beam_size; + // caculate the innerproduct for the current token and store the + // key + if (ti == query_ti + offset) { + auto kc_head_start = k_cache_ptr + ti * kcStrideS + + bi * kcStrideB + kv_hi * kcStrideH; + auto k_ptr_start = k_ptr + bi * kStrideB + kv_hi * kStrideH; + reduce_head_half( + q_ptr_start, + group_size, + k_ptr_start, + attn_w_pos, + attn_w_strideH, + head_size, + true, + kc_head_start); + } else { // caculate the innerproduct for the past token + auto kc_head_start = k_cache_ptr + ti * kcStrideS + + beam * kcStrideB + kv_hi * kcStrideH; + reduce_head_half( + q_ptr_start, + group_size, + kc_head_start, + attn_w_pos, + attn_w_strideH, + head_size, + false, + nullptr); + } } } } @@ -1119,29 +1555,59 @@ scale_dot_product_for_indirect_access_kv_cache_half( RECORD_FUNCTION( "ipex::iakv_sdp::div_add_softmax", c10::ArrayRef({})); #pragma omp parallel for collapse(2) - for (auto bi = 0; bi < bs; bi++) { + for (auto bsi = 0; bsi < prompt_bs; bsi++) { for (auto hi = 0; hi < head_num; hi++) { for (auto query_ti = 0; query_ti < cur_len; query_ti++) { - auto mask_ptr_start = mask_ptr + bi * mask_bs_stride + - (hi % mask_head_num) * mask_dim2 * seq_len; - auto attn_w_stride = (bi * head_num + hi) * cur_len * seq_len; - auto attn_w_query_start = - attn_w_ptr + attn_w_stride + query_ti * seq_len; - // div+add+softmax - for (auto qi = 0; qi < 1; qi++) { + for (auto bbi = 0; bbi < beam_size; bbi++) { + auto bi = bsi * beam_size + bbi; + auto mask_ptr_start = mask_ptr + bi * mask_bs_stride + + (hi % mask_head_num) * mask_dim2 * seq_len; + + // div+add+softmax at::Half max_val = -100000.0f; - torch_ipex::cpu::kernel::_dil_div_add_reduce_max_fusion_kernel_half( - attn_w_query_start, - mask_ptr_start + (query_ti % mask_dim2) * seq_len, - scale_factor, - seq_len, - attn_w_query_start, - max_val); + if (chg_attn_w_layout) { + auto attn_w_stride = + (bsi * head_num + hi) * cur_len * seq_len * beam_size; + auto attn_w_query_start = attn_w_ptr + attn_w_stride + + query_ti * seq_len * beam_size + bbi; + auto attn_w_query_start2 = attn_w_ptr2 + attn_w_stride + + query_ti * beam_size * seq_len + bbi * seq_len; + for (auto ti = 0; ti < seq_len; ti++) { + attn_w_query_start2[ti] = attn_w_query_start[ti * beam_size]; + } + torch_ipex::cpu::kernel:: + _dil_div_add_reduce_max_fusion_kernel_half( + attn_w_query_start2, + mask_ptr_start + (query_ti % mask_dim2) * seq_len, + scale_factor, + seq_len, + attn_w_query_start2, + max_val); - torch_ipex::cpu::kernel::_dil_exp_reduce_sum_fusion_kernel_half( - attn_w_query_start, seq_len, attn_w_query_start, max_val); - torch_ipex::cpu::kernel::_dil_normalization_kernel_half( - attn_w_query_start, max_val, seq_len, attn_w_query_start); + torch_ipex::cpu::kernel::_dil_exp_reduce_sum_fusion_kernel_half( + attn_w_query_start2, seq_len, attn_w_query_start2, max_val); + torch_ipex::cpu::kernel::_dil_normalization_kernel_half( + attn_w_query_start2, max_val, seq_len, attn_w_query_start2); + for (auto ti = 0; ti < seq_len; ti++) { + attn_w_query_start[ti * beam_size] = attn_w_query_start2[ti]; + } + } else { + auto attn_w_stride = (bi * head_num + hi) * cur_len * seq_len; + auto attn_w_query_start = + attn_w_ptr + attn_w_stride + query_ti * seq_len; + torch_ipex::cpu::kernel:: + _dil_div_add_reduce_max_fusion_kernel_half( + attn_w_query_start, + mask_ptr_start + (query_ti % mask_dim2) * seq_len, + scale_factor, + seq_len, + attn_w_query_start, + max_val); + torch_ipex::cpu::kernel::_dil_exp_reduce_sum_fusion_kernel_half( + attn_w_query_start, seq_len, attn_w_query_start, max_val); + torch_ipex::cpu::kernel::_dil_normalization_kernel_half( + attn_w_query_start, max_val, seq_len, attn_w_query_start); + } } } } @@ -1165,7 +1631,7 @@ scale_dot_product_for_indirect_access_kv_cache_half( c10::ArrayRef({})); #pragma omp parallel for collapse(3) for (auto block_id = 0; block_id < kv_block_count; block_id++) { - for (auto bsi = 0; bsi < bs; bsi += beam_size) { + for (auto bsi = 0; bsi < prompt_bs; bsi++) { for (auto hi = 0; hi < head_num; hi += group_size) { auto thread_id = 0; if (kv_block_size < seq_len) @@ -1175,57 +1641,149 @@ scale_dot_product_for_indirect_access_kv_cache_half( auto query_ti = 0; // maping the query head to key/value head to support MGA/MQA auto kv_hi = hi / group_size; - for (auto bbi = 0; bbi < beam_size; bbi++) { - auto bi = bsi + bbi; + if (chg_attn_w_layout) { + auto attn_w_stride = (bsi * head_num + hi) * attn_w_strideH; for (auto vi = v_start; vi < v_start + block_size; vi++) { - auto attn_w_stride = (bi * head_num + hi) * attn_w_strideH; - auto attn_w_query_start = - attn_w_ptr + attn_w_stride + query_ti * seq_len + vi; - // calculate weighted value and store the result to attn_outs[bs, - // head_num, cur_len, head_size] - auto attn_out_start = private_attn_out_ptr + - thread_id * attn_outs_stride_privT + - bi * attn_outs_stride_privB + hi * attn_outs_stride_privH; - auto flag_access_start = flag_access_ptr + - head_num * bs * thread_id + head_num * bi + hi; - - auto beam = need_update_beam_idx && vi >= prompt_len - ? new_beam_idx[bi][vi] - : bi - bi % beam_size; // caculate the attention values for the current token if (vi == offset) { - auto v_cache_head_start = v_cache_ptr + vi * vcStrideS + - bi * vcStrideB + kv_hi * vcStrideH; - auto v_ptr_start = v_ptr + bi * vStrideB + kv_hi * vStrideH; - mul_attenion_weights_and_value_of_head_half( - attn_w_query_start, - attn_w_strideH, - v_ptr_start, - attn_out_start, - head_size, - group_size, - head_size, - true, - v_cache_head_start, - flag_access_start); + for (auto bbi = 0; bbi < beam_size; bbi++) { + auto bi = bsi * beam_size + bbi; + auto attn_w_query_start = attn_w_ptr + attn_w_stride + + query_ti * seq_len + vi * beam_size + bbi; + // calculate weighted value and store the result to + // attn_outs[bs, head_num, cur_len, head_size] + auto attn_out_start = private_attn_out_ptr + + thread_id * attn_outs_stride_privT + + bi * attn_outs_stride_privB + hi * attn_outs_stride_privH; + auto flag_access_start = flag_access_ptr + + head_num * bs * thread_id + head_num * bi + hi; + auto v_cache_head_start = v_cache_ptr + vi * vcStrideS + + bi * vcStrideB + kv_hi * vcStrideH; + auto v_ptr_start = v_ptr + bi * vStrideB + kv_hi * vStrideH; + mul_attenion_weights_and_value_of_head_half( + attn_w_query_start, + attn_w_strideH, + v_ptr_start, + attn_out_start, + head_size, + group_size, + head_size, + true, + v_cache_head_start, + flag_access_start); + } } else { // caculate the innerproduct for the past token - auto v_cache_head_start = v_cache_ptr + vi * vcStrideS + - beam * vcStrideB + kv_hi * vcStrideH; - mul_attenion_weights_and_value_of_head_half( - attn_w_query_start, - attn_w_strideH, - v_cache_head_start, - attn_out_start, - head_size, - group_size, - head_size, - false, - nullptr, - flag_access_start); + if (need_update_beam_idx && vi >= prompt_len) { + for (auto bbi = 0; bbi < beam_size; bbi++) { + auto bi = bsi * beam_size + bbi; + auto attn_w_query_start = attn_w_ptr + attn_w_stride + + query_ti * seq_len + vi * beam_size + bbi; + // calculate weighted value and store the result to + // attn_outs[bs, head_num, cur_len, head_size] + auto attn_out_start = private_attn_out_ptr + + thread_id * attn_outs_stride_privT + + bi * attn_outs_stride_privB + + hi * attn_outs_stride_privH; + auto flag_access_start = flag_access_ptr + + head_num * bs * thread_id + head_num * bi + hi; + auto v_ptr_start = v_ptr + bi * vStrideB + kv_hi * vStrideH; + auto beam = new_beam_idx[bi][vi]; + auto v_cache_head_start = v_cache_ptr + vi * vcStrideS + + beam * vcStrideB + kv_hi * vcStrideH; + mul_attenion_weights_and_value_of_head_half( + attn_w_query_start, + attn_w_strideH, + v_cache_head_start, + attn_out_start, + head_size, + group_size, + head_size, + false, + nullptr, + flag_access_start); + } + } else { + auto bi = bsi * beam_size; + auto attn_w_query_start = attn_w_ptr + attn_w_stride + + query_ti * seq_len + vi * beam_size; + // calculate weighted value and store the result to + // attn_outs[bs, head_num, cur_len, head_size] + auto attn_out_start = private_attn_out_ptr + + thread_id * attn_outs_stride_privT + + bi * attn_outs_stride_privB + hi * attn_outs_stride_privH; + auto flag_access_start = flag_access_ptr + + head_num * bs * thread_id + head_num * bi + hi; + auto v_cache_head_start = v_cache_ptr + vi * vcStrideS + + bi * vcStrideB + kv_hi * vcStrideH; + mul_attenion_weights_and_value_of_head_half( + attn_w_query_start, + attn_w_strideH, + v_cache_head_start, + attn_out_start, + attn_outs_stride_privB, + head_size, + group_size, + head_size, + false, + nullptr, + flag_access_start, + head_num, + beam_size); + } + } + } + } else { + for (auto bbi = 0; bbi < beam_size; bbi++) { + auto bi = bsi * beam_size + bbi; + for (auto vi = v_start; vi < v_start + block_size; vi++) { + auto attn_w_stride = (bi * head_num + hi) * attn_w_strideH; + auto attn_w_query_start = + attn_w_ptr + attn_w_stride + query_ti * seq_len + vi; + // calculate weighted value and store the result to + // attn_outs[bs, head_num, cur_len, head_size] + auto attn_out_start = private_attn_out_ptr + + thread_id * attn_outs_stride_privT + + bi * attn_outs_stride_privB + hi * attn_outs_stride_privH; + auto flag_access_start = flag_access_ptr + + head_num * bs * thread_id + head_num * bi + hi; + + auto beam = need_update_beam_idx && vi >= prompt_len + ? new_beam_idx[bi][vi] + : bsi * beam_size; + // caculate the attention values for the current token + if (vi == offset) { + auto v_cache_head_start = v_cache_ptr + vi * vcStrideS + + bi * vcStrideB + kv_hi * vcStrideH; + auto v_ptr_start = v_ptr + bi * vStrideB + kv_hi * vStrideH; + mul_attenion_weights_and_value_of_head_half( + attn_w_query_start, + attn_w_strideH, + v_ptr_start, + attn_out_start, + head_size, + group_size, + head_size, + true, + v_cache_head_start, + flag_access_start); + } else { + // caculate the innerproduct for the past token + auto v_cache_head_start = v_cache_ptr + vi * vcStrideS + + beam * vcStrideB + kv_hi * vcStrideH; + mul_attenion_weights_and_value_of_head_half( + attn_w_query_start, + attn_w_strideH, + v_cache_head_start, + attn_out_start, + head_size, + group_size, + head_size, + false, + nullptr, + flag_access_start); + } } - if (flag_access[thread_id][bi][hi] == 0) - flag_access[thread_id][bi][hi] = 1; } } } diff --git a/csrc/cpu/vec/vec512/perf_kernel/add_softmax.h b/csrc/cpu/vec/vec512/perf_kernel/add_softmax.h index ae971a63c..905c987aa 100644 --- a/csrc/cpu/vec/vec512/perf_kernel/add_softmax.h +++ b/csrc/cpu/vec/vec512/perf_kernel/add_softmax.h @@ -339,21 +339,21 @@ inline void _dil_normalization_kernel( const float& sum, const int& size, scalar_t* out) { - auto vec_sum = _mm512_set1_ps(sum); + auto vec_sum = _mm512_set1_ps(1.0 / sum); __m512 vec_a = {}; __m512 vec_out = {}; int i = 0; for (; i <= size - 16; i += 16) { auto vec_a = _mm512_loadu_ps(a + i); - auto vec_out = _mm512_div_ps(vec_a, vec_sum); + auto vec_out = _mm512_mul_ps(vec_a, vec_sum); _storeu(out + i, vec_out); } if (i < size) { __mmask16 mask = (1 << (size - i)) - 1; auto vec_a = _mm512_maskz_loadu_ps(mask, a + i); - auto vec_out = _mm512_div_ps(vec_a, vec_sum); + auto vec_out = _mm512_mul_ps(vec_a, vec_sum); _mask_storeu(out + i, vec_out, mask); } } @@ -364,21 +364,22 @@ inline void _dil_normalization_kernel_half( const at::Half& sum, const int& size, at::Half* out) { - auto vec_sum = _mm512_set1_ph(*(_Float16*)&sum); + at::Half scale = 1.0 / sum; + auto vec_sum = _mm512_set1_ph(*(_Float16*)&scale); __m512h vec_a = {}; __m512h vec_out = {}; int i = 0; for (; i <= size - 32; i += 32) { auto vec_a = _loadu_half(a + i); - auto vec_out = _mm512_div_ph(vec_a, vec_sum); + auto vec_out = _mm512_mul_ph(vec_a, vec_sum); _storeu_Half(out + i, vec_out); } if (i < size) { __mmask32 mask = (1 << (size - i)) - 1; auto vec_a = _maskz_loadu(a + i, mask); - auto vec_out = _mm512_div_ph(vec_a, vec_sum); + auto vec_out = _mm512_mul_ph(vec_a, vec_sum); _mask_storeu(out + i, vec_out, mask); } }