diff --git a/csrc/gpu/append_attention.cu b/csrc/gpu/append_attention.cu new file mode 100644 index 000000000000..48cd2b6ad605 --- /dev/null +++ b/csrc/gpu/append_attention.cu @@ -0,0 +1,792 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "append_attn/append_attention_kernel.h" +#include "append_attn/decoder_write_cache_with_rope_kernel.h" +#include "append_attn/encoder_write_cache_with_rope_kernel.h" + +template +std::vector AppendAttentionKernel( + const paddle::Tensor& qkv, + const paddle::Tensor& key_cache, + const paddle::Tensor& value_cache, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& seq_lens_decoder, + const paddle::Tensor& seq_lens_this_time, + const paddle::Tensor& padding_offsets, + const paddle::Tensor& cum_offsets, + const paddle::Tensor& block_tables, + const paddle::Tensor& encoder_batch_ids, + const paddle::Tensor& encoder_tile_ids_per_batch, + const paddle::Tensor& encoder_num_blocks, + const paddle::Tensor& kv_batch_ids, + const paddle::Tensor& kv_tile_ids_per_batch, + const paddle::Tensor& kv_num_blocks, + const paddle::Tensor& decoder_batch_ids, + const paddle::Tensor& decoder_tile_ids_per_batch, + const paddle::Tensor& decoder_num_blocks, + const paddle::Tensor& max_enc_len_this_time, + const paddle::Tensor& max_dec_len_this_time, + const paddle::optional& rotary_embs, + const paddle::optional& attn_mask, + const paddle::optional& qkv_bias, + const paddle::optional& qkv_out_scales, + const paddle::optional& cache_k_quant_scales, + const paddle::optional& cache_v_quant_scales, + const paddle::optional& cache_k_dequant_scales, + const paddle::optional& cache_v_dequant_scales, + const paddle::optional& cache_k_zp, + const paddle::optional& cache_v_zp, + const paddle::optional& out_linear_shifts, + const paddle::optional& out_linear_smooths, + const std::string& cache_quant_type_str, + const bool use_neox_rotary_style, + const int max_input_length, + const float out_linear_in_scale, + const int encoder_block_shape_q, + const int decoder_block_shape_q, + const int max_partition_size, + const int encoder_max_partition_size, + const int speculate_max_draft_token_num, + const bool causal, + const bool enable_prefill) { + typedef PDTraits traits_; + typedef typename traits_::DataType DataType_; + typedef typename traits_::data_t data_t; + + const auto& qkv_dims = qkv.dims(); + const auto& key_cache_dims = key_cache.dims(); + const int token_num = qkv_dims[0]; + const int kv_num_heads = key_cache_dims[1]; + const int head_dim = key_cache_dims[3]; + const int total_num_head = qkv_dims[qkv_dims.size() - 1] / head_dim; + const int num_heads = total_num_head - 2 * kv_num_heads; + + int encoder_num_blocks_data = encoder_num_blocks.data()[0]; + int kv_num_blocks_data = kv_num_blocks.data()[0]; + int decoder_num_blocks_data = decoder_num_blocks.data()[0]; + int max_enc_len_this_time_data = max_enc_len_this_time.data()[0]; + int max_dec_len_this_time_data = max_dec_len_this_time.data()[0]; + + auto main_stream = qkv.stream(); + static cudaEvent_t main_event; + static cudaEvent_t decoder_event; + static cudaStream_t decoder_stream; + static bool init_flag = false; + if (max_enc_len_this_time_data > 0 && max_dec_len_this_time_data > 0 && + !init_flag) { + cudaEventCreateWithFlags(&main_event, cudaEventDisableTiming); + cudaEventCreateWithFlags(&decoder_event, cudaEventDisableTiming); + cudaStreamCreateWithFlags(&decoder_stream, cudaStreamNonBlocking); + init_flag = true; + } + + paddle::Tensor qkv_out; + if (qkv_out_scales) { + qkv_out = GetEmptyTensor(qkv.dims(), D, qkv.place()); + } else { + qkv_out = qkv; + } + paddle::Tensor fmha_out; + if (out_linear_in_scale > 0.0) { + fmha_out = GetEmptyTensor( + {token_num, num_heads * head_dim}, paddle::DataType::INT8, qkv.place()); + } else { + fmha_out = + GetEmptyTensor({token_num, num_heads * head_dim}, D, qkv.place()); + } + + if (max_enc_len_this_time_data > 0) { + if (max_dec_len_this_time_data > 0) { + cudaEventRecord(main_event, main_stream); + } + if (qkv_out_scales) { + EncoderWriteCacheWithRopeKernel( + qkv, + seq_lens_this_time, + seq_lens_encoder, + seq_lens_decoder, + padding_offsets, + cum_offsets, + block_tables, + kv_batch_ids, + kv_tile_ids_per_batch, + rotary_embs, + qkv_out_scales, + qkv_bias, + cache_k_quant_scales, + cache_v_quant_scales, + cache_k_zp, + cache_v_zp, + cache_quant_type_str, + kv_num_blocks_data, + max_input_length, + num_heads, + kv_num_heads, + head_dim, + use_neox_rotary_style, + main_stream, + &qkv_out, + const_cast(&key_cache), + const_cast(&value_cache)); + } else { + EncoderWriteCacheWithRopeKernel( + qkv_out, + seq_lens_this_time, + seq_lens_encoder, + seq_lens_decoder, + padding_offsets, + cum_offsets, + block_tables, + kv_batch_ids, + kv_tile_ids_per_batch, + rotary_embs, + qkv_out_scales, + qkv_bias, + cache_k_quant_scales, + cache_v_quant_scales, + cache_k_zp, + cache_v_zp, + cache_quant_type_str, + kv_num_blocks_data, + max_input_length, + num_heads, + kv_num_heads, + head_dim, + use_neox_rotary_style, + main_stream, + &qkv_out, + const_cast(&key_cache), + const_cast(&value_cache)); + } + if (out_linear_in_scale > 0.0) { + CascadeAppendAttentionKernel( + qkv_out, + key_cache, + value_cache, + attn_mask, + cache_k_dequant_scales, + cache_v_dequant_scales, + cache_k_zp, + cache_v_zp, + out_linear_shifts, + out_linear_smooths, + seq_lens_this_time, + seq_lens_decoder, + seq_lens_encoder, + padding_offsets, + cum_offsets, + block_tables, + encoder_batch_ids, + encoder_tile_ids_per_batch, + cache_quant_type_str, + encoder_num_blocks_data, + encoder_block_shape_q, + max_input_length, + max_enc_len_this_time_data, + num_heads, + kv_num_heads, + head_dim, + out_linear_in_scale, + max_partition_size, + encoder_max_partition_size, + speculate_max_draft_token_num, + causal, + false, + enable_prefill, + main_stream, + &fmha_out); + } else { + CascadeAppendAttentionKernel( + qkv_out, + key_cache, + value_cache, + attn_mask, + cache_k_dequant_scales, + cache_v_dequant_scales, + cache_k_zp, + cache_v_zp, + out_linear_shifts, + out_linear_smooths, + seq_lens_this_time, + seq_lens_decoder, + seq_lens_encoder, + padding_offsets, + cum_offsets, + block_tables, + encoder_batch_ids, + encoder_tile_ids_per_batch, + cache_quant_type_str, + encoder_num_blocks_data, + encoder_block_shape_q, + max_input_length, + max_enc_len_this_time_data, + num_heads, + kv_num_heads, + head_dim, + out_linear_in_scale, + max_partition_size, + encoder_max_partition_size, + speculate_max_draft_token_num, + causal, + false, + enable_prefill, + main_stream, + &fmha_out); + } + } + + if (max_dec_len_this_time_data > 0) { + cudaStream_t exec_stream; + if (max_enc_len_this_time_data > 0) { + cudaStreamWaitEvent(decoder_stream, main_event); + exec_stream = decoder_stream; + } else { + exec_stream = main_stream; + } + + if (qkv_out_scales) { + DecoderWriteCacheWithRoPEKernel( + qkv, // [token_num, num_heads, head_dim] + seq_lens_decoder, + seq_lens_encoder, + padding_offsets, + cum_offsets, + block_tables, + rotary_embs, + qkv_out_scales, + qkv_bias, + cache_k_quant_scales, + cache_v_quant_scales, + cache_k_zp, + cache_v_zp, + cache_quant_type_str, + use_neox_rotary_style, + max_input_length, + num_heads, + kv_num_heads, + head_dim, + exec_stream, + &qkv_out, + const_cast(&key_cache), + const_cast(&value_cache)); + } else { + DecoderWriteCacheWithRoPEKernel( + qkv_out, // [token_num, num_heads, head_dim] + seq_lens_decoder, + seq_lens_encoder, + padding_offsets, + cum_offsets, + block_tables, + rotary_embs, + qkv_out_scales, + qkv_bias, + cache_k_quant_scales, + cache_v_quant_scales, + cache_k_zp, + cache_v_zp, + cache_quant_type_str, + use_neox_rotary_style, + max_input_length, + num_heads, + kv_num_heads, + head_dim, + exec_stream, + &qkv_out, + const_cast(&key_cache), + const_cast(&value_cache)); + } + + if (out_linear_in_scale > 0.0) { + CascadeAppendAttentionKernel( + qkv_out, + key_cache, + value_cache, + attn_mask, + cache_k_dequant_scales, + cache_v_dequant_scales, + cache_k_zp, + cache_v_zp, + out_linear_shifts, + out_linear_smooths, + seq_lens_this_time, + seq_lens_decoder, + seq_lens_encoder, + padding_offsets, + cum_offsets, + block_tables, + decoder_batch_ids, + decoder_tile_ids_per_batch, + cache_quant_type_str, + decoder_num_blocks_data, + decoder_block_shape_q, + max_input_length, + max_dec_len_this_time_data + 1, + num_heads, + kv_num_heads, + head_dim, + out_linear_in_scale, + max_partition_size, + encoder_max_partition_size, + speculate_max_draft_token_num, + causal, + true, + enable_prefill, + exec_stream, + &fmha_out); + } else { + CascadeAppendAttentionKernel( + qkv_out, + key_cache, + value_cache, + attn_mask, + cache_k_dequant_scales, + cache_v_dequant_scales, + cache_k_zp, + cache_v_zp, + out_linear_shifts, + out_linear_smooths, + seq_lens_this_time, + seq_lens_decoder, + seq_lens_encoder, + padding_offsets, + cum_offsets, + block_tables, + decoder_batch_ids, + decoder_tile_ids_per_batch, + cache_quant_type_str, + decoder_num_blocks_data, + decoder_block_shape_q, + max_input_length, + max_dec_len_this_time_data + 1, + num_heads, + kv_num_heads, + head_dim, + out_linear_in_scale, + max_partition_size, + encoder_max_partition_size, + speculate_max_draft_token_num, + causal, + true, + enable_prefill, + exec_stream, + &fmha_out); + } + if (max_enc_len_this_time_data > 0) { + cudaEventRecord(decoder_event, exec_stream); + cudaStreamWaitEvent(main_stream, decoder_event); + } + } + + return {fmha_out, qkv_out}; +} + +std::vector AppendAttention( + const paddle::Tensor& qkv, + const paddle::Tensor& key_cache, + const paddle::Tensor& value_cache, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& seq_lens_decoder, + const paddle::Tensor& seq_lens_this_time, + const paddle::Tensor& padding_offsets, + const paddle::Tensor& cum_offsets, + const paddle::Tensor& block_tables, + const paddle::Tensor& encoder_batch_ids, + const paddle::Tensor& encoder_tile_ids_per_batch, + const paddle::Tensor& encoder_num_blocks, + const paddle::Tensor& kv_batch_ids, + const paddle::Tensor& kv_tile_ids_per_batch, + const paddle::Tensor& kv_num_blocks, + const paddle::Tensor& decoder_batch_ids, + const paddle::Tensor& decoder_tile_ids_per_batch, + const paddle::Tensor& decoder_num_blocks, + const paddle::Tensor& max_enc_len_this_time, + const paddle::Tensor& max_dec_len_this_time, + const paddle::optional& rotary_embs, + const paddle::optional& attn_mask, + const paddle::optional& qkv_bias, + const paddle::optional& qkv_out_scales, + const paddle::optional& cache_k_quant_scales, + const paddle::optional& cache_v_quant_scales, + const paddle::optional& cache_k_dequant_scales, + const paddle::optional& cache_v_dequant_scales, + const paddle::optional& cache_k_zp, + const paddle::optional& cache_v_zp, + const paddle::optional& out_linear_shifts, + const paddle::optional& out_linear_smooths, + const std::string& compute_dtype, + const std::string& cache_quant_type_str, + const bool use_neox_rotary_style, + const int max_input_length, + const float out_linear_in_scale, + const int encoder_block_shape_q, + const int decoder_block_shape_q, + const int max_partition_size, + const int encoder_max_partition_size, + const int speculate_max_draft_token_num, + const bool causal, + const bool enable_prefill) { + switch (qkv.dtype()) { + case paddle::DataType::FLOAT16: { + return AppendAttentionKernel( + qkv, + key_cache, + value_cache, + seq_lens_encoder, + seq_lens_decoder, + seq_lens_this_time, + padding_offsets, + cum_offsets, + block_tables, + encoder_batch_ids, + encoder_tile_ids_per_batch, + encoder_num_blocks, + kv_batch_ids, + kv_tile_ids_per_batch, + kv_num_blocks, + decoder_batch_ids, + decoder_tile_ids_per_batch, + decoder_num_blocks, + max_enc_len_this_time, + max_dec_len_this_time, + rotary_embs, + attn_mask, + qkv_bias, + qkv_out_scales, + cache_k_quant_scales, + cache_v_quant_scales, + cache_k_dequant_scales, + cache_v_dequant_scales, + cache_k_zp, + cache_v_zp, + out_linear_shifts, + out_linear_smooths, + cache_quant_type_str, + use_neox_rotary_style, + max_input_length, + out_linear_in_scale, + encoder_block_shape_q, + decoder_block_shape_q, + max_partition_size, + encoder_max_partition_size, + speculate_max_draft_token_num, + causal, + enable_prefill); + } + case paddle::DataType::BFLOAT16: { + return AppendAttentionKernel( + qkv, + key_cache, + value_cache, + seq_lens_encoder, + seq_lens_decoder, + seq_lens_this_time, + padding_offsets, + cum_offsets, + block_tables, + encoder_batch_ids, + encoder_tile_ids_per_batch, + encoder_num_blocks, + kv_batch_ids, + kv_tile_ids_per_batch, + kv_num_blocks, + decoder_batch_ids, + decoder_tile_ids_per_batch, + decoder_num_blocks, + max_enc_len_this_time, + max_dec_len_this_time, + rotary_embs, + attn_mask, + qkv_bias, + qkv_out_scales, + cache_k_quant_scales, + cache_v_quant_scales, + cache_k_dequant_scales, + cache_v_dequant_scales, + cache_k_zp, + cache_v_zp, + out_linear_shifts, + out_linear_smooths, + cache_quant_type_str, + use_neox_rotary_style, + max_input_length, + out_linear_in_scale, + encoder_block_shape_q, + decoder_block_shape_q, + max_partition_size, + encoder_max_partition_size, + speculate_max_draft_token_num, + causal, + enable_prefill); + } + case paddle::DataType::INT32: { + if (compute_dtype == "bf16") { + return AppendAttentionKernel( + qkv, + key_cache, + value_cache, + seq_lens_encoder, + seq_lens_decoder, + seq_lens_this_time, + padding_offsets, + cum_offsets, + block_tables, + encoder_batch_ids, + encoder_tile_ids_per_batch, + encoder_num_blocks, + kv_batch_ids, + kv_tile_ids_per_batch, + kv_num_blocks, + decoder_batch_ids, + decoder_tile_ids_per_batch, + decoder_num_blocks, + max_enc_len_this_time, + max_dec_len_this_time, + rotary_embs, + attn_mask, + qkv_bias, + qkv_out_scales, + cache_k_quant_scales, + cache_v_quant_scales, + cache_k_dequant_scales, + cache_v_dequant_scales, + cache_k_zp, + cache_v_zp, + out_linear_shifts, + out_linear_smooths, + cache_quant_type_str, + use_neox_rotary_style, + max_input_length, + out_linear_in_scale, + encoder_block_shape_q, + decoder_block_shape_q, + max_partition_size, + encoder_max_partition_size, + speculate_max_draft_token_num, + causal, + enable_prefill); + } else if (compute_dtype == "fp16") { + return AppendAttentionKernel( + qkv, + key_cache, + value_cache, + seq_lens_encoder, + seq_lens_decoder, + seq_lens_this_time, + padding_offsets, + cum_offsets, + block_tables, + encoder_batch_ids, + encoder_tile_ids_per_batch, + encoder_num_blocks, + kv_batch_ids, + kv_tile_ids_per_batch, + kv_num_blocks, + decoder_batch_ids, + decoder_tile_ids_per_batch, + decoder_num_blocks, + max_enc_len_this_time, + max_dec_len_this_time, + rotary_embs, + attn_mask, + qkv_bias, + qkv_out_scales, + cache_k_quant_scales, + cache_v_quant_scales, + cache_k_dequant_scales, + cache_v_dequant_scales, + cache_k_zp, + cache_v_zp, + out_linear_shifts, + out_linear_smooths, + cache_quant_type_str, + use_neox_rotary_style, + max_input_length, + out_linear_in_scale, + encoder_block_shape_q, + decoder_block_shape_q, + max_partition_size, + encoder_max_partition_size, + speculate_max_draft_token_num, + causal, + enable_prefill); + } else { + PD_THROW("Only supported attr of compute_dtype in ['fp16', 'bf16']."); + break; + } + } + default: { + PD_THROW( + "NOT supported data type. " + "Only float16 and bfloat16 are supported. "); + break; + } + } + return {paddle::Tensor{}}; +} + +std::vector> AppendAttentionInferShape( + const std::vector& qkv_shape, + const std::vector& key_cache_shape, + const std::vector& value_cache_shape, + const std::vector& seq_lens_encoder_shape, + const std::vector& seq_lens_decoder_shape, + const std::vector& seq_lens_this_time_shape, + const std::vector& padding_offsets_shape, + const std::vector& cum_offsets_shape, + const std::vector& block_tables_shape, + const std::vector& encoder_batch_ids_shape, + const std::vector& encoder_tile_ids_per_batch_shape, + const std::vector& encoder_num_blocks_shape, + const std::vector& kv_batch_ids_shape, + const std::vector& kv_tile_ids_per_batch_shape, + const std::vector& kv_num_blocks_shape, + const std::vector& decoder_batch_ids_shape, + const std::vector& decoder_tile_ids_per_batch_shape, + const std::vector& decoder_num_blocks_shape, + const std::vector& max_enc_len_this_time_shape, + const std::vector& max_dec_len_this_time_shape, + const paddle::optional>& rotary_embs_shape, + const paddle::optional>& attn_mask_shape, + const paddle::optional>& qkv_bias_shape, + const paddle::optional>& qkv_out_scales_shape, + const paddle::optional>& cache_k_quant_scales_shape, + const paddle::optional>& cache_v_quant_scales_shape, + const paddle::optional>& cache_k_dequant_scales_shape, + const paddle::optional>& cache_v_dequant_scales_shape, + const paddle::optional>& cache_k_zp_shape, + const paddle::optional>& cache_v_zp_shape, + const paddle::optional>& out_linear_shifts_shape, + const paddle::optional>& out_linear_smooths_shape) { + const int token_num = qkv_shape[0]; + const int kv_num_heads = key_cache_shape[1]; + const int head_dim = key_cache_shape[3]; + const int total_num_head = qkv_shape[qkv_shape.size() - 1] / head_dim; + const int num_heads = total_num_head - 2 * kv_num_heads; + return {{token_num, num_heads * head_dim}, qkv_shape}; +} + +std::vector AppendAttentionInferDtype( + const paddle::DataType& qkv_dtype, + const paddle::DataType& key_cache_dtype, + const paddle::DataType& value_cache_dtype, + const paddle::DataType& seq_lens_encoder_dtype, + const paddle::DataType& seq_lens_decoder_dtype, + const paddle::DataType& seq_lens_this_time_dtype, + const paddle::DataType& padding_offsets_dtype, + const paddle::DataType& cum_offsets_dtype, + const paddle::DataType& block_tables_dtype, + const paddle::DataType& encoder_batch_ids_dtype, + const paddle::DataType& encoder_tile_ids_per_batch_dtype, + const paddle::DataType& encoder_num_blocks_dtype, + const paddle::DataType& kv_batch_ids_dtype, + const paddle::DataType& kv_tile_ids_per_batch_dtype, + const paddle::DataType& kv_num_blocks_dtype, + const paddle::DataType& decoder_batch_ids_dtype, + const paddle::DataType& decoder_tile_ids_per_batch_dtype, + const paddle::DataType& decoder_num_blocks_dtype, + const paddle::DataType& max_enc_len_this_time_dtype, + const paddle::DataType& max_dec_len_this_time_dtype, + const paddle::optional& rotary_embs_dtype, + const paddle::optional& attn_mask_dtype, + const paddle::optional& qkv_bias_dtype, + const paddle::optional& qkv_out_scales_dtype, + const paddle::optional& cache_k_quant_scales_dtype, + const paddle::optional& cache_v_quant_scales_dtype, + const paddle::optional& cache_k_dequant_scales_dtype, + const paddle::optional& cache_v_dequant_scales_dtype, + const paddle::optional& cache_k_zp_dtype, + const paddle::optional& cache_v_zp_dtype, + const paddle::optional& out_linear_shifts_dtype, + const paddle::optional& out_linear_smooths_dtype, + const std::string& compute_dtype, + const std::string& cache_quant_type_str, + const bool use_neox_rotary_style, + const int max_input_length, + const float out_linear_in_scale, + const int encoder_block_shape_q, + const int decoder_block_shape_q, + const int max_partition_size, + const int encoder_max_partition_size, + const int speculate_max_draft_token_num, + const bool causal, + const bool enable_prefill) { + if (compute_dtype == "bf16") { + if (out_linear_in_scale > 0.0) { + return {paddle::DataType::INT8, paddle::DataType::BFLOAT16}; + } else { + return {paddle::DataType::BFLOAT16, paddle::DataType::BFLOAT16}; + } + } else if (compute_dtype == "fp16") { + if (out_linear_in_scale > 0.0) { + return {paddle::DataType::INT8, paddle::DataType::FLOAT16}; + } else { + return {paddle::DataType::FLOAT16, paddle::DataType::FLOAT16}; + } + } else { + PD_THROW("Only supported attr of compute_dtype in ['fp16', 'bf16']."); + } +} + +PD_BUILD_OP(append_attention) + .Inputs({"qkv", + "key_cache", + "value_cache", + "seq_lens_encoder", + "seq_lens_decoder", + "seq_lens_this_time", + "padding_offsets", + "cum_offsets", + "block_tables", + "encoder_batch_ids", + "encoder_tile_ids_per_batch", + "encoder_num_blocks", + "kv_batch_ids", + "kv_tile_ids_per_batch", + "kv_num_blocks", + "decoder_batch_ids", + "decoder_tile_ids_per_batch", + "decoder_num_blocks", + "max_enc_len_this_time", + "max_dec_len_this_time", + paddle::Optional("rotary_embs"), + paddle::Optional("attn_mask"), + paddle::Optional("qkv_bias"), + paddle::Optional("qkv_out_scales"), + paddle::Optional("cache_k_quant_scales"), + paddle::Optional("cache_v_quant_scales"), + paddle::Optional("cache_k_dequant_scales"), + paddle::Optional("cache_v_dequant_scales"), + paddle::Optional("cache_k_zp"), + paddle::Optional("cache_v_zp"), + paddle::Optional("out_linear_shifts"), + paddle::Optional("out_linear_smooths")}) + .Outputs({"fmha_out", "qkv_out", "key_cache_out", "value_cache_out"}) + .SetInplaceMap({{"key_cache", "key_cache_out"}, + {"value_cache", "value_cache_out"}}) + .Attrs({"compute_type: std::string", + "cache_quant_type: std::string", + "use_neox_rotary_style: bool", + "max_input_length: int", + "out_linear_in_scale: float", + "encoder_block_shape_q: int", + "decoder_block_shape_q: int", + "max_partition_size: int", + "encoder_max_partition_size: int", + "speculate_max_draft_token_num: int", + "causal: bool", + "enable_prefill: bool"}) + .SetKernelFn(PD_KERNEL(AppendAttention)) + .SetInferShapeFn(PD_INFER_SHAPE(AppendAttentionInferShape)) + .SetInferDtypeFn(PD_INFER_DTYPE(AppendAttentionInferDtype)); \ No newline at end of file diff --git a/csrc/gpu/append_attn/append_attention_bfloat16_bfloat16_kernel.cu b/csrc/gpu/append_attn/append_attention_bfloat16_bfloat16_kernel.cu new file mode 100644 index 000000000000..baf27ce9b3c9 --- /dev/null +++ b/csrc/gpu/append_attn/append_attention_bfloat16_bfloat16_kernel.cu @@ -0,0 +1,51 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "append_attention_kernel.h" + +template void CascadeAppendAttentionKernel( + const paddle::Tensor& qkv, // [token_num, num_heads, head_dim] + const paddle::Tensor& cache_k, // [max_block_num, num_heads, block_size, head_dim] + const paddle::Tensor& cache_v, // [max_block_num, num_heads, head_dim, block_size] + const paddle::optional& attn_mask, + const paddle::optional& cache_k_scale, // [num_kv_heads, head_dim] + const paddle::optional& cache_v_scale, // [num_kv_heads, head_dim] + const paddle::optional& cache_k_zp, // [num_kv_heads, head_dim] + const paddle::optional& cache_v_zp, // [num_kv_heads, head_dim] + const paddle::optional& shift_bias, // [num_kv_heads, head_dim] + const paddle::optional& smooth_weight, // [num_kv_heads, head_dim] + const paddle::Tensor& seq_lens_q, + const paddle::Tensor& seq_lens_kv, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& padding_offsets, + const paddle::Tensor& cum_offsets, + const paddle::Tensor& block_table, + const paddle::Tensor& batch_ids, + const paddle::Tensor& tile_ids_per_batch, + const std::string& cache_quant_type_str, + const int num_blocks, + const int block_shape_q, + const int max_seq_len, + const int max_dec_len, + const int num_heads, + const int kv_num_heads, + const int head_dim, + const float in_scale, + const int max_partition_size, + const int encoder_max_partition_size, + const int speculate_max_draft_token_num, + const bool causal, + const bool is_decoder, + const bool enable_prefill, + cudaStream_t& stream, + paddle::Tensor* out); \ No newline at end of file diff --git a/csrc/gpu/append_attn/append_attention_bfloat16_int8_kernel.cu b/csrc/gpu/append_attn/append_attention_bfloat16_int8_kernel.cu new file mode 100644 index 000000000000..437865585aa3 --- /dev/null +++ b/csrc/gpu/append_attn/append_attention_bfloat16_int8_kernel.cu @@ -0,0 +1,51 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "append_attention_kernel.h" + +template void CascadeAppendAttentionKernel( + const paddle::Tensor& qkv, // [token_num, num_heads, head_dim] + const paddle::Tensor& cache_k, // [max_block_num, num_heads, block_size, head_dim] + const paddle::Tensor& cache_v, // [max_block_num, num_heads, head_dim, block_size] + const paddle::optional& attn_mask, + const paddle::optional& cache_k_scale, // [num_kv_heads, head_dim] + const paddle::optional& cache_v_scale, // [num_kv_heads, head_dim] + const paddle::optional& cache_k_zp, // [num_kv_heads, head_dim] + const paddle::optional& cache_v_zp, // [num_kv_heads, head_dim] + const paddle::optional& shift_bias, // [num_kv_heads, head_dim] + const paddle::optional& smooth_weight, // [num_kv_heads, head_dim] + const paddle::Tensor& seq_lens_q, + const paddle::Tensor& seq_lens_kv, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& padding_offsets, + const paddle::Tensor& cum_offsets, + const paddle::Tensor& block_table, + const paddle::Tensor& batch_ids, + const paddle::Tensor& tile_ids_per_batch, + const std::string& cache_quant_type_str, + const int num_blocks, + const int block_shape_q, + const int max_seq_len, + const int max_dec_len, + const int num_heads, + const int kv_num_heads, + const int head_dim, + const float in_scale, + const int max_partition_size, + const int encoder_max_partition_size, + const int speculate_max_draft_token_num, + const bool causal, + const bool is_decoder, + const bool enable_prefill, + cudaStream_t& stream, + paddle::Tensor* out); \ No newline at end of file diff --git a/csrc/gpu/append_attn/append_attention_func.cuh b/csrc/gpu/append_attn/append_attention_func.cuh new file mode 100644 index 000000000000..2a29fd2030bd --- /dev/null +++ b/csrc/gpu/append_attn/append_attention_func.cuh @@ -0,0 +1,2567 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include "helper.h" +#include "mem_util.cuh" +#include "mma_tensor_op.cuh" +#include "utils.cuh" +// #define DEBUG_WRITE_C4 +// #define DEBUG_ATTN_C4 +// #define DEBUG_ATTN_C8 + +#define PRINT_TID 0 +#define PRINT_WID 0 +// #define DEBUG_ATTN + +template +__forceinline__ __device__ float fixed_expf(float x1, float x2) { + if constexpr (std::is_same::value) { + if (x1 == -5e4f) { + return 0; + } else { + return __expf(x1 - x2); + } + } else if constexpr (std::is_same::value) { + if (x1 == -3.0e+30f) { + return 0; + } else { + return __expf(x1 - x2); + } + } +} + +template +struct prefill_softmax_state_t { + AlignedVector o; + float m; + float d; + + __device__ __forceinline__ void init() { + if constexpr (std::is_same::value) { +#pragma unroll + for (int i = 0; i < vec_size / 2; ++i) { + *((half2*)(&o) + i) = make_half2(0, 0); + } + } else if constexpr (std::is_same::value) { +#pragma unroll + for (int i = 0; i < vec_size / 2; ++i) { + *((nv_bfloat162*)(&o) + i) = make_bfloat162(0, 0); + } + } + d = 1.f; + if constexpr (std::is_same::value) { + m = -5e4f; + } else if constexpr (std::is_same::value) { + m = -3.38953e38f; + } + } + + // __device__ __forceinline__ prefill_softmax_state_t() { + // init(); + // } + + __device__ __forceinline__ void merge( + const AlignedVector& other_o, float other_m, float other_d) { + float m_prev = m, d_prev = d; + m = m_prev > other_m ? m_prev : other_m; + const float scale1 = __expf(m_prev - m), scale2 = __expf(other_m - m); + const T scale1_T = static_cast(scale1), + scale2_T = static_cast(scale2); + d = d_prev * scale1 + other_d * scale2; +#pragma unroll + for (size_t i = 0; i < vec_size; ++i) { + o[i] = o[i] * scale1_T + other_o[i] * scale2_T; + } + } + + __device__ __forceinline__ void normalize() { + const T d_t = static_cast(d); +#pragma unroll + for (size_t i = 0; i < vec_size; ++i) { + o[i] /= d_t; + } + } +}; + +template +__device__ __forceinline__ void init_states(float (*o_frag)[num_frags_y][8], + float (*m)[2], + float (*d)[2]) { +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { +#pragma unroll + for (uint32_t fy = 0; fy < num_frags_y; ++fy) { +#pragma unroll + for (uint32_t reg_id = 0; reg_id < 8; ++reg_id) { + o_frag[fx][fy][reg_id] = 0.f; + } + } + } +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { +#pragma unroll + for (uint32_t j = 0; j < 2; ++j) { + if constexpr (std::is_same::value) { + m[fx][j] = -5e4f; + } else if constexpr (std::is_same::value) { + m[fx][j] = -3.0e+30f; + } + d[fx][j] = 1.f; + } + } +} + +template +__device__ __forceinline__ void load_q_global_smem_multi_warps( + T* q_ptr_base, + smem_t* q_smem, + uint32_t q_idx_base, + const uint32_t qo_upper_bound, + const uint32_t qo_n_stride, + const uint32_t qo_h_stride) { + constexpr uint32_t num_vecs_per_head = HEAD_DIM / num_elems_per_128b(); + + const uint32_t tx = threadIdx.x, ty = threadIdx.y; + uint32_t q_smem_offset_w = // [NUM_WARP_Q, num_frags_x, 16, head_dim] + smem_t::get_permuted_offset(ty * 4 + tx / 8, + tx % 8); // 4 * 64 + + // q_idx_base += (tx / 8) / group_size; + // q_ptr_base += ((tx / 8) / group_size) * qo_n_stride + ((tx / 8) % + // group_size) * qo_h_stride; + const uint32_t tx_offset = tx / 8; +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { + // num_frags_x * 16 * head_dim + // load 4 row per warp + const uint32_t base_offset = q_idx_base + fx * 16 + tx_offset; +#pragma unroll + // for (uint32_t j = 0; j < 4; ++j) { + const int j = ty; + const uint32_t offset_now = base_offset + j * 4; + const uint32_t n_offset = offset_now / group_size; + const uint32_t h_offset = offset_now % group_size; + T* q_ptr = q_ptr_base + n_offset * qo_n_stride + h_offset * qo_h_stride; +#pragma unroll + for (uint32_t fyo = 0; fyo < num_frags_y / 4; + ++fyo) { // (num_frags_y * 16) / (8 * num_elems_per_128b()) + // load q from gmem to smem + q_smem->load_128b_async( + q_smem_offset_w, q_ptr, n_offset < qo_upper_bound); + q_smem_offset_w = + q_smem->advance_offset_by_column<8>(q_smem_offset_w, fyo); + q_ptr += 8 * num_elems_per_128b(); + } + q_smem_offset_w = + q_smem->advance_offset_by_row<16, num_vecs_per_head>(q_smem_offset_w) - + 2 * num_frags_y; // num_frags_y / 4 * 8 + // } + } +} + +template +__device__ __forceinline__ void load_q_global_smem( + T* q_ptr_base, + smem_t* q_smem, + uint32_t q_idx_base, + const uint32_t qo_upper_bound, + const uint32_t qo_n_stride, + const uint32_t qo_h_stride) { + constexpr uint32_t num_vecs_per_head = HEAD_DIM / num_elems_per_128b(); + + const uint32_t tx = threadIdx.x, ty = threadIdx.y; + + uint32_t q_smem_offset_w = // [NUM_WARP_Q, num_frags_x, 16, head_dim] + smem_t::get_permuted_offset( + ty * num_frags_x * 16 + tx / 8, tx % 8); // 4 * 64 + + // q_idx_base += (tx / 8) / group_size; + // q_ptr_base += ((tx / 8) / group_size) * qo_n_stride + ((tx / 8) % + // group_size) * qo_h_stride; + const uint32_t tx_offset = tx / 8; +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { + // NUM_WARP_Q * num_frags_x * 16 * head_dim + const uint32_t base_offset = q_idx_base + fx * 16 + tx_offset; +#pragma unroll + for (uint32_t j = 0; j < 4; ++j) { + const uint32_t offset_now = base_offset + j * 4; + const uint32_t n_offset = offset_now / group_size; + const uint32_t h_offset = offset_now % group_size; + T* q_ptr = q_ptr_base + n_offset * qo_n_stride + h_offset * qo_h_stride; +#pragma unroll + for (uint32_t fyo = 0; fyo < num_frags_y / 4; + ++fyo) { // (num_frags_y * 16) / (8 * num_elems_per_128b()) + // load q from gmem to smem + q_smem->load_128b_async( + q_smem_offset_w, q_ptr, n_offset < qo_upper_bound); + q_smem_offset_w = + q_smem->advance_offset_by_column<8>(q_smem_offset_w, fyo); + q_ptr += 8 * num_elems_per_128b(); + } + q_smem_offset_w = + q_smem->advance_offset_by_row<4, num_vecs_per_head>(q_smem_offset_w) - + 2 * num_frags_y; // num_frags_y / 4 * 8 + } + } +} + +template +__device__ __forceinline__ void q_smem_inplace_multiply_sm_scale_multi_warps( + smem_t* q_smem, // [num_frags_x * 16, num_frags_y * 16] + const float sm_scale) { + constexpr int vec_size = 16 / sizeof(T); + using LoadT = AlignedVector; + LoadT tmp_vec; + const uint32_t tx = threadIdx.x, ty = threadIdx.y; + constexpr uint32_t head_dim = num_frags_y * 16; + constexpr uint32_t num_vecs_per_head = head_dim / num_elems_per_128b(); + +#pragma unroll + for (uint32_t i = 0; i < num_frags_x * 16 * head_dim / 1024; + ++i) { // 32 * 8 * 4 all warp + const int offset = i * 1024 + ty * 256 + tx * 8; + Load(reinterpret_cast(q_smem->base) + offset, &tmp_vec); +#pragma unroll + for (uint32_t reg_id = 0; reg_id < 8; ++reg_id) { + tmp_vec[reg_id] *= sm_scale; + } + Store(tmp_vec, reinterpret_cast(q_smem->base) + offset); + } +} + +template +__device__ __forceinline__ void q_smem_inplace_multiply_sm_scale( + smem_t* q_smem, // [num_frags_x * 16, num_frags_y * 16] + const float sm_scale) { + constexpr int vec_size = 16 / sizeof(T); + using LoadT = AlignedVector; + LoadT tmp_vec; + const uint32_t tx = threadIdx.x, ty = threadIdx.y; + constexpr uint32_t head_dim = num_frags_y * 16; + constexpr uint32_t num_vecs_per_head = head_dim / num_elems_per_128b(); + +#pragma unroll + for (uint32_t i = 0; i < num_frags_x * 16 * head_dim / 256; + ++i) { // 32 * 8 per warp + Load( + reinterpret_cast(q_smem->base + + ty * num_frags_x * 16 * num_vecs_per_head) + + i * 256 + tx * 8, + &tmp_vec); +#pragma unroll + for (uint32_t reg_id = 0; reg_id < 8; ++reg_id) { + tmp_vec[reg_id] *= sm_scale; + } + Store( + tmp_vec, + reinterpret_cast(q_smem->base + + ty * num_frags_x * 16 * num_vecs_per_head) + + i * 256 + tx * 8); + } +} + +template +__device__ __forceinline__ void produce_kv_blockwise( + smem_t smem, + uint32_t* smem_offset, + T** gptr, // [max_block_num, num_heads, block_size, head_dim] + const uint32_t kv_head_idx, + const uint32_t kv_n_stride, + const uint32_t kv_h_stride, + const uint32_t kv_b_stride, + const uint32_t kv_idx_base, + const uint32_t kv_len) { + constexpr uint32_t head_dim = num_frags_y * 16; + constexpr uint32_t num_vecs_per_head = head_dim / num_elems_per_128b(); + constexpr uint32_t NUM_WARP_KV = num_warps / NUM_WARP_Q; + const uint32_t tx = threadIdx.x, ty = threadIdx.y; + uint32_t kv_idx = kv_idx_base + ty * 4 + tx / 8; // kv_idx used to check +#pragma unroll + for (uint32_t i = 0; i < NUM_WARP_KV * num_frags_z * 4 / num_warps; + ++i) { // m num_frags_z * 16 / (num_warps * 4) + // 16 rows each time +#pragma unroll + for (uint32_t j = 0; j < num_frags_y / 4; + ++j) { // k num_frags_y * 16 / 8 / num_elems_per_128b() + smem.load_128b_async(*smem_offset, *gptr, kv_idx < kv_len); + *smem_offset = smem.advance_offset_by_column<8>(*smem_offset, j); + *gptr += 8 * num_elems_per_128b(); + } + kv_idx += num_warps * 4; + *smem_offset = smem.advance_offset_by_row( + *smem_offset) - + 2 * num_frags_y; // num_frags_y / 4 * 8 + *gptr += + num_warps * 4 * kv_b_stride - 2 * num_frags_y * num_elems_per_128b(); + } + *gptr -= NUM_WARP_KV * num_frags_z * 16 * kv_b_stride; + *smem_offset -= NUM_WARP_KV * num_frags_z * 16 * num_vecs_per_head; +} + +template +__device__ __forceinline__ void produce_v_blockwise_c8( + smem_t smem, + uint32_t* smem_offset, + CacheT* cache_v, + const int* block_table_now, + const uint32_t kv_head_idx, + const uint32_t kv_n_stride, + const uint32_t kv_h_stride, + const uint32_t kv_d_stride, + const uint32_t kv_idx_base, + const uint32_t kv_len, + const uint32_t const_v_offset) { + // [num_frags_y * 16, num_frags_z * 16] + constexpr uint32_t num_vecs_per_blocksize = + block_size / num_elems_per_128b(); // 8 + constexpr uint32_t NUM_WARP_KV = num_warps / NUM_WARP_Q; + const uint32_t tx = threadIdx.x, ty = threadIdx.y; + uint32_t kv_idx = + kv_idx_base + + tx % 4 * num_elems_per_128b(); // kv_idx used to check + if constexpr (NUM_WARP_Q == 4) { + int block_id = __ldg(&block_table_now[kv_idx / block_size]); + if (block_id < 0) block_id = 0; + CacheT* cache_v_now = cache_v + block_id * kv_n_stride + const_v_offset; +#pragma unroll + for (uint32_t i = 0; i < num_frags_y * 2 / num_warps; + ++i) { // m (num_frags_y * 16 / (num_warps * 8)) +#pragma unroll + for (uint32_t j = 0; j < num_frags_z / 4; + ++j) { // k num_frags_z * 16 / 4 / num_elems_per_128b() + // smem.load_128b_async(*smem_offset, *gptr, kv_idx < + // kv_len); + smem.load_128b_async(*smem_offset, cache_v_now, true); + *smem_offset = smem.advance_offset_by_column<4, num_vecs_per_blocksize>( + *smem_offset, j); + cache_v_now += 4 * num_elems_per_128b(); + // kv_idx += 4 * num_elems_per_128b(); + } + // kv_idx -= num_frags_z * num_elems_per_128b(); + *smem_offset = + smem.advance_offset_by_row( + *smem_offset) - + num_frags_z; // num_frags_z / 4 * 4 + cache_v_now += num_warps * 8 * kv_d_stride - + num_frags_z * num_elems_per_128b(); + } + // *gptr -= num_frags_y * 16 * kv_d_stride; + *smem_offset -= num_frags_y * 16 * num_vecs_per_blocksize; + } else { +#pragma unroll + for (uint32_t kv_i = 0; kv_i < NUM_WARP_KV / 2; ++kv_i) { + int block_id = __ldg(&block_table_now[kv_idx / block_size]); + if (block_id < 0) block_id = 0; + CacheT* cache_v_now = cache_v + block_id * kv_n_stride + const_v_offset; +#ifdef DEBUG_ATTN_C8 + if (threadIdx.x == PRINT_TID && threadIdx.y == 0 && blockIdx.z == 0 && + blockIdx.x == gridDim.x - 1 && blockIdx.y == gridDim.y - 1) { + printf( + "kv_i: %d, kv_idx: %d, tid: %d, block_id: %d, v_smem_offset_w: %d, " + "cache_v_now: %f, cache_v_now_p: %p\n", + (int)kv_i, + (int)kv_idx, + (int)threadIdx.x, + (int)block_id, + (int)*smem_offset, + (float)(*cache_v_now), + cache_v_now); + } + __syncthreads(); +#endif +#pragma unroll + for (uint32_t i = 0; i < num_frags_y * 2 / num_warps; + ++i) { // m (num_frags_y * 16 / (num_warps * 8)) +#pragma unroll + for (uint32_t j = 0; j < 2 * num_frags_z / 4; + ++j) { // k num_frags_z * 16 / 4 / num_elems_per_128b() +#ifdef DEBUG_ATTN_C8 + if ((threadIdx.x == PRINT_TID) && threadIdx.y == 0 && + blockIdx.z == 0 && blockIdx.x == gridDim.x - 1 && + blockIdx.y == gridDim.y - 1) { + printf( + "i: %d, j: %d, v_smem_offset_w: %d, cache_v_now: %f, " + "cache_v_now_p: %p\n", + (int)i, + (int)j, + (int)(*smem_offset), + (float)(*cache_v_now), + cache_v_now); + } + __syncthreads(); +#endif + smem.load_128b_async(*smem_offset, cache_v_now, true); + *smem_offset = + smem.advance_offset_by_column<4, num_vecs_per_blocksize>( + *smem_offset, j); + cache_v_now += 4 * num_elems_per_128b(); + kv_idx += 4 * num_elems_per_128b(); + } + kv_idx -= 2 * num_frags_z * num_elems_per_128b(); + *smem_offset = + smem.advance_offset_by_row( + *smem_offset) - + 2 * num_frags_z; // num_frags_z / 4 * 4 + cache_v_now += num_warps * 8 * kv_d_stride - + 2 * num_frags_z * num_elems_per_128b(); + } + kv_idx += block_size; + } + *smem_offset -= NUM_WARP_KV / 2 * num_frags_y * 16 * num_vecs_per_blocksize; + } +} + +template +__device__ __forceinline__ void produce_k_blockwise_c8( + smem_t smem, + uint32_t* smem_offset, + CacheT* cache_k, + const int* block_table_now, + const uint32_t kv_head_idx, + const uint32_t kv_n_stride, + const uint32_t kv_h_stride, + const uint32_t kv_b_stride, + const uint32_t kv_idx_base, + const uint32_t kv_len, + const uint32_t const_k_offset) { + // [num_frags_z * 16, num_frags_y * 16] + constexpr uint32_t head_dim = num_frags_y * 16; + constexpr uint32_t num_vecs_per_head = + head_dim / num_elems_per_128b(); // 8 + constexpr uint32_t NUM_WARP_KV = num_warps / NUM_WARP_Q; + const uint32_t tx = threadIdx.x, ty = threadIdx.y; + uint32_t kv_idx = kv_idx_base + ty * 4 + tx / 8; // kv_idx used to check + if constexpr (NUM_WARP_Q == 4) { + int block_id = __ldg(&block_table_now[kv_idx / block_size]); + if (block_id < 0) block_id = 0; + CacheT* cache_k_now = cache_k + block_id * kv_n_stride + const_k_offset; +#pragma unroll + for (uint32_t i = 0; i < num_frags_z * 4 / num_warps; + ++i) { // m num_frags_z * 16 / (num_warps * 4) +#pragma unroll + for (uint32_t j = 0; j < num_frags_y / 8; + ++j) { // k num_frags_y * 16 / 8 / num_elems_per_128b() + // smem.load_128b_async(*smem_offset, *gptr, kv_idx < + // kv_len); + smem.load_128b_async(*smem_offset, cache_k_now, true); + *smem_offset = smem.advance_offset_by_column<8, num_vecs_per_head>( + *smem_offset, j); + cache_k_now += 8 * num_elems_per_128b(); + } + // kv_idx += num_warps * 4; + *smem_offset = + smem.advance_offset_by_row( + *smem_offset) - + num_frags_y; // num_frags_y / 4 * 4 + cache_k_now += num_warps * 4 * kv_b_stride - + num_frags_y * num_elems_per_128b(); + } + // *gptr -= num_frags_z * 16 * kv_b_stride; + *smem_offset -= num_frags_z * 16 * num_vecs_per_head; + } else { +#pragma unroll + for (uint32_t kv_i = 0; kv_i < NUM_WARP_KV / 2; ++kv_i) { + int block_id = __ldg(&block_table_now[kv_idx / block_size]); + if (block_id < 0) block_id = 0; + CacheT* cache_k_now = cache_k + block_id * kv_n_stride + const_k_offset; +#pragma unroll + for (uint32_t i = 0; i < 2 * num_frags_z * 4 / num_warps; + ++i) { // m num_frags_z * 16 / (num_warps * 4) +#pragma unroll + for (uint32_t j = 0; j < num_frags_y / 8; + ++j) { // k num_frags_y * 16 / 8 / num_elems_per_128b() + // smem.load_128b_async(*smem_offset, *gptr, kv_idx < + // kv_len); + smem.load_128b_async(*smem_offset, cache_k_now, true); + *smem_offset = smem.advance_offset_by_column<8, num_vecs_per_head>( + *smem_offset, j); + cache_k_now += 8 * num_elems_per_128b(); + } + kv_idx += num_warps * 4; + *smem_offset = + smem.advance_offset_by_row( + *smem_offset) - + num_frags_y; // num_frags_y / 4 * 4 + cache_k_now += num_warps * 4 * kv_b_stride - + num_frags_y * num_elems_per_128b(); + } + } + *smem_offset -= NUM_WARP_KV * num_frags_z * 16 * num_vecs_per_head; + } +} + +template +__device__ __forceinline__ void produce_v_blockwise_c4( + smem_t smem, + uint32_t* smem_offset, + CacheT* cache_v, + const int* block_table_now, + const uint32_t kv_head_idx, + const uint32_t kv_n_stride, + const uint32_t kv_h_stride, + const uint32_t kv_d_stride, + const uint32_t kv_idx_base, + const uint32_t kv_len, + const uint32_t const_v_offset) { + constexpr uint32_t num_vecs_per_blocksize = + block_size / 2 / num_elems_per_128b(); // 2 + const uint32_t tx = threadIdx.x, ty = threadIdx.y; + constexpr uint32_t NUM_WARP_KV = num_warps / NUM_WARP_Q; + uint32_t kv_idx = + kv_idx_base + + tx % 2 * 2 * num_elems_per_128b(); // kv_idx used to check +#pragma unroll + for (uint32_t kv_i = 0; kv_i < NUM_WARP_KV; ++kv_i) { + int block_id = __ldg(&block_table_now[(kv_idx) / block_size]); + if (block_id < 0) block_id = 0; + CacheT* cache_v_now = cache_v + block_id * kv_n_stride + const_v_offset; +#ifdef DEBUG_ATTN_C4 + if (threadIdx.x == PRINT_TID && threadIdx.y == 0 && blockIdx.z == 0 && + blockIdx.x == gridDim.x - 1 && blockIdx.y == gridDim.y - 1) { + printf( + "kv_i: %d, kv_idx: %d, tid: %d, block_id: %d, v_smem_offset_w: %d, " + "cache_v_now: %f, cache_v_now_p: %p\n", + (int)kv_i, + (int)kv_idx, + (int)threadIdx.x, + (int)block_id, + (int)*smem_offset, + (float)(*cache_v_now), + cache_v_now); + } + __syncthreads(); +#endif +#pragma unroll + for (uint32_t i = 0; i < num_frags_y / num_warps; ++i) { // m +#pragma unroll + for (uint32_t j = 0; j < num_frags_z / 4; + ++j) { // k num_frags_z * 16 / 2 / 2 / num_elems_per_128b() + // smem.load_128b_async(*smem_offset, *gptr, kv_idx < + // kv_len); + smem.load_128b_async(*smem_offset, cache_v_now, true); + *smem_offset = smem.advance_offset_by_column<2, num_vecs_per_blocksize>( + *smem_offset, j); + cache_v_now += 2 * num_elems_per_128b(); + kv_idx += 4 * num_elems_per_128b(); + } + kv_idx -= num_frags_z * num_elems_per_128b(); + *smem_offset = + smem.advance_offset_by_row( + *smem_offset) - + num_frags_z / 2; // num_frags_y / 4 * 2 + cache_v_now += num_warps * 16 * kv_d_stride - + num_frags_z / 2 * num_elems_per_128b(); + } + kv_idx += block_size; + } + // *gptr -= num_frags_y * 16 * kv_d_stride; + *smem_offset -= NUM_WARP_KV * num_frags_y * 16 * num_vecs_per_blocksize; +} + +template +__device__ __forceinline__ void produce_k_blockwise_c4( + smem_t smem, + uint32_t* smem_offset, + CacheT* cache_k, + const int* block_table_now, + const uint32_t kv_head_idx, + const uint32_t kv_n_stride, + const uint32_t kv_h_stride, + const uint32_t kv_b_stride, + const uint32_t kv_idx_base, + const uint32_t kv_len, + const uint32_t const_k_offset) { + constexpr uint32_t head_dim = num_frags_y * 16; + constexpr uint32_t num_vecs_per_head = + head_dim / 2 / num_elems_per_128b(); // 4 + constexpr uint32_t NUM_WARP_KV = num_warps / NUM_WARP_Q; + const uint32_t tx = threadIdx.x, ty = threadIdx.y; + uint32_t kv_idx = kv_idx_base + ty * 8 + tx / 4; // kv_idx used to check + +#pragma unroll + for (uint32_t kv_i = 0; kv_i < NUM_WARP_KV; ++kv_i) { + int block_id = __ldg(&block_table_now[kv_idx / block_size]); + if (block_id < 0) block_id = 0; + CacheT* cache_k_now = cache_k + block_id * kv_n_stride + const_k_offset; + +#pragma unroll + for (uint32_t i = 0; i < num_frags_z * 2 / num_warps; + ++i) { // m num_frags_z * 16 / (num_warps * 8) +#pragma unroll + for (uint32_t j = 0; j < num_frags_y / 8; + ++j) { // k num_frags_y * 16 / 2 / 4 / num_elems_per_128b() + // smem.load_128b_async(*smem_offset, *gptr, kv_idx < + // kv_len); + smem.load_128b_async(*smem_offset, cache_k_now, true); + *smem_offset = smem.advance_offset_by_column<4, num_vecs_per_head>( + *smem_offset, j); + cache_k_now += 4 * num_elems_per_128b(); + } + kv_idx += num_warps * 8; + *smem_offset = + smem.advance_offset_by_row( + *smem_offset) - + num_frags_y / 2; // num_frags_y / 8 * 4 + cache_k_now += num_warps * 8 * kv_b_stride - + num_frags_y / 2 * num_elems_per_128b(); + } + } + *smem_offset -= NUM_WARP_KV * num_frags_z * 16 * num_vecs_per_head; +} + +template +__device__ __forceinline__ void block_produce_kv( + smem_t smem, + uint32_t* smem_offset, + T* gptr_base, // [max_block_num, num_heads, block_size, head_dim] + const int* block_table, + const uint32_t kv_head_idx, + const uint32_t kv_n_stride, + const uint32_t kv_h_stride, + const uint32_t kv_b_stride, + const uint32_t kv_idx_base, + const uint32_t kv_len) { + constexpr uint32_t head_dim = num_frags_y * 16; + constexpr uint32_t num_vecs_per_head = head_dim / num_elems_per_128b(); + const uint32_t tx = threadIdx.x, ty = threadIdx.y; + if constexpr (NUM_WARP_Q == 4) { +#pragma unroll + for (uint32_t i = 0; i < num_frags_z * 4 / num_warps; + ++i) { // m num_frags_z * 16 / (num_warps * 4) + const uint32_t row_now = + kv_idx_base + (i * 4 * num_warps + ty * 4 + tx / 8); + const uint32_t kv_n_idx = row_now / block_size; + const uint32_t kv_bid = row_now % block_size; + T* gptr = gptr_base + __ldg(&block_table[kv_n_idx]) * kv_n_stride + + kv_head_idx * kv_h_stride + kv_bid * kv_b_stride + + tx % 8 * num_elems_per_128b(); +#pragma unroll + for (uint32_t j = 0; j < num_frags_y / 4; + ++j) { // k num_frags_y * 16 / 8 / num_elems_per_128b() + smem.load_128b_async(*smem_offset, gptr, row_now < kv_len); + *smem_offset = smem.advance_offset_by_column<8>(*smem_offset, j); + gptr += 8 * num_elems_per_128b(); + } + *smem_offset = + smem.advance_offset_by_row( + *smem_offset) - + 2 * num_frags_y; // num_frags_y / 4 * 8 + } + *smem_offset -= num_frags_z * 16 * num_vecs_per_head; + } else { + const uint32_t row_id_per_tx = tx / 8; + const uint32_t col_id_per_tx = tx % 8; +#pragma unroll + for (uint32_t i = 0; i < num_frags_z; + ++i) { // m num_warps * num_frags_z * 16 +#pragma unroll + for (uint32_t j = 0; j < 4; j++) { + const uint32_t row_now = kv_idx_base + (i * 16 + j * 4 + row_id_per_tx); + const uint32_t kv_n_idx = row_now / block_size; + const uint32_t kv_bid = row_now % block_size; + T* gptr = gptr_base + __ldg(&block_table[kv_n_idx]) * kv_n_stride + + kv_head_idx * kv_h_stride + kv_bid * kv_b_stride + + col_id_per_tx * num_elems_per_128b(); +#pragma unroll + for (uint32_t fy = 0; fy < num_frags_y / 4; + ++fy) { // k num_frags_y * 16 / 8 / num_elems_per_128b() + smem.load_128b_async(*smem_offset, gptr, row_now < kv_len); + *smem_offset = smem.advance_offset_by_column<8>(*smem_offset, fy); + gptr += 8 * num_elems_per_128b(); + } + *smem_offset = + smem.advance_offset_by_row<4, num_vecs_per_head>(*smem_offset) - + 2 * num_frags_y; // num_frags_y / 4 * 8 + } + } + *smem_offset -= num_frags_z * 16 * num_vecs_per_head; + } +} + +template +__device__ __forceinline__ void produce_kv(smem_t smem, + uint32_t* smem_offset, + T** gptr, + const uint32_t kv_n_stride, + const uint32_t kv_idx_base, + const uint32_t kv_len) { + constexpr uint32_t head_dim = num_frags_y * 16; + constexpr uint32_t num_vecs_per_head = head_dim / num_elems_per_128b(); + const uint32_t tx = threadIdx.x, ty = threadIdx.y; + uint32_t kv_idx = kv_idx_base + ty * 4 + tx / 8; // kv_idx used to check +#pragma unroll + for (uint32_t i = 0; i < num_frags_z * 4 / num_warps; + ++i) { // m num_frags_z * 16 / (num_warps * 4) +#pragma unroll + for (uint32_t j = 0; j < num_frags_y / 4; + ++j) { // k num_frags_y * 16 / 8 / num_elems_per_128b() + smem.load_128b_async(*smem_offset, *gptr, kv_idx < kv_len); + *smem_offset = smem.advance_offset_by_column<8>(*smem_offset, j); + *gptr += 8 * num_elems_per_128b(); + } + kv_idx += num_warps * 4; + *smem_offset = smem.advance_offset_by_row( + *smem_offset) - + 2 * num_frags_y; // num_frags_y / 4 * 8 + *gptr += + num_warps * 4 * kv_n_stride - 2 * num_frags_y * num_elems_per_128b(); + } + *smem_offset -= num_frags_z * 16 * num_vecs_per_head; +} + +template +__device__ __forceinline__ void compute_qk(smem_t* q_smem, + uint32_t* q_smem_offset_r, + smem_t* k_smem, + uint32_t* k_smem_offset_r, + float (*s_frag)[num_frags_z][8]) { + // q [num_warps_q, num_frags_x, 16, head_dim], k [num_warps_kv, num_frags_z, + // 16, head_dim] + constexpr uint32_t head_dim = num_frags_y * 16; + constexpr uint32_t num_vecs_per_head = head_dim / num_elems_per_128b(); + uint32_t a_frag[num_frags_x][4], b_frag[4]; + // compute q*k^T +#pragma unroll + for (uint32_t fy = 0; fy < num_frags_y; ++fy) { // k +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { // m + q_smem->ldmatrix_m8n8x4(*q_smem_offset_r, a_frag[fx]); + *q_smem_offset_r = q_smem->advance_offset_by_row<16, num_vecs_per_head>( + *q_smem_offset_r); + } + + *q_smem_offset_r = + q_smem->advance_offset_by_column<2>(*q_smem_offset_r, fy) - + num_frags_x * 16 * num_vecs_per_head; + +#pragma unroll + for (uint32_t fz = 0; fz < num_frags_z; ++fz) { // n + k_smem->ldmatrix_m8n8x4(*k_smem_offset_r, b_frag); + *k_smem_offset_r = k_smem->advance_offset_by_row<16, num_vecs_per_head>( + *k_smem_offset_r); +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { + if (fy == 0) { + mma_sync_m16n16k16_row_col_f16f16f32( + s_frag[fx][fz], a_frag[fx], b_frag); + } else { + mma_sync_m16n16k16_row_col_f16f16f32( + s_frag[fx][fz], a_frag[fx], b_frag); + } + } + } + *k_smem_offset_r = + k_smem->advance_offset_by_column<2>(*k_smem_offset_r, fy) - + num_frags_z * 16 * num_vecs_per_head; + } + *q_smem_offset_r -= num_frags_y * 2; + *k_smem_offset_r -= num_frags_y * 2; +} + +template +__device__ __forceinline__ void compute_qk_c4(smem_t* q_smem, + uint32_t* q_smem_offset_r, + smem_t* k_smem, + uint32_t* k_smem_offset_r, + float (*s_frag)[num_frags_z][8], + T (*cache_k_scale_frag)[4], + T (*cache_k_zp_frag)[4]) { + // q [num_warps_q, num_frags_x, 16, head_dim], k [num_warps_kv, num_frags_z, + // 16, head_dim] + constexpr uint32_t head_dim = num_frags_y * 16; + constexpr uint32_t num_vecs_per_head_q = head_dim / num_elems_per_128b(); + constexpr uint32_t num_vecs_per_head_k = + head_dim / 2 / num_elems_per_128b(); + + uint32_t a_frag[num_frags_x][4][4], b_frag[4], b_frag_dq[4]; + +#pragma unroll + for (uint32_t ky = 0; ky < num_frags_y / 4; ++ky) { // k + // load q +#pragma unroll + for (uint32_t fy = 0; fy < 4; ++fy) { +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { + q_smem->ldmatrix_m8n8x4(*q_smem_offset_r, a_frag[fx][fy]); +#ifdef DEBUG_ATTN_C4 + if (threadIdx.x == PRINT_TID && threadIdx.y == 0 && blockIdx.z == 0) { + printf("q_smem_offset_r: %d\n", (int)*q_smem_offset_r); + T* a_frag_t = reinterpret_cast(a_frag[fx][fy]); + for (int k = 0; k < 8; k++) { + printf( + "compute_qk_c4 fx: %d, ky: %d, fy: %d, a_frag[%d][%d][%d]: %f " + " ", + (int)fx, + (int)ky, + (int)fy, + (int)fx, + (int)fy, + (int)k, + (float)a_frag_t[k]); + } + printf("\n"); + } + __syncthreads(); +#endif + *q_smem_offset_r = + q_smem->advance_offset_by_row<16, num_vecs_per_head_q>( + *q_smem_offset_r); + } + *q_smem_offset_r = + q_smem->advance_offset_by_column<2>(*q_smem_offset_r, ky * 4 + fy) - + num_frags_x * 16 * num_vecs_per_head_q; + } + +#pragma unroll + for (uint32_t fz = 0; fz < num_frags_z; ++fz) { + // load + k_smem->ldmatrix_m8n8x4(*k_smem_offset_r, b_frag); + *k_smem_offset_r = k_smem->advance_offset_by_row<16, num_vecs_per_head_k>( + *k_smem_offset_r); +#pragma unroll + for (uint32_t fy = 0; fy < 4; ++fy) { + // dequant b_frag[fy] -> b_frag_dq + T* b_frag_dq_T = reinterpret_cast(b_frag_dq); + convert_int4(b_frag_dq_T, b_frag[fy]); + // scale zp +#pragma unroll + for (uint32_t b_i = 0; b_i < 8; ++b_i) { + const int b_offset = b_i % 4; +#ifdef DEBUG_ATTN_C4 + if (threadIdx.x == PRINT_TID && threadIdx.y == 0 && blockIdx.z == 0) { + printf("compute_qk_c4, fz: %d, ky: %d, fy: %d\n", + (int)fz, + (int)ky, + (int)fy); + printf( + "compute_qk_c4 b_frag_dq_T[%d]: %f, cache_k_zp_frag[%d][%d]: " + "%f, cache_k_scale_frag[%d][%d]: %f\n", + (int)b_i, + (float)b_frag_dq_T[b_i], + int(ky * 4 + fy), + (int)b_offset, + (float)cache_k_zp_frag[ky * 4 + fy][b_offset], + int(ky * 4 + fy), + (int)b_offset, + (float)cache_k_scale_frag[ky * 4 + fy][b_offset]); + } + __syncthreads(); +#endif + b_frag_dq_T[b_i] = + (b_frag_dq_T[b_i] - cache_k_zp_frag[ky * 4 + fy][b_offset]) * + cache_k_scale_frag[ky * 4 + fy][b_offset]; + } + +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { + if (ky == 0 && fy == 0) { + mma_sync_m16n16k16_row_col_f16f16f32( + s_frag[fx][fz], a_frag[fx][fy], b_frag_dq); + } else { + mma_sync_m16n16k16_row_col_f16f16f32( + s_frag[fx][fz], a_frag[fx][fy], b_frag_dq); + } + } + } + } + *k_smem_offset_r = k_smem->advance_offset_by_column<2, num_vecs_per_head_k>( + *k_smem_offset_r, ky) - + num_frags_z * 16 * num_vecs_per_head_k; + } +#ifdef DEBUG_ATTN_C4 + if (threadIdx.x == PRINT_TID && threadIdx.y == 0 && blockIdx.z == 0) { + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { + for (uint32_t fz = 0; fz < num_frags_z; ++fz) { + for (int k = 0; k < 8; k++) { + printf("s_frag[%d][%d][%d]: %f ", + (int)fx, + (int)fz, + (int)k, + s_frag[fx][fz][k]); + } + printf("\n"); + } + printf("\n"); + } + } + __syncthreads(); +#endif + // advance by col + *q_smem_offset_r -= num_frags_y * 2; + *k_smem_offset_r -= num_frags_y / 4 * 2; // !!!, check if recover correctly +} + +template +__device__ __forceinline__ void compute_qk_c8(smem_t* q_smem, + uint32_t* q_smem_offset_r, + smem_t* k_smem, + uint32_t* k_smem_offset_r, + const T cache_k_scale, + float (*s_frag)[num_frags_z][8]) { + // q [num_warps_q, num_frags_x, 16, head_dim], k [num_warps_kv, num_frags_z, + // 16, head_dim] + constexpr uint32_t head_dim = num_frags_y * 16; + constexpr uint32_t num_vecs_per_head_q = head_dim / num_elems_per_128b(); + constexpr uint32_t num_vecs_per_head_k = + head_dim / num_elems_per_128b(); + + uint32_t a_frag[num_frags_x][2][4], b_frag[4], b_frag_dq[4]; + +#pragma unroll + for (uint32_t ky = 0; ky < num_frags_y / 2; ++ky) { // k + // load q +#pragma unroll + for (uint32_t fy = 0; fy < 2; ++fy) { +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { + q_smem->ldmatrix_m8n8x4(*q_smem_offset_r, a_frag[fx][fy]); +#ifdef DEBUG_ATTN_C8 + if (threadIdx.x == PRINT_TID && threadIdx.y == 0 && blockIdx.z == 0 && + blockIdx.x == gridDim.x - 1 && blockIdx.y == gridDim.y - 1) { + printf("q_smem_offset_r: %d\n", (int)*q_smem_offset_r); + T* a_frag_t = reinterpret_cast(a_frag[fx][fy]); + for (int k = 0; k < 8; k++) { + printf( + "compute_qk_c8 fx: %d, ky: %d, fy: %d, a_frag[%d][%d][%d]: %f " + " ", + (int)fx, + (int)ky, + (int)fy, + (int)fx, + (int)fy, + (int)k, + (float)a_frag_t[k]); + } + printf("\n"); + } + __syncthreads(); +#endif + + *q_smem_offset_r = + q_smem->advance_offset_by_row<16, num_vecs_per_head_q>( + *q_smem_offset_r); + } + *q_smem_offset_r = + q_smem->advance_offset_by_column<2>(*q_smem_offset_r, ky * 2 + fy) - + num_frags_x * 16 * num_vecs_per_head_q; + } + +#pragma unroll + for (uint32_t fz = 0; fz < num_frags_z; ++fz) { + // load + k_smem->ldmatrix_m8n8x4(*k_smem_offset_r, b_frag); + *k_smem_offset_r = k_smem->advance_offset_by_row<16, num_vecs_per_head_k>( + *k_smem_offset_r); +#pragma unroll + for (uint32_t fy = 0; fy < 2; ++fy) { + // dequant b_frag[fy] -> b_frag_dq + T* b_frag_dq_T = reinterpret_cast(b_frag_dq); + convert_int8(b_frag_dq_T, b_frag[fy * 2]); + convert_int8(b_frag_dq_T + 4, b_frag[fy * 2 + 1]); + // scale zp +#pragma unroll + for (uint32_t b_i = 0; b_i < 8; ++b_i) { +#ifdef DEBUG_ATTN_C8 + if (threadIdx.x == PRINT_TID && threadIdx.y == 0 && blockIdx.z == 0 && + blockIdx.x == gridDim.x - 1 && blockIdx.y == gridDim.y - 1) { + printf("compute_qk_c8, fz: %d, ky: %d, fy: %d\n", + (int)fz, + (int)ky, + (int)fy); + printf("compute_qk_c8 b_frag_dq_T[%d]: %f, cache_k_scale: %f\n", + (int)b_i, + (float)b_frag_dq_T[b_i], + (float)cache_k_scale); + } + __syncthreads(); +#endif + + b_frag_dq_T[b_i] *= cache_k_scale; + } +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { + if (ky == 0 && fy == 0) { + mma_sync_m16n16k16_row_col_f16f16f32( + s_frag[fx][fz], a_frag[fx][fy], b_frag_dq); + } else { + mma_sync_m16n16k16_row_col_f16f16f32( + s_frag[fx][fz], a_frag[fx][fy], b_frag_dq); + } + } + } + } + *k_smem_offset_r = k_smem->advance_offset_by_column<2, num_vecs_per_head_k>( + *k_smem_offset_r, ky) - + num_frags_z * 16 * num_vecs_per_head_k; + } +#ifdef DEBUG_ATTN_C8 + if (threadIdx.x == PRINT_TID && threadIdx.y == 0 && blockIdx.z == 0 && + blockIdx.x == gridDim.x - 1 && blockIdx.y == gridDim.y - 1) { + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { + for (uint32_t fz = 0; fz < num_frags_z; ++fz) { + for (int k = 0; k < 8; k++) { + printf("s_frag[%d][%d][%d]: %f ", + (int)fx, + (int)fz, + (int)k, + s_frag[fx][fz][k]); + } + printf("\n"); + } + printf("\n"); + } + } + __syncthreads(); +#endif + + // advance by col + *q_smem_offset_r -= num_frags_y * 2; + *k_smem_offset_r -= num_frags_y / 2 * 2; // !!!, check if recover correctly +} + +template +__device__ __forceinline__ void mask_s(const uint32_t qo_idx_base, + const uint32_t kv_idx_base, + const uint32_t qo_len, + const uint32_t kv_len, + const uint32_t chunk_end, + float (*s_frag)[num_frags_z][8]) { + const uint32_t tx = threadIdx.x; +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { +#pragma unroll + for (uint32_t fz = 0; fz < num_frags_z; ++fz) { +#pragma unroll + for (uint32_t reg_id = 0; reg_id < 8; ++reg_id) { + if constexpr (!IS_SYSTEM) { + const uint32_t q_idx = (qo_idx_base + fx * 16 + tx / 4 + + 8 * ((reg_id % 4) / 2)) / + group_size, + kv_idx = kv_idx_base + fz * 16 + 2 * (tx % 4) + + 8 * (reg_id / 4) + reg_id % 2; + const bool out_of_boundary = + (causal ? (kv_idx > kv_len + q_idx - qo_len || + (kv_idx >= chunk_end)) + : kv_idx >= chunk_end); + if constexpr (std::is_same::value) { + s_frag[fx][fz][reg_id] = + out_of_boundary ? -5e4f : s_frag[fx][fz][reg_id]; + } else if constexpr (std::is_same::value) { + s_frag[fx][fz][reg_id] = + out_of_boundary ? -3.0e+30f : s_frag[fx][fz][reg_id]; + } + } else { // 共享前缀decoder加速,不增加q_idx,每位置q_idx相同 + const uint32_t q_idx = qo_idx_base, + kv_idx = kv_idx_base + fz * 16 + 2 * (tx % 4) + + 8 * (reg_id / 4) + reg_id % 2; + const bool out_of_boundary = + (causal ? (kv_idx > kv_len + q_idx - qo_len || + (kv_idx >= chunk_end)) + : kv_idx >= chunk_end); +#ifdef DEBUG_ATTN_C4 + if (threadIdx.x == PRINT_TID && threadIdx.y == PRINT_WID && + blockIdx.z == 0 && blockIdx.x == 0 && + blockIdx.y == gridDim.y - 1 && fx == 0 && fz == 3 && + reg_id == 4) { + printf( + "q_idx: %d, kv_idx: %d, kv_len: %d, qo_len: %d, chunk_end: " + "%d\n", + (int)q_idx, + (int)kv_idx, + (int)kv_len, + (int)qo_len, + (int)chunk_end); + } + __syncthreads(); +#endif + if constexpr (std::is_same::value) { + s_frag[fx][fz][reg_id] = + out_of_boundary ? -5e4f : s_frag[fx][fz][reg_id]; + } else if constexpr (std::is_same::value) { + s_frag[fx][fz][reg_id] = + out_of_boundary ? -3.0e+30f : s_frag[fx][fz][reg_id]; + } + } + } + } + } +} + +template +__device__ __forceinline__ void update_mdo_states( + float (*s_frag)[num_frags_z][8], + float (*o_frag)[num_frags_y][8], + float (*m)[2], + float (*d)[2]) { + // [num_warps * num_frags_x * 16, num_frags_z * 16] +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { + // 16 * (num_frags_z * 16) +#pragma unroll + for (uint32_t j = 0; j < 2; ++j) { // 2行 + float m_prev = m[fx][j]; +#pragma unroll + for (uint32_t fz = 0; fz < num_frags_z; ++fz) { + float m_local = + max(max(s_frag[fx][fz][j * 2 + 0], s_frag[fx][fz][j * 2 + 1]), + max(s_frag[fx][fz][j * 2 + 4], s_frag[fx][fz][j * 2 + 5])); + m[fx][j] = max(m[fx][j], m_local); + } + m[fx][j] = max(m[fx][j], __shfl_xor_sync(-1, m[fx][j], 0x2, 32)); + m[fx][j] = max(m[fx][j], __shfl_xor_sync(-1, m[fx][j], 0x1, 32)); + float o_scale = __expf(m_prev - m[fx][j]); + d[fx][j] *= o_scale; +#pragma unroll + for (uint32_t fy = 0; fy < num_frags_y; ++fy) { + o_frag[fx][fy][j * 2 + 0] *= o_scale; + o_frag[fx][fy][j * 2 + 1] *= o_scale; + o_frag[fx][fy][j * 2 + 4] *= o_scale; + o_frag[fx][fy][j * 2 + 5] *= o_scale; + } +#pragma unroll + for (uint32_t fz = 0; fz < num_frags_z; ++fz) { + s_frag[fx][fz][j * 2 + 0] = + __expf(s_frag[fx][fz][j * 2 + 0] - m[fx][j]); + s_frag[fx][fz][j * 2 + 1] = + __expf(s_frag[fx][fz][j * 2 + 1] - m[fx][j]); + s_frag[fx][fz][j * 2 + 4] = + __expf(s_frag[fx][fz][j * 2 + 4] - m[fx][j]); + s_frag[fx][fz][j * 2 + 5] = + __expf(s_frag[fx][fz][j * 2 + 5] - m[fx][j]); + } + } + } +} + +template +__device__ __forceinline__ void compute_sfm_v_c4( + smem_t* v_smem, + uint32_t* v_smem_offset_r, + float (*s_frag)[num_frags_z][8], + float (*o_frag)[num_frags_y][8], + float (*d)[2], + T (*cache_v_scale_frag)[2], + T (*cache_v_zp_frag)[2]) { + // [num_frags_x, 16, num_frags_z, 16] [num_frags_y, 16, num_frags_z, 16] -> + // [num_frags_x, 16, num_frags_y, 16] + constexpr uint32_t num_vecs_per_blocksize = + block_size / 2 / num_elems_per_128b(); +#ifdef DEBUG_ATTN_C4 + if (threadIdx.x == PRINT_TID && threadIdx.y == 0 && blockIdx.z == 0) { + for (int i = 0; i < num_frags_x; i++) { + for (int j = 0; j < num_frags_z; j++) { + for (int k = 0; k < 8; k++) { + printf("compute_sfm_v_c4 s_frag[%d][%d][%d]: %f ", + i, + j, + k, + s_frag[i][j][k]); + } + printf("\n"); + } + printf("\n"); + } + } + __syncthreads(); +#endif + + T s_frag_f16[num_frags_x][num_frags_z][8]; + uint32_t b_frag[4], b_frag_dq[4]; +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { +#pragma unroll + for (uint32_t fz = 0; fz < num_frags_z; ++fz) { + vec_cast(s_frag_f16[fx][fz], s_frag[fx][fz]); + } + } + +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { +#pragma unroll + for (uint32_t fz = 0; fz < num_frags_z; ++fz) { + rowsum_f16f16f32(d[fx], s_frag_f16[fx][fz]); + } + } + +#pragma unroll + for (uint32_t kz = 0; kz < num_frags_z / 4; ++kz) { // k +#pragma unroll + for (uint32_t fy = 0; fy < num_frags_y; ++fy) { + v_smem->ldmatrix_m8n8x4(*v_smem_offset_r, b_frag); + *v_smem_offset_r = + v_smem->advance_offset_by_row<16, num_vecs_per_blocksize>( + *v_smem_offset_r); +#pragma unroll + for (uint32_t fz = 0; fz < 4; ++fz) { + // dequant b_frag -> b_frag_dq + T* b_frag_dq_T = reinterpret_cast(b_frag_dq); + convert_int4(b_frag_dq_T, b_frag[fz]); + // scale zp +#pragma unroll + for (uint32_t b_i = 0; b_i < 8; ++b_i) { + const int b_offset = b_i / 4; +#ifdef DEBUG_ATTN_C4 + if (threadIdx.x == PRINT_TID && threadIdx.y == 0 && blockIdx.z == 0) { + printf( + "compute_sfm_v_c4, kz: %d, fz: %d, b_frag_dq_T[%d]: %f, " + "cache_v_zp_frag[%d][%d]: %f, cache_v_scale_frag[%d][%d]: %f\n", + (int)kz, + (int)fz, + (int)b_i, + (float)b_frag_dq_T[b_i], + int(fy), + (int)b_offset, + (float)cache_v_zp_frag[fy][b_offset], + int(fy), + (int)b_offset, + (float)cache_v_scale_frag[fy][b_offset]); + } + __syncthreads(); +#endif + b_frag_dq_T[b_i] = + (b_frag_dq_T[b_i] - cache_v_zp_frag[fy][b_offset]) * + cache_v_scale_frag[fy][b_offset]; + } +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { // m: num_frags_x * 16 + mma_sync_m16n16k16_row_col_f16f16f32( + o_frag[fx][fy], + (uint32_t*)(s_frag_f16[fx][kz * 4 + fz]), + b_frag_dq); +#ifdef DEBUG_ATTN_C4 + if (threadIdx.x == PRINT_TID && threadIdx.y == 0 && blockIdx.z == 0) { + printf("compute_sfm_v_c4 o_frag\n"); + for (int k = 0; k < 8; k++) { + printf("compute_sfm_v_c4 o_frag[%d][%d][%d]: %f ", + (int)fx, + (int)fy, + (int)k, + o_frag[fx][fy][k]); + } + printf("\n"); + } + __syncthreads(); +#endif + } + } + } + *v_smem_offset_r = + v_smem->advance_offset_by_column<2, num_vecs_per_blocksize>( + *v_smem_offset_r, kz) - + num_frags_y * 16 * num_vecs_per_blocksize; + } +#ifdef DEBUG_ATTN_C4 + if (threadIdx.x == PRINT_TID && threadIdx.y == 0 && blockIdx.z == 0) { + printf("res o_frag\n"); + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { + for (uint32_t fy = 0; fy < num_frags_y; ++fy) { + for (int k = 0; k < 8; k++) { + printf("compute_sfm_v_c4 o_frag[%d][%d][%d]: %f ", + (int)fx, + (int)fy, + (int)k, + o_frag[fx][fy][k]); + } + printf("\n"); + } + printf("\n"); + } + printf("\n"); + } + __syncthreads(); +#endif + *v_smem_offset_r -= num_frags_z / 4 * 2; +} + +template +__device__ __forceinline__ void compute_sfm_v_c8( + smem_t* v_smem, + uint32_t* v_smem_offset_r, + float (*s_frag)[num_frags_z][8], + float (*o_frag)[num_frags_y][8], + float (*d)[2], + T cache_v_scale) { + // [num_frags_x, 16, num_frags_z, 16] [num_frags_y, 16, num_frags_z, 16] -> + // [num_frags_x, 16, num_frags_y, 16] + constexpr uint32_t num_vecs_per_blocksize = + block_size / num_elems_per_128b(); +#ifdef DEBUG_ATTN_C8 + if (threadIdx.x == PRINT_TID && threadIdx.y == 0 && blockIdx.z == 0 && + blockIdx.x == gridDim.x - 1 && blockIdx.y == gridDim.y - 1) { + for (int i = 0; i < num_frags_x; i++) { + for (int j = 0; j < num_frags_y; j++) { + for (int k = 0; k < 8; k++) { + printf("after_update compute_sfm_v_c8 o_frag[%d][%d][%d]: %f\n", + i, + j, + k, + o_frag[i][j][k]); + } + printf("\n"); + } + printf("\n"); + } + } + __syncthreads(); + + if ((threadIdx.x == 0 || threadIdx.x == 1 || threadIdx.x == 2 || + threadIdx.x == 3) && + threadIdx.y == 0 && blockIdx.z == 0 && blockIdx.x == gridDim.x - 1 && + blockIdx.y == gridDim.y - 1) { + for (int i = 0; i < num_frags_x; i++) { + for (int j = 0; j < num_frags_z; j++) { + for (int k = 0; k < 8; k++) { + printf("compute_sfm_v_c8 tid: %d, s_frag[%d][%d][%d]: %f\n", + (int)threadIdx.x, + i, + j, + k, + s_frag[i][j][k]); + } + printf("\n"); + } + printf("\n"); + } + } + __syncthreads(); +#endif + + T s_frag_f16[num_frags_x][num_frags_z][8]; + uint32_t b_frag[4], b_frag_dq[4]; +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { +#pragma unroll + for (uint32_t fz = 0; fz < num_frags_z; ++fz) { + vec_cast(s_frag_f16[fx][fz], s_frag[fx][fz]); + } + } + +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { +#pragma unroll + for (uint32_t fz = 0; fz < num_frags_z; ++fz) { + rowsum_f16f16f32(d[fx], s_frag_f16[fx][fz]); + } + } + +#pragma unroll + for (uint32_t kz = 0; kz < num_frags_z / 2; ++kz) { // k +#pragma unroll + for (uint32_t fy = 0; fy < num_frags_y; ++fy) { + v_smem->ldmatrix_m8n8x4(*v_smem_offset_r, b_frag); +#ifdef DEBUG_ATTN_C8 + if ((threadIdx.x == PRINT_TID) && threadIdx.y == 0 && blockIdx.z == 0 && + blockIdx.x == gridDim.x - 1 && blockIdx.y == gridDim.y - 1) { + printf("kz: %d, fy: %d, v_smem_offset_r: %d\n", + (int)kz, + (int)fy, + (int)*v_smem_offset_r); + } + __syncthreads(); +#endif + *v_smem_offset_r = + v_smem->advance_offset_by_row<16, num_vecs_per_blocksize>( + *v_smem_offset_r); +#pragma unroll + for (uint32_t fz = 0; fz < 2; ++fz) { + // dequant b_frag -> b_frag_dq + T* b_frag_dq_T = reinterpret_cast(b_frag_dq); + convert_int8(b_frag_dq_T, b_frag[fz * 2]); + convert_int8(b_frag_dq_T + 4, b_frag[fz * 2 + 1]); + // scale zp +#pragma unroll + for (uint32_t b_i = 0; b_i < 8; ++b_i) { + b_frag_dq_T[b_i] *= cache_v_scale; +#ifdef DEBUG_ATTN_C8 + // if (threadIdx.x == PRINT_TID && threadIdx.y == 0 && blockIdx.z == 0 + // && blockIdx.x == gridDim.x - 1 && blockIdx.y == gridDim.y - 1) { + if ((threadIdx.x == 16 || threadIdx.x == 17 || threadIdx.x == 18 || + threadIdx.x == 19) && + threadIdx.y == 0 && blockIdx.z == 0 && + blockIdx.x == gridDim.x - 1 && blockIdx.y == gridDim.y - 1) { + uint8_t* b_frag_dq_uint8 = + reinterpret_cast(&b_frag[fz * 2]); + printf( + "compute_sfm_v_c8, tid: %d, kz: %d, fz: %d, " + "b_frag_dq_uint8[%d]: %d, b_frag_dq_T[%d]: %f, cache_v_scale: " + "%f\n", + (int)threadIdx.x, + (int)kz, + (int)fz, + (int)b_i, + (int)b_frag_dq_uint8[b_i], + (int)b_i, + (float)b_frag_dq_T[b_i], + (float)cache_v_scale); + } + __syncthreads(); +#endif + } +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { // m: num_frags_x * 16 + mma_sync_m16n16k16_row_col_f16f16f32( + o_frag[fx][fy], + (uint32_t*)(s_frag_f16[fx][kz * 2 + fz]), + b_frag_dq); +#ifdef DEBUG_ATTN_C8 + if (threadIdx.x == PRINT_TID && threadIdx.y == 0 && blockIdx.z == 0 && + blockIdx.x == gridDim.x - 1 && blockIdx.y == gridDim.y - 1) { + printf("compute_sfm_v_c8 o_frag\n"); + for (int k = 0; k < 8; k++) { + printf("compute_sfm_v_c8 o_frag[%d][%d][%d]: %f ", + (int)fx, + (int)fy, + (int)k, + o_frag[fx][fy][k]); + } + printf("\n"); + } + __syncthreads(); +#endif + } + } + } +#ifdef DEBUG_ATTN_C8 + if ((threadIdx.x == PRINT_TID) && threadIdx.y == 0 && blockIdx.z == 0 && + blockIdx.x == gridDim.x - 1 && blockIdx.y == gridDim.y - 1) { + printf( + "111 kz: %d, v_smem_offset_r: %d\n", (int)kz, (int)*v_smem_offset_r); + } + __syncthreads(); +#endif + *v_smem_offset_r = + v_smem->advance_offset_by_column<2, num_vecs_per_blocksize>( + *v_smem_offset_r, kz) - + num_frags_y * 16 * num_vecs_per_blocksize; +#ifdef DEBUG_ATTN_C8 + if ((threadIdx.x == PRINT_TID) && threadIdx.y == 0 && blockIdx.z == 0 && + blockIdx.x == gridDim.x - 1 && blockIdx.y == gridDim.y - 1) { + printf( + "222 kz: %d, v_smem_offset_r: %d\n", (int)kz, (int)*v_smem_offset_r); + } + __syncthreads(); +#endif + } + *v_smem_offset_r -= num_frags_z / 2 * 2; +} + +template +__device__ __forceinline__ void compute_sfm_v_c8_iter_sq_bvec( + smem_t* v_smem, + uint32_t* v_smem_offset_r, + float (*s_frag)[num_frags_z][8], + float (*o_frag)[num_frags_y][8], + float (*d)[2], + T cache_v_scale) { + // [num_frags_x, 16, num_frags_z, 16] [num_frags_y, 16, num_frags_z, 16] -> + // [num_frags_x, 16, num_frags_y, 16] + constexpr uint32_t num_vecs_per_blocksize = + block_size / num_elems_per_128b(); + + T s_frag_f16[num_frags_x][num_frags_z][8]; + uint32_t b_frag[4], b_frag_dq[4]; +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { +#pragma unroll + for (uint32_t fz = 0; fz < num_frags_z; ++fz) { + vec_cast(s_frag_f16[fx][fz], s_frag[fx][fz]); + } + } + +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { +#pragma unroll + for (uint32_t fz = 0; fz < num_frags_z; ++fz) { + rowsum_f16f16f32(d[fx], s_frag_f16[fx][fz]); + } + } + +#pragma unroll + for (uint32_t kz = 0; kz < num_frags_z / 2; ++kz) { // k +#pragma unroll + for (uint32_t fy = 0; fy < num_frags_y; ++fy) { + v_smem->ldmatrix_m8n8x4(*v_smem_offset_r, b_frag); + *v_smem_offset_r = + v_smem->advance_offset_by_row<16, num_vecs_per_blocksize>( + *v_smem_offset_r); +#pragma unroll + for (uint32_t fz = 0; fz < 2; ++fz) { + // dequant b_frag -> b_frag_dq + T* b_frag_dq_T = reinterpret_cast(b_frag_dq); + convert_int8(b_frag_dq_T, b_frag[fz * 2]); + convert_int8(b_frag_dq_T + 4, b_frag[fz * 2 + 1]); + // scale zp +#pragma unroll + for (uint32_t b_i = 0; b_i < 8; ++b_i) { + b_frag_dq_T[b_i] *= cache_v_scale; + } +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { // m: num_frags_x * 16 + mma_sync_m16n16k16_row_col_f16f16f32( + o_frag[fx][fy], + (uint32_t*)(s_frag_f16[fx][kz * 2 + fz]), + b_frag_dq); + } + } + } + *v_smem_offset_r -= num_frags_y * 16 * num_vecs_per_blocksize; + } +} + +template +__device__ __forceinline__ void compute_sfm_v(smem_t* v_smem, + uint32_t* v_smem_offset_r, + float (*s_frag)[num_frags_z][8], + float (*o_frag)[num_frags_y][8], + float (*d)[2]) { + // [num_frags_x, 16, num_frags_z, 16] [num_frags_z, 16, num_frags_y, 16] -> + // [num_frags_x, 16, num_frags_y, 16] + constexpr uint32_t head_dim = num_frags_y * 16; + constexpr uint32_t num_vecs_per_head = head_dim / num_elems_per_128b(); + + T s_frag_f16[num_frags_x][num_frags_z][8]; +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { +#pragma unroll + for (uint32_t fz = 0; fz < num_frags_z; ++fz) { + vec_cast(s_frag_f16[fx][fz], s_frag[fx][fz]); + } + } +#ifdef DEBUG_ATTN + if (threadIdx.x == 0 && threadIdx.y == 0 && blockIdx.x == 0 && + blockIdx.z == 0) { + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { + for (uint32_t fz = 0; fz < num_frags_z; ++fz) { + for (int i = 0; i < 8; ++i) { + printf("fx: %d, fz: %d, s_frag_f16[%d][%d][%d]: %f\n", + (int)fx, + (int)fz, + (int)fx, + (int)fz, + (int)i, + (float)s_frag_f16[fx][fz][i]); + } + } + } + } + __syncthreads(); +#endif + +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { +#pragma unroll + for (uint32_t fz = 0; fz < num_frags_z; ++fz) { + rowsum_f16f16f32(d[fx], s_frag_f16[fx][fz]); + } + } + +#pragma unroll + for (uint32_t fz = 0; fz < num_frags_z; + ++fz) { // k: num_warps_kv * num_frags_z * 16 +#pragma unroll + for (uint32_t fy = 0; fy < num_frags_y; ++fy) { // n: num_frags_y * 16 + // [num_warps * num_frags_z * 16, num_frags_y * 16] + uint32_t b_frag[4]; + v_smem->ldmatrix_m8n8x4_trans(*v_smem_offset_r, b_frag); +#ifdef DEBUG_ATTN + if (threadIdx.x == 0 && threadIdx.y == 0 && blockIdx.x == 0 && + blockIdx.z == 0) { + T* b_frag_T = reinterpret_cast(b_frag); + for (int i = 0; i < 8; ++i) { + printf("bbb fz: %d, fy: %d, b_frag[%d]: %f\n", + (int)fz, + (int)fy, + (int)i, + (float)b_frag_T[i]); + } + } + __syncthreads(); +#endif +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { // m: num_frags_x * 16 + mma_sync_m16n16k16_row_col_f16f16f32( + o_frag[fx][fy], (uint32_t*)(s_frag_f16[fx][fz]), b_frag); +#ifdef DEBUG_ATTN + if (threadIdx.x == 0 && threadIdx.y == 0 && blockIdx.x == 0 && + blockIdx.z == 0) { + for (int i = 0; i < 8; ++i) { + printf("ooo fx: %d, fy: %d, o_frag[%d]: %f\n", + (int)fx, + (int)fy, + (int)i, + (float)o_frag[fx][fy][i]); + } + } +#endif + } + *v_smem_offset_r = + v_smem->advance_offset_by_column<2>(*v_smem_offset_r, fy); + } + *v_smem_offset_r = + v_smem->advance_offset_by_row<16, num_vecs_per_head>(*v_smem_offset_r) - + 2 * num_frags_y; + } + *v_smem_offset_r -= 16 * num_frags_z * num_vecs_per_head; +} + +template +__device__ __forceinline__ void normalize_d(float (*o_frag)[num_frags_y][8], + float (*d)[2]) { + float d_rcp[num_frags_x][2]; +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { +#pragma unroll + for (uint32_t j = 0; j < 2; ++j) { + d_rcp[fx][j] = 1.f / d[fx][j]; + } + } + +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { +#pragma unroll + for (uint32_t fy = 0; fy < num_frags_y; ++fy) { +#pragma unroll + for (uint32_t reg_id = 0; reg_id < 8; ++reg_id) { + o_frag[fx][fy][reg_id] = + o_frag[fx][fy][reg_id] * d_rcp[fx][(reg_id % 4) / 2]; + } + } + } +} + +template +__device__ __forceinline__ void merge_res_multi_warps( + T* o_smem, // [num_threads, num_frags_x, num_frags_y, 8] + T* md_smem, // [num_warps, num_frags_x * 16 * 2] + T (*o_frag)[num_frags_y][8], + T (*m_frag)[2], + T (*d_frag)[2]) { + constexpr uint32_t head_dim = num_frags_y * 16; + constexpr uint32_t num_vecs_per_head = head_dim / num_elems_per_128b(); + const uint32_t tx = threadIdx.x, ty = threadIdx.y; + const uint32_t tidx = ty * 32 + tx; + const uint32_t row_id = tx / 4; + + // [num_frags_x * 16, num_frags_y * 16] +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { +#pragma unroll + for (uint32_t fy = 0; fy < num_frags_y; ++fy) { + const int offset = + tidx * num_frags_x * num_frags_y * 8 + fx * num_frags_y * 8 + fy * 8; + *(b128_t*)(&o_smem[offset]) = *(b128_t*)&o_frag[fx][fy]; // !!! + *(b128_t*)(&o_smem[offset + 4]) = *(b128_t*)&o_frag[fx][fy][4]; // !!! + } + } + if (tx % 4 == 0) { +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { + const int offset = ty * num_frags_x * 16 + fx * 16 + row_id; + md_smem[offset * 2] = m_frag[fx][0]; + md_smem[offset * 2 + 1] = d_frag[fx][0]; + md_smem[(offset + 8) * 2] = m_frag[fx][1]; + md_smem[(offset + 8) * 2 + 1] = d_frag[fx][1]; + } + } + __syncthreads(); + + if (ty == 0) { +#pragma unroll + for (uint32_t warp_id = 0; warp_id < NUM_WARPS; ++warp_id) { + const int tmp_tidx = warp_id * 32 + tx; +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { + const int offset = warp_id * num_frags_x * 16 + fx * 16 + row_id; +#pragma unroll + for (uint32_t fy = 0; fy < num_frags_y; ++fy) { + const int o_offset = tmp_tidx * num_frags_x * num_frags_y * 8 + + fx * num_frags_y * 8 + fy * 8; + AlignedVector o_now; + Load(&o_smem[o_offset], &o_now); + + float m_prev = m_frag[fx][0], d_prev = d_frag[fx][0]; + float m_now = md_smem[offset * 2], d_now = md_smem[offset * 2 + 1]; + float tmp_m = max(m_prev, m_now); + float scale1 = __expf(m_prev - tmp_m), scale2 = __expf(m_now - tmp_m); + float tmp_d = scale1 * d_prev + scale2 * d_now; + o_frag[fx][fx][0] = scale1 * o_frag[fx][fx][0] + scale2 * o_now[0]; + o_frag[fx][fx][1] = scale1 * o_frag[fx][fx][1] + scale2 * o_now[1]; + o_frag[fx][fx][4] = scale1 * o_frag[fx][fx][4] + scale2 * o_now[4]; + o_frag[fx][fx][5] = scale1 * o_frag[fx][fx][5] + scale2 * o_now[5]; + m_frag[fx][0] = tmp_m; + d_frag[fx][0] = tmp_d; + + m_prev = m_frag[fx][1], d_prev = d_frag[fx][1]; + m_now = md_smem[(offset + 8) * 2], + d_now = md_smem[(offset + 8) * 2 + 1]; + tmp_m = max(m_prev, m_now); + scale1 = __expf(m_prev - tmp_m), scale2 = __expf(m_now - tmp_m); + tmp_d = scale1 * d_prev + scale2 * d_now; + o_frag[fx][fx][2] = scale1 * o_frag[fx][fx][2] + scale2 * o_now[2]; + o_frag[fx][fx][3] = scale1 * o_frag[fx][fx][3] + scale2 * o_now[3]; + o_frag[fx][fx][6] = scale1 * o_frag[fx][fx][6] + scale2 * o_now[6]; + o_frag[fx][fx][7] = scale1 * o_frag[fx][fx][7] + scale2 * o_now[7]; + m_frag[fx][1] = tmp_m; + d_frag[fx][1] = tmp_d; + } + } + } + } +} + +template +__device__ __forceinline__ void write_o_reg_gmem_kv_multi_warps( + float (*o_frag)[num_frags_y][8], + smem_t* o_smem, + T* o_ptr_base, + uint32_t o_idx_base, + const uint32_t qo_upper_bound, + const uint32_t qo_n_stride, + const uint32_t qo_h_stride) { + constexpr uint32_t head_dim = num_frags_y * 16; + constexpr uint32_t num_vecs_per_head = head_dim / num_elems_per_128b(); + const uint32_t tx = threadIdx.x, ty = threadIdx.y; + + if (ty == 0) { + // [num_frags_x * 16, num_frags_y * 16] +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { +#pragma unroll + for (uint32_t fy = 0; fy < num_frags_y; ++fy) { + // 16 * 16 + // 每个fy放16个数,vec size为8(f16/bf16),所以y轴为2fy + uint32_t o_frag_f16[4]; + vec_cast((T*)o_frag_f16, o_frag[fx][fy]); + uint32_t o_smem_offset_w = smem_t::get_permuted_offset< + num_vecs_per_head>( // num_vecs_per_head = num_frags_y * 16 / 8 = + // num_frags_y * 2 + fx * 16 + tx / 4, + fy * 2); + ((uint32_t*)(o_smem->base + o_smem_offset_w))[tx % 4] = o_frag_f16[0]; + ((uint32_t*)(o_smem->base + o_smem_offset_w + + 8 * num_vecs_per_head))[tx % 4] = o_frag_f16[1]; + ((uint32_t*)(o_smem->base + (o_smem_offset_w ^ 0x1)))[tx % 4] = + o_frag_f16[2]; // 2fy,异或1往右移一位 + ((uint32_t*)(o_smem->base + (o_smem_offset_w ^ 0x1) + + 8 * num_vecs_per_head))[tx % 4] = o_frag_f16[3]; + } + } + } + __syncthreads(); + + // smem连续存储到gmem上, [num_frags_x * 16, num_frags_y * 16] + uint32_t o_smem_offset_w = smem_t::get_permuted_offset( + ty * 4 + tx / 8, tx % 8); // 每个warp一次搬4行,每次搬64个数 + + o_idx_base += (tx / 8) / group_size; + o_ptr_base += ((tx / 8) / group_size) * qo_n_stride + + ((tx / 8) % group_size) * qo_h_stride; +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { +#pragma unroll + // for (uint32_t j = 0; j < 4; ++j) { // 4 * 4 = 16 + const int j = ty; + const uint32_t o_idx = o_idx_base + (fx * 16 + j * 4) / group_size; + T* o_ptr = o_ptr_base + ((fx * 16 + j * 4) / group_size) * qo_n_stride + + ((fx * 16 + j * 4) % group_size) * qo_h_stride; +#pragma unroll + for (uint32_t fyo = 0; fyo < num_frags_y / 4; + ++fyo) { // num_frags_y * 16 / (8[tid] * + // num_elems_per_128b()[vec_per_thread]) + if (o_idx < qo_upper_bound) { + // need write + o_smem->store_128b(o_smem_offset_w, o_ptr); + } + o_ptr += 8 * num_elems_per_128b(); + o_smem_offset_w = + o_smem->advance_offset_by_column<8>(o_smem_offset_w, fyo); + } + o_smem_offset_w = + o_smem->advance_offset_by_row<16, num_vecs_per_head>(o_smem_offset_w) - + 2 * num_frags_y; + // } + } +} + + +template +struct StoreFunc { + __device__ __forceinline__ void operator()( + const AlignedVector& ori_out_vec, + const AlignedVector& shift_bias_vec, + const AlignedVector& smooth_weight_vec, + AlignedVector& out_vec, + const float in_scale, + const int i) { + out_vec[i] = static_cast(ori_out_vec[i]); + printf("Fatal!!! Unimplemented StoreFunc for cascade append attention\n"); + } +}; + +template +struct StoreFunc { + __device__ __forceinline__ void operator()( + const AlignedVector& ori_out_vec, + const AlignedVector& shift_bias_vec, + const AlignedVector& smooth_weight_vec, + AlignedVector& out_vec, + const float in_scale, + const int i) { + float quant_value = + 127.0f * + static_cast((ori_out_vec[i] + shift_bias_vec[i]) * + smooth_weight_vec[i]) * + in_scale; + quant_value = rintf(quant_value); + quant_value = quant_value > 127.0f ? 127.0f : quant_value; + quant_value = quant_value < -127.0f ? -127.0f : quant_value; + out_vec[i] = static_cast(quant_value); + } +}; + +template +struct StoreFunc { + __device__ __forceinline__ void operator()( + const AlignedVector& ori_out_vec, + const AlignedVector& shift_bias_vec, + const AlignedVector& smooth_weight_vec, + AlignedVector& out_vec, + const float in_scale, + const int i) { + out_vec[i] = ori_out_vec[i]; + } +}; + +template +__device__ __forceinline__ void write_o_reg_gmem_multi_warps_shift_smooth_quant( + float (*o_frag)[num_frags_y][8], + smem_t* o_smem, + OutT* o_ptr_base, + const T* shift_bias, + const T* smooth_weight, + uint32_t o_idx_base, + const uint32_t q_head_idx_base, + const float in_scale, + const uint32_t qo_upper_bound, + const uint32_t qo_n_stride, + const uint32_t qo_h_stride) { + constexpr uint32_t head_dim = num_frags_y * 16; + constexpr uint32_t num_vecs_per_head = head_dim / num_elems_per_128b(); + const uint32_t tx = threadIdx.x, ty = threadIdx.y; + constexpr int VEC_SIZE = 16 / sizeof(T); + AlignedVector ori_out_vec; + AlignedVector shift_bias_vec; + AlignedVector smooth_weight_vec; + AlignedVector out_vec; + // [num_warps * num_frags_x * 16, num_frags_y * 16] + if (ty == 0) { + // [num_frags_x * 16, num_frags_y * 16] +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { +#pragma unroll + for (uint32_t fy = 0; fy < num_frags_y; ++fy) { + // 16 * 16 + // 每个fy放16个数,vec size为8(f16/bf16),所以y轴为2fy + uint32_t o_frag_f16[4]; + vec_cast((T*)o_frag_f16, o_frag[fx][fy]); + uint32_t o_smem_offset_w = smem_t::get_permuted_offset< + num_vecs_per_head>( // num_vecs_per_head = num_frags_y * 16 / 8 = + // num_frags_y * 2 + fx * 16 + tx / 4, + fy * 2); + ((uint32_t*)(o_smem->base + o_smem_offset_w))[tx % 4] = o_frag_f16[0]; + ((uint32_t*)(o_smem->base + o_smem_offset_w + + 8 * num_vecs_per_head))[tx % 4] = o_frag_f16[1]; + ((uint32_t*)(o_smem->base + (o_smem_offset_w ^ 0x1)))[tx % 4] = + o_frag_f16[2]; // 2fy,异或1往右移一位 + ((uint32_t*)(o_smem->base + (o_smem_offset_w ^ 0x1) + + 8 * num_vecs_per_head))[tx % 4] = o_frag_f16[3]; + } + } + } + __syncthreads(); +#ifdef DEBUG_ATTN + if (threadIdx.x == PRINT_TID && threadIdx.y == 0 && blockIdx.z == 0 && + blockIdx.x == gridDim.x - 1) { + printf("o_smem\n"); + T* o_smem_t = reinterpret_cast(o_smem->base); + for (uint32_t i = 0; i < num_frags_x * 16; ++i) { + for (uint32_t j = 0; j < num_frags_y * 16; ++j) { + printf("o_smem[%d][%d] = %f ", + (int)i, + (int)j, + (float)o_smem_t[i * num_frags_y * 16 + j]); + } + printf("\n"); + } + } + __syncthreads(); +#endif + + // smem连续存储到gmem上, [num_frags_x * 16, num_frags_y * 16] + uint32_t o_smem_offset_w = smem_t::get_permuted_offset( + ty * 4 + tx / 8, tx % 8); // 每个warp一次搬4行,每次搬64个数 + + const uint32_t tx_offset = tx / 8; + // o_idx_base += (tx / 8) / group_size; + // o_ptr_base += ((tx / 8) / group_size) * qo_n_stride + ((tx / 8) % + // group_size) * qo_h_stride; uint32_t q_head_idx_now_base = q_head_idx_base + + // (tx / 8) % group_size; +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { + const uint32_t base_offset = o_idx_base + fx * 16 + tx_offset; +#pragma unroll + const int j = ty; + const uint32_t offset_now = base_offset + j * 4; + const uint32_t n_offset = offset_now / group_size; + const uint32_t h_offset = offset_now % group_size; +#ifdef DEBUG_ATTN + __syncthreads(); + if (threadIdx.x == PRINT_TID && threadIdx.y == 0 && blockIdx.z == 0 && + blockIdx.x == gridDim.x - 1) { + printf("o_smem\n"); + T* o_smem_t = reinterpret_cast(o_smem->base); + for (uint32_t i = 0; i < num_frags_x * 16; ++i) { + for (uint32_t j = 0; j < num_frags_y * 16; ++j) { + printf("o_smem[%d][%d] = %f ", + (int)i, + (int)j, + (float)o_smem_t[i * num_frags_y * 16 + j]); + } + printf("index:%d \n", n_offset * qo_n_stride + h_offset * qo_h_stride); + } + } + __syncthreads(); +#endif + OutT* o_ptr = o_ptr_base + n_offset * qo_n_stride + h_offset * qo_h_stride; + + uint32_t shift_smooth_offset = (q_head_idx_base + h_offset) * head_dim + + tx % 8 * num_elems_per_128b(); +#pragma unroll + for (uint32_t fyo = 0; fyo < num_frags_y / 4; + ++fyo) { // num_frags_y * 16 / (8[tid] * + // num_elems_per_128b()[vec_per_thread]) + + if (n_offset < qo_upper_bound) { + if constexpr (!partition_kv) { + + if (in_scale > 0.0) { + if (shift_bias) { + Load(shift_bias + shift_smooth_offset, + &shift_bias_vec); + Load(smooth_weight + shift_smooth_offset, + &smooth_weight_vec); + } + } + Load( + reinterpret_cast(o_smem->base + o_smem_offset_w), + &ori_out_vec); + +#pragma unroll + for (int i = 0; i < VEC_SIZE; ++i) { + StoreFunc()(ori_out_vec, + shift_bias_vec, + smooth_weight_vec, + out_vec, + in_scale, + i); +#ifdef DEBUG_ATTN + __syncthreads(); + if (threadIdx.x == PRINT_TID && threadIdx.y == 0 && + blockIdx.z == 0 && blockIdx.x == gridDim.x - 1 && + blockIdx.y == 0) { + printf( + "write_o fx: %d, j: %d, fyo: %d, in_scale: %f, i: %d, " + "shift_bias = %f, smooth_weight = %f, ori_out = %f, out_vec: " + "%f\n", + (int)fx, + (int)j, + (int)fyo, + (float)in_scale, + (int)i, + (float)shift_bias_vec[i], + (float)smooth_weight_vec[i], + (float)ori_out_vec[i], + (float)out_vec[i]); + } + __syncthreads(); +#endif + } + Store(out_vec, o_ptr); + } else { + o_smem->store_128b(o_smem_offset_w, o_ptr); + } + } + o_ptr += 8 * num_elems_per_128b(); + shift_smooth_offset += 8 * num_elems_per_128b(); + o_smem_offset_w = + o_smem->advance_offset_by_column<8>(o_smem_offset_w, fyo); + } + o_smem_offset_w = + o_smem->advance_offset_by_row<16, num_vecs_per_head>(o_smem_offset_w) - + 2 * num_frags_y; + // } + } +} + +template +__device__ __forceinline__ void write_o_reg_gmem_shift_smooth_quant( + float (*o_frag)[num_frags_y][8], + smem_t* o_smem, + OutT* o_ptr_base, + const T* shift_bias, + const T* smooth_weight, + uint32_t o_idx_base, + const uint32_t q_head_idx_base, + const float in_scale, + const uint32_t qo_upper_bound, + const uint32_t qo_n_stride, + const uint32_t qo_h_stride) { + constexpr uint32_t head_dim = num_frags_y * 16; + constexpr uint32_t num_vecs_per_head = head_dim / num_elems_per_128b(); + const uint32_t tx = threadIdx.x, ty = threadIdx.y; + constexpr int VEC_SIZE = 8; + AlignedVector ori_out_vec; + AlignedVector shift_bias_vec; + AlignedVector smooth_weight_vec; + AlignedVector out_vec; + // [num_warps * num_frags_x * 16, num_frags_y * 16] +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { +#pragma unroll + for (uint32_t fy = 0; fy < num_frags_y; ++fy) { + // 每个fy放16个数,vec size为8(f16/bf16),所以y轴为2fy + uint32_t o_frag_f16[4]; + vec_cast((T*)o_frag_f16, o_frag[fx][fy]); + uint32_t o_smem_offset_w = smem_t::get_permuted_offset( // num_vecs_per_head = num_frags_y * 16 / 8 = num_frags_y * 2 + (ty * num_frags_x + fx) * 16 + tx / 4, fy * 2); + ((uint32_t*)(o_smem->base + o_smem_offset_w))[tx % 4] = o_frag_f16[0]; + ((uint32_t*)(o_smem->base + o_smem_offset_w + 8 * num_vecs_per_head))[tx % 4] = + o_frag_f16[1]; + ((uint32_t*)(o_smem->base + (o_smem_offset_w ^ 0x1)))[tx % 4] = o_frag_f16[2]; // 2fy,异或1往右移一位 + ((uint32_t*)(o_smem->base + (o_smem_offset_w ^ 0x1) + 8 * num_vecs_per_head))[tx % 4] = + o_frag_f16[3]; + } + } + __syncthreads(); + + // smem连续存储到gmem上, [num_frags_x * 16, num_frags_y * 16] + uint32_t o_smem_offset_w = + smem_t::get_permuted_offset(ty * num_frags_x * 16 + tx / 8, tx % 8); // 每个warp一次搬4行,每次搬64个数 + + const uint32_t tx_offset = tx / 8; + // o_idx_base += (tx / 8) / group_size; + // o_ptr_base += ((tx / 8) / group_size) * qo_n_stride + ((tx / 8) % group_size) * qo_h_stride; + // uint32_t q_head_idx_now_base = q_head_idx_base + (tx / 8) % group_size; +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { + const uint32_t base_offset = o_idx_base + fx * 16 + tx_offset; +#pragma unroll + for (uint32_t j = 0; j < 4; ++j) { // 4 * 4 = 16 + const uint32_t offset_now = base_offset + j * 4; + const uint32_t n_offset = offset_now / group_size; + const uint32_t h_offset = offset_now % group_size; + OutT* o_ptr = o_ptr_base + n_offset * qo_n_stride + h_offset * qo_h_stride; + uint32_t shift_smooth_offset = (q_head_idx_base + h_offset) * head_dim + tx % 8 * num_elems_per_128b(); +#pragma unroll + for (uint32_t fyo = 0; fyo < num_frags_y / 4; ++fyo) { // num_frags_y * 16 / (8[tid] * num_elems_per_128b()[vec_per_thread]) + if (n_offset < qo_upper_bound) { + if (!partition_kv && in_scale > 0.0) { + if (shift_bias) { + Load(shift_bias + shift_smooth_offset, &shift_bias_vec); + Load(smooth_weight + shift_smooth_offset, &smooth_weight_vec); + } + Load(reinterpret_cast(o_smem->base + o_smem_offset_w), &ori_out_vec); +#pragma unroll + for (int i = 0; i < VEC_SIZE; ++i) { + StoreFunc()(ori_out_vec, shift_bias_vec, smooth_weight_vec, out_vec, in_scale, i); +#ifdef DEBUG_ATTN_C4 + if (threadIdx.x == PRINT_TID && threadIdx.y == 0 && blockIdx.z == 0) { + printf("write_o fx: %d, j: %d, fyo: %d, shift_bias[%d] = %f, smooth_weight[%d] = %f, ori_out[%d] = %f, out_vec[%d]: %f\n", + (int)fx, (int)j, (int)fyo, i, (float)shift_bias_vec[i], i, (float)smooth_weight_vec[i], i, (float)ori_out_vec[i], (float)out_vec[i]); + } + __syncthreads(); +#endif + } + Store(out_vec, o_ptr); + } else { + o_smem->store_128b(o_smem_offset_w, o_ptr); + } + } + o_ptr += 8 * num_elems_per_128b(); + shift_smooth_offset += 8 * num_elems_per_128b(); + o_smem_offset_w = o_smem->advance_offset_by_column<8>(o_smem_offset_w, fyo); + } + o_smem_offset_w = o_smem->advance_offset_by_row<4, num_vecs_per_head>(o_smem_offset_w) - + 2 * num_frags_y; + } + } +} + +template +__device__ __forceinline__ void write_o_reg_gmem( + float (*o_frag)[num_frags_y][8], + smem_t* o_smem, + T* o_ptr_base, + uint32_t o_idx_base, + const uint32_t qo_upper_bound, + const uint32_t qo_n_stride, + const uint32_t qo_h_stride) { + constexpr uint32_t head_dim = num_frags_y * 16; + constexpr uint32_t num_vecs_per_head = head_dim / num_elems_per_128b(); + const uint32_t tx = threadIdx.x, ty = threadIdx.y; + + // [num_warps * num_frags_x * 16, num_frags_y * 16] +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { +#pragma unroll + for (uint32_t fy = 0; fy < num_frags_y; ++fy) { + // 每个fy放16个数,vec size为8(f16/bf16),所以y轴为2fy + uint32_t o_frag_f16[4]; + vec_cast((T*)o_frag_f16, o_frag[fx][fy]); + uint32_t o_smem_offset_w = smem_t::get_permuted_offset< + num_vecs_per_head>( // num_vecs_per_head = num_frags_y * 16 / 8 = + // num_frags_y * 2 + (ty * num_frags_x + fx) * 16 + tx / 4, + fy * 2); + ((uint32_t*)(o_smem->base + o_smem_offset_w))[tx % 4] = o_frag_f16[0]; + ((uint32_t*)(o_smem->base + o_smem_offset_w + + 8 * num_vecs_per_head))[tx % 4] = o_frag_f16[1]; + ((uint32_t*)(o_smem->base + (o_smem_offset_w ^ 0x1)))[tx % 4] = + o_frag_f16[2]; // 2fy,异或1往右移一位 + ((uint32_t*)(o_smem->base + (o_smem_offset_w ^ 0x1) + + 8 * num_vecs_per_head))[tx % 4] = o_frag_f16[3]; + } + } + __syncthreads(); + + // smem连续存储到gmem上, [num_frags_x * 16, num_frags_y * 16] + uint32_t o_smem_offset_w = smem_t::get_permuted_offset( + ty * num_frags_x * 16 + tx / 8, + tx % 8); // 每个warp一次搬4行,每次搬64个数 + + o_idx_base += (tx / 8) / group_size; + o_ptr_base += ((tx / 8) / group_size) * qo_n_stride + + ((tx / 8) % group_size) * qo_h_stride; +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { +#pragma unroll + for (uint32_t j = 0; j < 4; ++j) { // 4 * 4 = 16 + const uint32_t o_idx = o_idx_base + (fx * 16 + j * 4) / group_size; + T* o_ptr = o_ptr_base + ((fx * 16 + j * 4) / group_size) * qo_n_stride + + ((fx * 16 + j * 4) % group_size) * qo_h_stride; +#pragma unroll + for (uint32_t fyo = 0; fyo < num_frags_y / 4; + ++fyo) { // num_frags_y * 16 / (8[tid] * + // num_elems_per_128b()[vec_per_thread]) + if (o_idx < qo_upper_bound) { + // need write + o_smem->store_128b(o_smem_offset_w, o_ptr); + } + o_ptr += 8 * num_elems_per_128b(); + o_smem_offset_w = + o_smem->advance_offset_by_column<8>(o_smem_offset_w, fyo); + } + o_smem_offset_w = + o_smem->advance_offset_by_row<4, num_vecs_per_head>(o_smem_offset_w) - + 2 * num_frags_y; + } + } +} + +template +__global__ void split_q_block(const int* __restrict__ seq_lens_q, + int* __restrict__ batch_ids, + int* __restrict__ tile_ids_per_batch, + int* __restrict__ num_blocks_x, + const uint32_t bsz, + const uint32_t num_rows_per_block) { + if (threadIdx.x == 0) { + int gridx = 0; + int index = 0; + for (uint32_t bid = 0; bid < bsz; bid++) { + const int seq_len = seq_lens_q[bid]; + const int loop_times = div_up(seq_len * GROUP_SIZE, num_rows_per_block); + for (uint32_t tile_id = 0; tile_id < loop_times; tile_id++) { + batch_ids[index] = bid; + tile_ids_per_batch[index++] = tile_id; + } + gridx += loop_times; + } + *num_blocks_x = gridx; + } +} + +template +__global__ void merge_multi_chunks_kernel( + const T* __restrict__ multi_out, // [token_num, num_chunks, num_heads, + // head_dim] + const float* __restrict__ multi_m, // [token_num, num_chunks, num_heads] + const float* __restrict__ multi_d, // [token_num, num_chunks, num_heads] + const int* __restrict__ seq_lens_q, + const int* __restrict__ seq_lens_kv, + const int* __restrict__ padding_offsets, + const T* __restrict__ shift_bias, // [q_num_heads * HEAD_DIM] + const T* __restrict__ smooth_weight, // [q_num_heads * HEAD_DIM] + T* __restrict__ out, + const float in_scale, + const int max_seq_len, + const int num_chunks, + const int num_heads, + const int chunk_size, + const int head_dim) { + const int vid = threadIdx.x, hid = threadIdx.y; + const int qid = blockIdx.x; + const uint32_t ori_token_id = qid + padding_offsets[qid]; + const uint32_t bid = ori_token_id / max_seq_len; + if (seq_lens_q[bid] <= 0 || seq_lens_kv[bid] <= 0) { + return; + } + const int seq_len_kv = seq_lens_kv[bid]; + const int num_chunks_this_seq = div_up(seq_len_kv, chunk_size); + + using LoadT = AlignedVector; + LoadT load_vec; + LoadT res_vec; + if constexpr (std::is_same::value) { +#pragma unroll + for (int i = 0; i < vec_size / 2; ++i) { + *((half2*)(&res_vec) + i) = make_half2(0, 0); + } + } else { +#pragma unroll + for (int i = 0; i < vec_size / 2; ++i) { + *((nv_bfloat162*)(&res_vec) + i) = make_bfloat162(0, 0); + } + } + float m; + float d = 1.f; + if constexpr (std::is_same::value) { + m = -5e4f; + } else if constexpr (std::is_same::value) { + m = -3.0e+30f; + } +#pragma unroll 2 + for (int i = 0; i < num_chunks_this_seq; ++i) { + uint32_t offset = (qid * num_chunks + i) * num_heads + hid; + float m_prev = m; + float d_prev = d; + const float m_now = multi_m[offset]; + const float d_now = multi_d[offset]; + m = max(m_prev, m_now); + offset = (qid * num_chunks * num_heads + i * num_heads + hid) * head_dim + + vid * vec_size; + Load(&multi_out[offset], &load_vec); + const float scale1 = __expf(m_prev - m), scale2 = __expf(m_now - m); + const T scale1_T = static_cast(scale1), + scale2_T = static_cast(scale2); + d = d * scale1 + d_now * scale2; + for (int j = 0; j < vec_size; j++) { + res_vec[j] = res_vec[j] * scale1_T + load_vec[j] * scale2_T; + } + } +#pragma unroll + for (int j = 0; j < vec_size; j++) { + res_vec[j] /= d; + } + Store(res_vec, + &out[(qid * num_heads + hid) * head_dim + vid * vec_size]); +} + + +template +__device__ __forceinline__ void merge_block_res(float (*o_frag)[num_frags_y][8], + float* md_smem, + float (*m)[2], + float (*d)[2], + const uint32_t wid, + const uint32_t tid) { + // o [num_warps, 32, num_frags_x, num_frags_y, 8] -> [32, num_frags_x, + // num_frags_y, 8] md [num_warps, num_frags_x, 2, 32, 2] -> [num_frags_x, 2, + // 32, 2] + float2* smem_md = reinterpret_cast(md_smem); +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { +#pragma unroll + for (uint32_t j = 0; j < 2; ++j) { + smem_md[((wid * num_frags_x + fx) * 2 + j) * 32 + tid] = + make_float2(m[fx][j], d[fx][j]); + } + } + __syncthreads(); + float o_scale[4][num_frags_x][2]; + + // deal md/scale +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { +#pragma unroll + for (uint32_t j = 0; j < 2; ++j) { + float m_new; + float d_new = 1.f; + if constexpr (std::is_same::value) { + m_new = -5e4f; + } else { + m_new = -3.0e+30f; + } +#pragma unroll + for (uint32_t i = 0; i < 4; ++i) { + float2 md = smem_md[((i * num_frags_x + fx) * 2 + j) * 32 + tid]; + float m_prev = m_new, d_prev = d_new; + m_new = max(m_new, md.x); + d_new = d_prev * __expf(m_prev - m_new) + md.y * __expf(md.x - m_new); + } +#pragma unroll + for (uint32_t i = 0; i < 4; ++i) { + float2 md = smem_md[((i * num_frags_x + fx) * 2 + j) * 32 + tid]; + o_scale[i][fx][j] = __expf(md.x - m_new); + } + m[fx][j] = m_new; + d[fx][j] = d_new; + } + } + __syncthreads(); + +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { +#pragma unroll + for (uint32_t fy = 0; fy < num_frags_y; ++fy) { + // num_warps * 32 * 8 each time + AlignedVector o_new; +#pragma + for (uint32_t o_id = 0; o_id < 8; ++o_id) { + o_new[o_id] = 0.f; + } + *(reinterpret_cast(md_smem + (wid * 32 + tid) * 8)) = + *(reinterpret_cast(&o_frag[fx][fy][0])); + *(reinterpret_cast(md_smem + (wid * 32 + tid) * 8 + 4)) = + *(reinterpret_cast(&o_frag[fx][fy][4])); + __syncthreads(); +#pragma unroll + for (uint32_t i = 0; i < 4; ++i) { + AlignedVector oi; + Load(md_smem + (i * 32 + tid) * 8, &oi); +#pragma unroll + for (uint32_t reg_id = 0; reg_id < 8; ++reg_id) { + o_new[reg_id] += oi[reg_id] * o_scale[i][fx][(reg_id % 4) / 2]; + } + } + *(reinterpret_cast(&o_frag[fx][fy][0])) = + *(reinterpret_cast(&o_new[0])); + *(reinterpret_cast(&o_frag[fx][fy][4])) = + *(reinterpret_cast(&o_new[4])); + __syncthreads(); + } + } +} + +template +__device__ __forceinline__ void merge_block_res_v2( + float (*o_frag)[num_frags_y][8], + float* md_smem, + float (*m)[2], + float (*d)[2], + const uint32_t wid, + const uint32_t tid) { + // o: [num_warps, num_frags_x, num_frags_y, warp_size(32), 8] + // md: [num_warps, num_frags_x, 2, warp_size(32), 2 (m/d)] + float2* smem_md = reinterpret_cast( + md_smem + num_frags_x * num_frags_y * 1024); // 4 * 32 * 8 +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { +#pragma unroll + for (uint32_t j = 0; j < 2; ++j) { + smem_md[((wid * num_frags_x + fx) * 2 + j) * 32 + tid] = + make_float2(m[fx][j], d[fx][j]); + } + } +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { +#pragma unroll + for (uint32_t fy = 0; fy < num_frags_y; ++fy) { + *(reinterpret_cast( + md_smem + (((wid * num_frags_x + fx) * num_frags_y + fy) * 32 + tid) * + 8)) = *(reinterpret_cast(&o_frag[fx][fy][0])); + *(reinterpret_cast( + md_smem + + (((wid * num_frags_x + fx) * num_frags_y + fy) * 32 + tid) * 8 + 4)) = + *(reinterpret_cast(&o_frag[fx][fy][4])); + } + } + __syncthreads(); + float o_scale[4][num_frags_x][2]; + + // deal md/scale +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { +#pragma unroll + for (uint32_t j = 0; j < 2; ++j) { + float m_new; + float d_new = 1.f; + if constexpr (std::is_same::value) { + m_new = -5e4f; + } else { + m_new = -3.0e+30f; + } +#pragma unroll + for (uint32_t i = 0; i < 4; ++i) { + float2 md = smem_md[((i * num_frags_x + fx) * 2 + j) * 32 + tid]; + float m_prev = m_new, d_prev = d_new; + m_new = max(m_new, md.x); + d_new = d_prev * __expf(m_prev - m_new) + md.y * __expf(md.x - m_new); + } +#pragma unroll + for (uint32_t i = 0; i < 4; ++i) { + float2 md = smem_md[((i * num_frags_x + fx) * 2 + j) * 32 + tid]; + o_scale[i][fx][j] = __expf(md.x - m_new); + } + m[fx][j] = m_new; + d[fx][j] = d_new; + } + } + +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { +#pragma unroll + for (uint32_t fy = 0; fy < num_frags_y; ++fy) { + // num_warps * 32 * 8 each time + AlignedVector o_new; +#pragma + for (uint32_t o_id = 0; o_id < 4; ++o_id) { + *(reinterpret_cast(&o_new[o_id * 2])) = make_float2(0.f, 0.f); + } +#pragma unroll + for (uint32_t i = 0; i < 4; ++i) { + AlignedVector oi; + Load( + md_smem + + (((i * num_frags_x + fx) * num_frags_y + fy) * 32 + tid) * 8, + &oi); +#pragma unroll + for (uint32_t reg_id = 0; reg_id < 8; ++reg_id) { + o_new[reg_id] += oi[reg_id] * o_scale[i][fx][(reg_id % 4) / 2]; + } + } + *(reinterpret_cast(&o_frag[fx][fy][0])) = + *(reinterpret_cast(&o_new[0])); + *(reinterpret_cast(&o_frag[fx][fy][4])) = + *(reinterpret_cast(&o_new[4])); + } + } +} diff --git a/csrc/gpu/append_attn/append_attention_impl.cuh b/csrc/gpu/append_attn/append_attention_impl.cuh new file mode 100644 index 000000000000..65620b158730 --- /dev/null +++ b/csrc/gpu/append_attn/append_attention_impl.cuh @@ -0,0 +1,5225 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once +#include "append_attention_func.cuh" + +// #define DEBUG_ATTN_C8 +template +__global__ void multi_query_append_attention_kernel( + T *__restrict__ q, // [token_num. num_heads, head_dim] + T *__restrict__ cache_k, // [max_block_num, num_heads, block_size, head_dim] + T *__restrict__ cache_v, + const T *__restrict__ shift_bias, // [q_num_heads * HEAD_DIM] + const T *__restrict__ smooth_weight, // [q_num_heads * HEAD_DIM] + const int *__restrict__ seq_lens, + const int *__restrict__ seq_lens_kv, + const int *__restrict__ batch_ids, + const int *__restrict__ tile_ids_per_batch, + const int *__restrict__ cum_offsets, + const int *__restrict__ block_table, // [bsz, block_num_per_seq] + const int max_seq_len, + const int max_dec_len, + const int max_block_num_per_seq, + const float scale, + const float in_scale, + const uint32_t chunk_size, + T *__restrict__ tmp_workspace, // split kv [token_num, num_chunks, num_heads, head_dim] + float *__restrict__ tmp_m, // [token_num, num_chunks, num_heads] + float *__restrict__ tmp_d, // [token_num, num_chunks, num_heads] + OutT *__restrict__ out, + const int speculate_max_draft_token_num = 5) { + constexpr uint32_t num_vecs_per_head = HEAD_DIM / num_elems_per_128b(); + const uint32_t btid = blockIdx.x, kv_head_idx = blockIdx.z; + const uint32_t kv_num_heads = gridDim.z; + const uint32_t q_num_heads = kv_num_heads * GROUP_SIZE; + const uint32_t q_head_idx = kv_head_idx * GROUP_SIZE; + const uint32_t tid = threadIdx.x, wid = threadIdx.y; + const uint32_t num_chunks = gridDim.y; + const uint32_t chunk_idx = blockIdx.y; + + const uint32_t batch_id = batch_ids[btid]; + const uint32_t tile_id = tile_ids_per_batch[btid]; + const uint32_t num_rows_per_block = NUM_WARPS * num_frags_x * 16; + const int *block_table_now = nullptr; + + block_table_now = block_table + batch_id * max_block_num_per_seq; + + const uint32_t q_len = seq_lens[batch_id]; + if (q_len <= 0) { + return; + } + const uint32_t q_end = + min(q_len, div_up((tile_id + 1) * num_rows_per_block, GROUP_SIZE)); + uint32_t kv_len = seq_lens_kv[batch_id]; + if (ENABLE_PREFILL) { + kv_len += q_len; // !!! + if (kv_len <= 0) { + return; + } + } else { + if (kv_len <= 0) { + return; + } + kv_len += q_len; + } + + const uint32_t num_chunks_this_seq = div_up(kv_len, chunk_size); + if (chunk_idx >= num_chunks_this_seq) { + return; + } + + const uint32_t chunk_start = partition_kv ? chunk_idx * chunk_size : 0; + const uint32_t chunk_end = + partition_kv ? min(kv_len, chunk_start + chunk_size) : kv_len; + const uint32_t chunk_len = chunk_end - chunk_start; + + extern __shared__ uint8_t smem[]; + float s_frag[num_frags_x][num_frags_z][8]; + float o_frag[num_frags_x][num_frags_y][8]; + float m_frag[num_frags_x][2]; + float d_frag[num_frags_x][2]; + init_states(o_frag, m_frag, d_frag); + + const uint32_t q_n_stride = q_num_heads * HEAD_DIM; + const uint32_t q_ori_n_stride = (q_num_heads + kv_num_heads * 2) * HEAD_DIM; + const uint32_t kv_n_stride = kv_num_heads * BLOCK_SIZE * HEAD_DIM; + const uint32_t kv_h_stride = BLOCK_SIZE * HEAD_DIM; + const uint32_t kv_b_stride = HEAD_DIM; + const uint32_t q_start_seq_id = + batch_id * max_seq_len - __ldg(&cum_offsets[batch_id]); + const uint32_t q_base_seq_id_this_block = + (tile_id * NUM_WARPS + wid) * num_frags_x * 16; + const uint32_t q_offset = q_start_seq_id * q_ori_n_stride + + q_head_idx * HEAD_DIM + + tid % 8 * num_elems_per_128b(); + const uint32_t o_offset = q_start_seq_id * q_n_stride + + q_head_idx * HEAD_DIM + + tid % 8 * num_elems_per_128b(); + T *q_base_ptr = q + q_offset; + T *o_base_ptr_T = nullptr; + OutT *o_base_ptr_int8 = nullptr; + if constexpr (partition_kv) { + if (ENABLE_PREFILL) { + o_base_ptr_T = tmp_workspace + q_start_seq_id * num_chunks * q_n_stride + + chunk_idx * q_n_stride + q_head_idx * HEAD_DIM + + tid % 8 * num_elems_per_128b(); + } else { + o_base_ptr_T = + tmp_workspace + + batch_id * speculate_max_draft_token_num * num_chunks * q_n_stride + + chunk_idx * q_n_stride + q_head_idx * HEAD_DIM + + tid % 8 * num_elems_per_128b(); + } + } else { + o_base_ptr_int8 = out + o_offset; + } + smem_t qo_smem(smem); + + /* + 1 | 3 + —————— + 2 | 4 + */ + uint32_t q_smem_offset_r = smem_t::get_permuted_offset( + wid * num_frags_x * 16 + tid % 16, tid / 16); // 16 * 16 + load_q_global_smem( + q_base_ptr, + &qo_smem, + q_base_seq_id_this_block, + q_end, + q_ori_n_stride, + HEAD_DIM); + commit_group(); + wait_group<0>(); + __syncthreads(); + + q_smem_inplace_multiply_sm_scale(&qo_smem, + scale); + + smem_t k_smem(smem + NUM_WARPS * num_frags_x * 16 * HEAD_DIM * sizeof(T)), + v_smem(smem + (NUM_WARPS * num_frags_x + num_frags_z) * 16 * HEAD_DIM * + sizeof(T)); + + + const uint32_t num_iterations = div_up( + CAUSAL + ? (min(chunk_len, + sub_if_greater_or_zero( + kv_len - q_len + + div_up((tile_id + 1) * num_rows_per_block, GROUP_SIZE), + chunk_start))) + : chunk_len, + num_frags_z * 16); + const uint32_t mask_check_iteration = + (CAUSAL ? (min(chunk_len, + sub_if_greater_or_zero( + kv_len - q_len + + tile_id * num_rows_per_block / GROUP_SIZE, + chunk_start))) + : chunk_len) / + (num_frags_z * 16); +#ifdef DEBUG_ATTN + if (tid == 0 && wid == 0 && kv_head_idx == 0) { + printf( + "batch_id: %d, tile_id: %d, chunk_size: %d, q_len: %d, kv_len: %d, " + "chunk_start: %d, chunk_end: %d, num_iterations: %d, " + "mask_check_iteration: %d\n", + (int)batch_id, + (int)tile_id, + (int)chunk_size, + (int)q_len, + (int)kv_len, + (int)chunk_start, + (int)chunk_end, + (int)num_iterations, + (int)mask_check_iteration); + } + __syncthreads(); +#endif + + /* + 1 | 2 + —————— + 3 | 4 + */ + uint32_t k_smem_offset_r = smem_t::get_permuted_offset( + 8 * (tid / 16) + tid % 8, (tid % 16) / 8); + /* + 1 | 3 + —————— + 2 | 4 transpose + */ + uint32_t v_smem_offset_r = + smem_t::get_permuted_offset(tid % 16, tid / 16); + + uint32_t kv_smem_offset_w = smem_t::get_permuted_offset( + wid * 4 + tid / 8, tid % 8); + + uint32_t kv_idx_base = chunk_start; + int block_id = __ldg(&block_table_now[kv_idx_base / BLOCK_SIZE]); + const uint32_t const_offset = kv_head_idx * kv_h_stride + + (wid * 4 + tid / 8) * kv_b_stride + + tid % 8 * num_elems_per_128b(); + T *cache_k_now = cache_k + block_id * kv_n_stride + const_offset; + T *cache_v_now = cache_v + block_id * kv_n_stride + const_offset; + + produce_kv_blockwise(k_smem, + &kv_smem_offset_w, + &cache_k_now, + kv_head_idx, + kv_n_stride, + kv_h_stride, + kv_b_stride, + kv_idx_base, + chunk_end); + commit_group(); + produce_kv_blockwise(v_smem, + &kv_smem_offset_w, + &cache_v_now, + kv_head_idx, + kv_n_stride, + kv_h_stride, + kv_b_stride, + kv_idx_base, + chunk_end); + commit_group(); +#pragma unroll 1 + for (uint32_t iter = 0; iter < num_iterations; ++iter) { + wait_group<1>(); + __syncthreads(); +#ifdef DEBUG_ATTN + if (tid == PRINT_TID && wid == 0 && blockIdx.z == 0 && blockIdx.x == 0) { + printf("cache_k_smem\n"); + T *k_smem_t = reinterpret_cast(k_smem.base); + for (uint32_t i = 0; i < NUM_WARP_KV * num_frags_z * 16; ++i) { + for (uint32_t j = 0; j < num_frags_y * 16; ++j) { + printf("k_smem[%d][%d] = %f ", + (int)i, + (int)j, + (float)k_smem_t[i * num_frags_y * 16 + j]); + } + printf("\n"); + } + } + __syncthreads(); +#endif + // s = qk + compute_qk( + &qo_smem, &q_smem_offset_r, &k_smem, &k_smem_offset_r, s_frag); + // mask according to kv_idx and q_idx + if (iter >= mask_check_iteration) { + mask_s(q_base_seq_id_this_block, + kv_idx_base, + q_len, + kv_len, + chunk_end, + s_frag); + } + + // update m,d + update_mdo_states( + s_frag, o_frag, m_frag, d_frag); + __syncthreads(); +#ifdef DEBUG_ATTN + if (threadIdx.y == 0 && threadIdx.x == 0 && blockIdx.x == 0 && + blockIdx.y == 0 && blockIdx.z == 0) { + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { + for (uint32_t fz = 0; fz < num_frags_z; ++fz) { + for (int k = 0; k < 8; k++) { + printf( + "after_update_mdo_states_tid:%d_mask_s_s_frag[%d][%d][%d]:%f ", + (int)threadIdx.x, + (int)fx, + (int)fz, + (int)k, + s_frag[fx][fz][k]); + } + printf("\n"); + } + printf("\n"); + } + } + __syncthreads(); +#endif + + kv_idx_base += num_frags_z * 16; + block_id = __ldg(&block_table_now[kv_idx_base / BLOCK_SIZE]); + if (block_id < 0) { + block_id = 0; // 搬但不算 + } + cache_k_now = cache_k + block_id * kv_n_stride + const_offset; + produce_kv_blockwise(k_smem, + &kv_smem_offset_w, + &cache_k_now, + kv_head_idx, + kv_n_stride, + kv_h_stride, + kv_b_stride, + kv_idx_base, + chunk_end); + commit_group(); + wait_group<1>(); + __syncthreads(); +#ifdef DEBUG_ATTN + if (tid == PRINT_TID && wid == 0 && blockIdx.z == 0 && blockIdx.x == 0) { + printf("cache_v_smem\n"); + T *v_smem_t = reinterpret_cast(v_smem.base); + for (uint32_t i = 0; i < NUM_WARP_KV * num_frags_z * 16; ++i) { + for (uint32_t j = 0; j < num_frags_y * 16; ++j) { + printf("v_smem[%d][%d] = %f ", + (int)i, + (int)j, + (float)v_smem_t[i * num_frags_y * 16 + j]); + } + printf("\n"); + } + } + __syncthreads(); +#endif + // compute sfm*v + compute_sfm_v( + &v_smem, &v_smem_offset_r, s_frag, o_frag, d_frag); + + __syncthreads(); + cache_v_now = cache_v + block_id * kv_n_stride + const_offset; + produce_kv_blockwise(v_smem, + &kv_smem_offset_w, + &cache_v_now, + kv_head_idx, + kv_n_stride, + kv_h_stride, + kv_b_stride, + kv_idx_base, + chunk_end); + commit_group(); + } + wait_group<0>(); + __syncthreads(); +#ifdef DEBUG_ATTN + if (threadIdx.x == PRINT_TID && threadIdx.y == 0 && blockIdx.z == 0) { + printf("o_res\n"); + for (uint32_t i = 0; i < num_frags_x; ++i) { + printf("m1: %f, m2: %f\n", m_frag[i][0], m_frag[i][1]); + printf("d1: %f, d2: %f\n", d_frag[i][0], d_frag[i][1]); + for (uint32_t j = 0; j < num_frags_y; ++j) { + for (int r_id = 0; r_id < 8; r_id++) { + printf("o_frag[%d][%d][%d]: %f ", + (int)i, + (int)j, + r_id, + o_frag[i][j][r_id]); + } + } + printf("\n"); + } + } + __syncthreads(); +#endif + + if constexpr (!partition_kv) { + normalize_d(o_frag, d_frag); + } + + // write o + // [num_frags_x, 16, num_frags_y, 16] + // if (in_scale > 0.0) { + if constexpr (partition_kv) { + write_o_reg_gmem_shift_smooth_quant( + o_frag, + &qo_smem, + o_base_ptr_T, + shift_bias, + smooth_weight, + q_base_seq_id_this_block, + q_head_idx, + in_scale, + q_len, + partition_kv ? q_n_stride * num_chunks : q_n_stride, + HEAD_DIM); + } else { + write_o_reg_gmem_shift_smooth_quant( + o_frag, + &qo_smem, + o_base_ptr_int8, + shift_bias, + smooth_weight, + q_base_seq_id_this_block, + q_head_idx, + in_scale, + q_len, + partition_kv ? q_n_stride * num_chunks : q_n_stride, + HEAD_DIM); + } + + + if constexpr (partition_kv) { +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { +#pragma unroll + for (uint32_t j = 0; j < 2; ++j) { + const uint32_t qo_idx_now = + q_base_seq_id_this_block + tid / 4 + j * 8 + fx * 16; + const uint32_t qo_head_idx = q_head_idx + qo_idx_now % GROUP_SIZE; + const uint32_t qo_idx = q_start_seq_id + qo_idx_now / GROUP_SIZE; + if (qo_idx - q_start_seq_id < q_len) { + uint32_t offset; + if (ENABLE_PREFILL) { + offset = + (qo_idx * num_chunks + chunk_idx) * q_num_heads + qo_head_idx; + } else { + offset = ((batch_id * speculate_max_draft_token_num + + qo_idx_now / GROUP_SIZE) * + num_chunks + + chunk_idx) * + q_num_heads + + qo_head_idx; + } + tmp_m[offset] = m_frag[fx][j]; + tmp_d[offset] = d_frag[fx][j]; + } + } + } + } +} + +template +__global__ void multi_query_append_attention_warp1_4_kernel( + T *__restrict__ q, // [token_num. num_heads, head_dim] + T *__restrict__ cache_k, // [max_block_num, num_heads, block_size, + // head_dim] + T *__restrict__ cache_v, + const T *__restrict__ shift_bias, // [q_num_heads * HEAD_DIM] + const T *__restrict__ smooth_weight, // [q_num_heads * HEAD_DIM] + const int *__restrict__ seq_lens, + const int *__restrict__ seq_lens_kv, + const int *__restrict__ batch_ids, + const int *__restrict__ tile_ids_per_batch, + const int *__restrict__ cum_offsets, + const int *__restrict__ block_table, // [bsz, block_num_per_seq] + const int max_seq_len, + const int max_dec_len, + const int max_block_num_per_seq, + const float scale, + const float in_scale, + const uint32_t chunk_size, + T *__restrict__ tmp_workspace, // split kv [token_num, num_chunks, + // num_heads, head_dim] + float *__restrict__ tmp_m, // [token_num, num_chunks, num_heads] + float *__restrict__ tmp_d, // [token_num, num_chunks, num_heads] + OutT *__restrict__ out, + const int speculate_max_draft_token_num = 5) { + // q_len <= 32, num_frags_x = 1/2, num_frags_z = 4 / 4 * 1/2/4, num_frags_y = + // HEAD_DIM / 16 + constexpr uint32_t num_vecs_per_head = HEAD_DIM / num_elems_per_128b(); + static_assert(NUM_WARP_Q == 1, "NUM_WARP_Q must be 1"); + static_assert(NUM_WARP_KV == 4, "NUM_WARP_KV must be 4"); + const uint32_t btid = blockIdx.x, kv_head_idx = blockIdx.z; + const uint32_t kv_num_heads = gridDim.z; + const uint32_t q_num_heads = kv_num_heads * GROUP_SIZE; + const uint32_t q_head_idx = kv_head_idx * GROUP_SIZE; + const uint32_t tid = threadIdx.x, wid = threadIdx.y; + const uint32_t num_chunks = gridDim.y; + const uint32_t chunk_idx = blockIdx.y; + + const uint32_t batch_id = batch_ids[btid]; + const uint32_t tile_id = tile_ids_per_batch[btid]; + const uint32_t num_rows_per_block = num_frags_x * 16; + const int *block_table_now = block_table + batch_id * max_block_num_per_seq; + + const uint32_t q_len = seq_lens[batch_id]; + if (q_len <= 0) { + return; + } + const uint32_t q_end = + min(q_len, div_up((tile_id + 1) * num_rows_per_block, GROUP_SIZE)); + uint32_t kv_len = seq_lens_kv[batch_id]; + if (ENABLE_PREFILL) { + kv_len += q_len; // !!! + if (kv_len <= 0) { + return; + } + } else { + if (kv_len <= 0) { + return; + } + kv_len += q_len; + } + const uint32_t num_chunks_this_seq = div_up(kv_len, chunk_size); + if (chunk_idx >= num_chunks_this_seq) { + return; + } + + const uint32_t chunk_start = partition_kv ? chunk_idx * chunk_size : 0; + const uint32_t chunk_end = + partition_kv ? min(kv_len, chunk_start + chunk_size) : kv_len; + const uint32_t chunk_len = chunk_end - chunk_start; + + extern __shared__ uint8_t smem[]; + float s_frag[num_frags_x][num_frags_z][8]; + float o_frag[num_frags_x][num_frags_y][8]; + float m_frag[num_frags_x][2]; + float d_frag[num_frags_x][2]; + init_states(o_frag, m_frag, d_frag); + + const uint32_t q_n_stride = q_num_heads * HEAD_DIM; + const uint32_t q_ori_n_stride = (q_num_heads + kv_num_heads * 2) * HEAD_DIM; + const uint32_t kv_n_stride = kv_num_heads * BLOCK_SIZE * HEAD_DIM; + const uint32_t kv_h_stride = BLOCK_SIZE * HEAD_DIM; + const uint32_t kv_b_stride = HEAD_DIM; + const uint32_t q_start_seq_id = + batch_id * max_seq_len - __ldg(&cum_offsets[batch_id]); + const uint32_t q_base_seq_id_this_block = tile_id * num_frags_x * 16; + const uint32_t q_offset = q_start_seq_id * q_ori_n_stride + + q_head_idx * HEAD_DIM + + tid % 8 * num_elems_per_128b(); + const uint32_t o_offset = q_start_seq_id * q_n_stride + + q_head_idx * HEAD_DIM + + tid % 8 * num_elems_per_128b(); + T *q_base_ptr = q + q_offset; + T *o_base_ptr_T = nullptr; + OutT *o_base_ptr_int8 = nullptr; + if (num_chunks_this_seq <= 1) { + o_base_ptr_int8 = out + o_offset; + } else { + if (ENABLE_PREFILL) { + o_base_ptr_T = tmp_workspace + batch_id * num_chunks * q_n_stride + + chunk_idx * q_n_stride + q_head_idx * HEAD_DIM + + tid % 8 * num_elems_per_128b(); + } else { + o_base_ptr_T = + tmp_workspace + + batch_id * speculate_max_draft_token_num * num_chunks * q_n_stride + + chunk_idx * q_n_stride + q_head_idx * HEAD_DIM + + tid % 8 * num_elems_per_128b(); + } + // } else { + // o_base_ptr_int8 = out + o_offset; + } + + smem_t qo_smem(smem); + + /* + 1 | 3 + —————— + 2 | 4 + */ + uint32_t q_smem_offset_r = smem_t::get_permuted_offset( + tid % 16, tid / 16); // 16 * 16 + load_q_global_smem_multi_warps(q_base_ptr, + &qo_smem, + q_base_seq_id_this_block, + q_end, + q_ori_n_stride, + HEAD_DIM); + commit_group(); + wait_group<0>(); + __syncthreads(); + + q_smem_inplace_multiply_sm_scale_multi_warps( + &qo_smem, scale); + + smem_t k_smem(smem + num_frags_x * 16 * HEAD_DIM * sizeof(T)), + v_smem(smem + (num_frags_x + NUM_WARP_KV * num_frags_z) * 16 * HEAD_DIM * + sizeof(T)); + + const uint32_t num_iterations = div_up( + CAUSAL + ? (min(chunk_len, + sub_if_greater_or_zero( + kv_len - q_len + + div_up((tile_id + 1) * num_rows_per_block, GROUP_SIZE), + chunk_start))) + : chunk_len, + NUM_WARP_KV * num_frags_z * 16); + const uint32_t mask_check_iteration = + (CAUSAL ? (min(chunk_len, + sub_if_greater_or_zero( + kv_len - q_len + + tile_id * num_rows_per_block / GROUP_SIZE, + chunk_start))) + : chunk_len) / + (NUM_WARP_KV * num_frags_z * 16); + + /* + 1 | 2 + —————— + 3 | 4 + */ + uint32_t k_smem_offset_r = smem_t::get_permuted_offset( + wid * num_frags_z * 16 + 8 * (tid / 16) + tid % 8, (tid % 16) / 8); + /* + 1 | 3 + —————— + 2 | 4 transpose + */ + uint32_t v_smem_offset_r = smem_t::get_permuted_offset( + wid * num_frags_z * 16 + tid % 16, tid / 16); + uint32_t kv_smem_offset_w = smem_t::get_permuted_offset( + wid * 4 + tid / 8, tid % 8); // 注意内存访问事务,8 * 128 / 8 = 128B + // uint32_t kv_smem_offset_w = + // smem_t::get_permuted_offset(wid * num_frags_z * 16 + tid + // / 8, tid % 8); // 注意内存访问事务,8 * 128 / 8 = 128B + + uint32_t kv_idx_base = chunk_start; + int block_id = __ldg(&block_table_now[kv_idx_base / BLOCK_SIZE]); + const uint32_t const_offset = kv_head_idx * kv_h_stride + + (wid * 4 + tid / 8) * kv_b_stride + + tid % 8 * num_elems_per_128b(); + // uint32_t kv_idx_base = chunk_start + wid * num_frags_z * 16; + // int block_id = __ldg(&block_table_now[kv_idx_base / BLOCK_SIZE]); + // const uint32_t const_offset = kv_head_idx * kv_h_stride + (wid * + // num_frags_z * 16 % BLOCK_SIZE + tid / 8) * kv_b_stride + tid % 8 * + // num_elems_per_128b(); + T *cache_k_now = cache_k + block_id * kv_n_stride + const_offset; + T *cache_v_now = cache_v + block_id * kv_n_stride + const_offset; + + // load BLOCK_SIZE * HEAD_DIM each time + produce_kv_blockwise(k_smem, + &kv_smem_offset_w, + &cache_k_now, + kv_head_idx, + kv_n_stride, + kv_h_stride, + kv_b_stride, + kv_idx_base, + chunk_end); + commit_group(); + + produce_kv_blockwise(v_smem, + &kv_smem_offset_w, + &cache_v_now, + kv_head_idx, + kv_n_stride, + kv_h_stride, + kv_b_stride, + kv_idx_base, + chunk_end); + commit_group(); + +#pragma unroll 1 + for (uint32_t iter = 0; iter < num_iterations; ++iter) { + wait_group<1>(); + __syncthreads(); + + // s = qk + compute_qk( + &qo_smem, &q_smem_offset_r, &k_smem, &k_smem_offset_r, s_frag); + // mask according to kv_idx and q_idx + if (iter >= mask_check_iteration) { + mask_s(q_base_seq_id_this_block, + kv_idx_base + wid * num_frags_z * 16, + q_len, + kv_len, + chunk_end, + s_frag); + } + + // update m,d + update_mdo_states( + s_frag, o_frag, m_frag, d_frag); + __syncthreads(); + + kv_idx_base += NUM_WARP_KV * num_frags_z * 16; + block_id = __ldg(&block_table_now[kv_idx_base / BLOCK_SIZE]); + if (block_id < 0) { + block_id = 0; // 搬但不算 + } + cache_k_now = cache_k + block_id * kv_n_stride + const_offset; + produce_kv_blockwise(k_smem, + &kv_smem_offset_w, + &cache_k_now, + kv_head_idx, + kv_n_stride, + kv_h_stride, + kv_b_stride, + kv_idx_base, + chunk_end); + commit_group(); + wait_group<1>(); + __syncthreads(); + + // compute sfm*v + compute_sfm_v( + &v_smem, &v_smem_offset_r, s_frag, o_frag, d_frag); + __syncthreads(); + + cache_v_now = cache_v + block_id * kv_n_stride + const_offset; + produce_kv_blockwise(v_smem, + &kv_smem_offset_w, + &cache_v_now, + kv_head_idx, + kv_n_stride, + kv_h_stride, + kv_b_stride, + kv_idx_base, + chunk_end); + commit_group(); + } + wait_group<0>(); + __syncthreads(); + + merge_block_res_v2( + o_frag, reinterpret_cast(smem), m_frag, d_frag, wid, tid); + + if (num_chunks_this_seq <= 1) { + normalize_d(o_frag, d_frag); + } + + // write o + // [num_frags_x, 16, num_frags_y, 16] + if (num_chunks_this_seq <= 1) { + write_o_reg_gmem_multi_warps_shift_smooth_quant( + o_frag, + &qo_smem, + o_base_ptr_int8, + shift_bias, + smooth_weight, + q_base_seq_id_this_block, + q_head_idx, + in_scale, + q_len, + q_n_stride, + HEAD_DIM + ); + } else { + write_o_reg_gmem_multi_warps_shift_smooth_quant( + o_frag, + &qo_smem, + o_base_ptr_T, + shift_bias, + smooth_weight, + q_base_seq_id_this_block, + q_head_idx, + in_scale, + q_len, + q_n_stride * num_chunks, + HEAD_DIM); + // } else { + // write_o_reg_gmem_multi_warps_shift_smooth_quant( + // o_frag, + // &qo_smem, + // o_base_ptr_int8, + // shift_bias, + // smooth_weight, + // q_base_seq_id_this_block, + // q_head_idx, + // in_scale, + // q_len, + // partition_kv ? q_n_stride * num_chunks : q_n_stride, + // HEAD_DIM); + } + + if (num_chunks_this_seq > 1) { + if (wid == 0) { +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { +#pragma unroll + for (uint32_t j = 0; j < 2; ++j) { + const uint32_t qo_idx_now = + q_base_seq_id_this_block + tid / 4 + j * 8 + fx * 16; + const uint32_t qo_head_idx = q_head_idx + qo_idx_now % GROUP_SIZE; + const uint32_t qo_idx = q_start_seq_id + qo_idx_now / GROUP_SIZE; + + if (qo_idx - q_start_seq_id < q_len) { + uint32_t offset; + if (ENABLE_PREFILL) { + offset = (batch_id * num_chunks + chunk_idx) * q_num_heads + + qo_head_idx; + } else { + offset = ((batch_id * speculate_max_draft_token_num + + qo_idx_now / GROUP_SIZE) * + num_chunks + + chunk_idx) * + q_num_heads + + qo_head_idx; + } + tmp_m[offset] = m_frag[fx][j]; + tmp_d[offset] = d_frag[fx][j]; + } + } + } + } + } +} + +template +__global__ void multi_query_append_attention_c8_kernel( + T *__restrict__ q, // [token_num. num_heads, head_dim] + CacheT *__restrict__ cache_k, // [max_block_num, num_heads, block_size, head_dim] + CacheT *__restrict__ cache_v, + const T *__restrict__ cache_k_scale, // [num_kv_heads] + const T *__restrict__ cache_v_scale, // [num_kv_heads] + const T *__restrict__ shift_bias, // [q_num_heads * HEAD_DIM] + const T *__restrict__ smooth_weight, // [q_num_heads * HEAD_DIM] + const int *__restrict__ seq_lens, + const int *__restrict__ seq_lens_kv, + const int *__restrict__ batch_ids, + const int *__restrict__ tile_ids_per_batch, + const int *__restrict__ cum_offsets, + const int *__restrict__ block_table, // [bsz, block_num_per_seq] + const int max_seq_len, + const int max_dec_len, + const int max_block_num_per_seq, + const float scale, + const float in_scale, + const uint32_t chunk_size, + T *__restrict__ tmp_workspace, // split kv [token_num, num_chunks, + // num_heads, head_dim] + float *__restrict__ tmp_m, // [token_num, num_chunks, num_heads] + float *__restrict__ tmp_d, // [token_num, num_chunks, num_heads] + OutT *__restrict__ out, + const int speculate_max_draft_token_num = 5) { +#ifdef DEBUG_ATTN_C8 + __syncthreads(); + printf("launched multi_query_append_attention_c8_kernel"); + __syncthreads(); +#endif + constexpr uint32_t num_vecs_per_head = HEAD_DIM / num_elems_per_128b(); // 128 / 8 = 16 + constexpr uint32_t num_vecs_per_head_k = HEAD_DIM / num_elems_per_128b(); // 128 / 16 = 8 + constexpr uint32_t num_vecs_per_blocksize = BLOCK_SIZE / num_elems_per_128b(); // 64 / 16 = 4 + constexpr uint32_t inv_k_stride = 8 / num_vecs_per_head_k; + constexpr uint32_t inv_v_stride = 8 / num_vecs_per_blocksize; + const uint32_t btid = blockIdx.x, kv_head_idx = blockIdx.z; + const uint32_t kv_num_heads = gridDim.z; + const uint32_t q_num_heads = kv_num_heads * GROUP_SIZE; + const uint32_t q_head_idx = kv_head_idx * GROUP_SIZE; + const uint32_t tid = threadIdx.x, wid = threadIdx.y; + const uint32_t num_chunks = gridDim.y; + const uint32_t chunk_idx = blockIdx.y; +#ifdef DEBUG_ATTN_C8 + if (tid == PRINT_TID && wid == 0 && blockIdx.z == 0) { + printf("num_vecs_per_head: %d, num_vecs_per_head_k: %d, num_vecs_per_blocksize: %d, inv_k_stride: %d, inv_v_stride: %d\n", + (int)num_vecs_per_head, (int)num_vecs_per_head_k, (int)num_vecs_per_blocksize, (int)inv_k_stride, (int)inv_v_stride); + } + __syncthreads(); +#endif + + const uint32_t batch_id = batch_ids[btid]; + const uint32_t tile_id = tile_ids_per_batch[btid]; + const uint32_t num_rows_per_block = NUM_WARPS * num_frags_x * 16; + const int *block_table_now = nullptr; + + block_table_now = block_table + batch_id * max_block_num_per_seq; + + const uint32_t q_len = seq_lens[batch_id]; + if (q_len <= 0) { + return; + } + const T cache_k_scale_reg = cache_k_scale[kv_head_idx]; + const T cache_v_scale_reg = cache_v_scale[kv_head_idx]; + + const uint32_t q_end = min(q_len, div_up((tile_id + 1) * num_rows_per_block, GROUP_SIZE)); + uint32_t kv_len = seq_lens_kv[batch_id]; + if (ENABLE_PREFILL) { + kv_len += q_len; // !!! + if (kv_len <= 0) { + return; + } + } else { + if (kv_len <= 0) { + return; + } + kv_len += q_len; + } + const uint32_t num_chunks_this_seq = div_up(kv_len, chunk_size); + if (chunk_idx >= num_chunks_this_seq) { + return; + } + + const uint32_t chunk_start = partition_kv ? chunk_idx * chunk_size : 0; + const uint32_t chunk_end = partition_kv ? min(kv_len, chunk_start + chunk_size) : kv_len; + const uint32_t chunk_len = chunk_end - chunk_start; + + extern __shared__ uint8_t smem[]; + float s_frag[num_frags_x][num_frags_z][8]; + float o_frag[num_frags_x][num_frags_y][8]; + float m_frag[num_frags_x][2]; + float d_frag[num_frags_x][2]; + init_states(o_frag, m_frag, d_frag); + + const uint32_t q_n_stride = q_num_heads * HEAD_DIM; + const uint32_t q_ori_n_stride = (q_num_heads + kv_num_heads * 2) * HEAD_DIM; + const uint32_t kv_n_stride = kv_num_heads * BLOCK_SIZE * HEAD_DIM; + const uint32_t kv_h_stride = BLOCK_SIZE * HEAD_DIM; + const uint32_t kv_b_stride = HEAD_DIM; + const uint32_t kv_d_stride = BLOCK_SIZE; + const uint32_t q_start_seq_id = batch_id * max_seq_len - __ldg(&cum_offsets[batch_id]); + const uint32_t q_base_seq_id_this_block = (tile_id * NUM_WARPS + wid) * num_frags_x * 16; + const uint32_t q_offset = q_start_seq_id * q_ori_n_stride + q_head_idx * HEAD_DIM + tid % 8 * num_elems_per_128b(); + const uint32_t o_offset = q_start_seq_id * q_n_stride + q_head_idx * HEAD_DIM + tid % 8 * num_elems_per_128b(); + T *q_base_ptr = q + q_offset; +#ifdef DEBUG_ATTN_C8 + if (tid == PRINT_TID && wid == 0 && blockIdx.z == 0) { + printf("q_start_seq_id: %d, q_offset: %d, q_ori_n_stride: %d, q_base: %f\n", + (int)q_start_seq_id, (int)q_offset, (int)q_ori_n_stride, (float)*q_base_ptr); + } + __syncthreads(); +#endif + + T *o_base_ptr_T = nullptr; + OutT *o_base_ptr_int8 = nullptr; + if constexpr (partition_kv) { + if (ENABLE_PREFILL) { + o_base_ptr_T = tmp_workspace + q_start_seq_id * num_chunks * q_n_stride + chunk_idx * q_n_stride + q_head_idx * HEAD_DIM + tid % 8 * num_elems_per_128b(); + } else { + o_base_ptr_T = tmp_workspace + batch_id * speculate_max_draft_token_num * num_chunks * q_n_stride + chunk_idx * q_n_stride + q_head_idx * HEAD_DIM + tid % 8 * num_elems_per_128b(); + } + } else { + o_base_ptr_int8 = out + o_offset; + } + smem_t qo_smem(smem); + + /* + 1 | 3 + —————— + 2 | 4 + */ + uint32_t q_smem_offset_r = smem_t::get_permuted_offset(wid * num_frags_x * 16 + tid % 16, tid / 16); // 16 * 16 + load_q_global_smem( + q_base_ptr, + &qo_smem, + q_base_seq_id_this_block, + q_end, + q_ori_n_stride, + HEAD_DIM + ); + commit_group(); + wait_group<0>(); + __syncthreads(); + + q_smem_inplace_multiply_sm_scale(&qo_smem, scale); + smem_t k_smem(smem + NUM_WARPS * num_frags_x * 16 * HEAD_DIM * sizeof(T)), + v_smem(smem + NUM_WARPS * num_frags_x * 16 * HEAD_DIM * sizeof(T) + num_frags_z * 16 * HEAD_DIM * sizeof(CacheT)); + + + const uint32_t num_iterations = div_up(CAUSAL ? + (min(chunk_len, + sub_if_greater_or_zero(kv_len - q_len + div_up((tile_id + 1) * num_rows_per_block, GROUP_SIZE), chunk_start))) + : chunk_len, num_frags_z * 16); + const uint32_t mask_check_iteration = (CAUSAL ? + (min(chunk_len, + sub_if_greater_or_zero(kv_len - q_len + tile_id * num_rows_per_block / GROUP_SIZE, chunk_start))) + : chunk_len) / (num_frags_z * 16); +#ifdef DEBUG_ATTN_C8 + if (tid == 0 && wid == 0) { + printf("batch_id: %d, tile_id: %d, chunk_size: %d, q_len: %d, kv_len: %d, chunk_start: %d, chunk_end: %d, num_iterations: %d, mask_check_iteration: %d\n", + (int)batch_id, (int)tile_id, (int)chunk_size, (int)q_len, (int)kv_len, (int)chunk_start, (int)chunk_end, (int)num_iterations, (int)mask_check_iteration); + } + __syncthreads(); +#endif + + /* + 1 | 2 + —————— + 3 | 4 + */ + uint32_t k_smem_offset_r = smem_t::get_permuted_offset(8 * (tid / 16) + tid % 8, (tid % 16) / 8); + /* + 1 | 2 + —————— + 3 | 4 + */ + uint32_t v_smem_offset_r = smem_t::get_permuted_offset(8 * (tid / 16) + tid % 8, (tid % 16) / 8); + + uint32_t k_smem_offset_w = smem_t::get_permuted_offset(wid * 4 + tid / 8, tid % 8); // 8 * 128 / 8 = 128 !!! just for HEAD_DIM >= 128 + uint32_t v_smem_offset_w = smem_t::get_permuted_offset(wid * 8 + tid / 4, tid % 4); // 4 * 128 / 8 = 64 + + uint32_t kv_idx_base = chunk_start; + // int block_id = __ldg(&block_table_now[kv_idx_base / BLOCK_SIZE]); + const uint32_t const_k_offset = kv_head_idx * kv_h_stride + (wid * 4 + tid / 8) * kv_b_stride + tid % 8 * num_elems_per_128b(); + const uint32_t const_v_offset = kv_head_idx * kv_h_stride + (wid * 8 + tid / 4) * kv_d_stride + tid % 4 * num_elems_per_128b(); + // CacheT *cache_k_now = cache_k + block_id * kv_n_stride + const_k_offset; + // CacheT *cache_v_now = cache_v + block_id * kv_n_stride + const_v_offset; +#ifdef DEBUG_ATTN_C8 + if (threadIdx.x == PRINT_TID && threadIdx.y == 0 && blockIdx.z == 0) { + printf("387 ori q_smem_offset_r: %d, k_smem_offset_r: %d, v_smem_offset_r: %d, k_smem_offset_w: %d, v_smem_offset_w: %d\n", + (int)q_smem_offset_r, (int)k_smem_offset_r, (int)v_smem_offset_r, (int)k_smem_offset_w, (int)v_smem_offset_w); + } + __syncthreads(); +#endif + + produce_k_blockwise_c8( + k_smem, + &k_smem_offset_w, + cache_k, + block_table_now, + kv_head_idx, + kv_n_stride, + kv_h_stride, + kv_b_stride, + kv_idx_base, + chunk_end, + const_k_offset + ); + commit_group(); + produce_v_blockwise_c8( + v_smem, + &v_smem_offset_w, + cache_v, + block_table_now, + kv_head_idx, + kv_n_stride, + kv_h_stride, + kv_d_stride, + kv_idx_base, + chunk_end, + const_v_offset + ); + commit_group(); +#ifdef DEBUG_ATTN_C8 + if (threadIdx.x == PRINT_TID && threadIdx.y == 0 && blockIdx.z == 0) { + printf("418 ori q_smem_offset_r: %d, k_smem_offset_r: %d, v_smem_offset_r: %d, k_smem_offset_w: %d, v_smem_offset_w: %d\n", + (int)q_smem_offset_r, (int)k_smem_offset_r, (int)v_smem_offset_r, (int)k_smem_offset_w, (int)v_smem_offset_w); + } + __syncthreads(); +#endif + +#pragma unroll 1 + for (uint32_t iter = 0; iter < num_iterations; ++iter) { + wait_group<1>(); + __syncthreads(); +#ifdef DEBUG_ATTN_C8 + if (tid == PRINT_TID && wid == 0 && blockIdx.z == 0) { + printf("cache_k_smem\n"); + uint8_t *k_smem_t = reinterpret_cast(k_smem.base); + for (uint32_t i = 0; i < num_frags_z * 16; ++i) { + for (uint32_t j = 0; j < num_frags_y * 16; ++j) { + printf("k_smem[%d][%d] = %d ", (int)i, (int)j, (int)k_smem_t[i * num_frags_y * 16 + j]); + } + printf("\n"); + } + } + __syncthreads(); +#endif + + // s = qk + compute_qk_c8( + &qo_smem, + &q_smem_offset_r, + &k_smem, + &k_smem_offset_r, + cache_k_scale_reg, + s_frag); +#ifdef DEBUG_ATTN_C8 + if (threadIdx.x == PRINT_TID && threadIdx.y == 0 && blockIdx.z == 0) { + printf("111 iter: %d, q_smem_offset_r: %d, k_smem_offset_r: %d, v_smem_offset_r: %d, k_smem_offset_w: %d, v_smem_offset_w: %d\n", + (int)iter, (int)q_smem_offset_r, (int)k_smem_offset_r, (int)v_smem_offset_r, (int)k_smem_offset_w, (int)v_smem_offset_w); + } + __syncthreads(); +#endif + + // mask according to kv_idx and q_idx + if (iter >= mask_check_iteration) { + mask_s( + q_base_seq_id_this_block, kv_idx_base, q_len, kv_len, chunk_end, s_frag); + } +#ifdef DEBUG_ATTN_C8 + if (threadIdx.x == PRINT_TID && threadIdx.y == 0 && blockIdx.z == 0) { + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { + for (uint32_t fz = 0; fz < num_frags_z; ++fz) { + for (int k = 0; k < 8; k++) { + printf("mask_s s_frag[%d][%d][%d]: %f ", (int)fx, (int)fz, (int)k, s_frag[fx][fz][k]); + } + printf("\n"); + } + printf("\n"); + } + } + __syncthreads(); +#endif + + // update m,d + update_mdo_states( + s_frag, o_frag, m_frag, d_frag); + __syncthreads(); + + kv_idx_base += num_frags_z * 16; + produce_k_blockwise_c8( + k_smem, + &k_smem_offset_w, + cache_k, + block_table_now, + kv_head_idx, + kv_n_stride, + kv_h_stride, + kv_b_stride, + kv_idx_base, + chunk_end, + const_k_offset + ); + commit_group(); + wait_group<1>(); + __syncthreads(); +#ifdef DEBUG_ATTN_C8 + if (threadIdx.x == PRINT_TID && threadIdx.y == 0 && blockIdx.z == 0) { + printf("222 iter: %d, q_smem_offset_r: %d, k_smem_offset_r: %d, v_smem_offset_r: %d, k_smem_offset_w: %d, v_smem_offset_w: %d\n", + (int)iter, (int)q_smem_offset_r, (int)k_smem_offset_r, (int)v_smem_offset_r, (int)k_smem_offset_w, (int)v_smem_offset_w); + } + __syncthreads(); +#endif + + // compute sfm*v + compute_sfm_v_c8( + &v_smem, + &v_smem_offset_r, + s_frag, + o_frag, + d_frag, + cache_v_scale_reg); + __syncthreads(); +#ifdef DEBUG_ATTN_C8 + if (threadIdx.x == PRINT_TID && threadIdx.y == 0 && blockIdx.z == 0) { + printf("333 iter: %d, q_smem_offset_r: %d, k_smem_offset_r: %d, v_smem_offset_r: %d, k_smem_offset_w: %d, v_smem_offset_w: %d\n", + (int)iter, (int)q_smem_offset_r, (int)k_smem_offset_r, (int)v_smem_offset_r, (int)k_smem_offset_w, (int)v_smem_offset_w); + } + __syncthreads(); +#endif + produce_v_blockwise_c8( + v_smem, + &v_smem_offset_w, + cache_v, + block_table_now, + kv_head_idx, + kv_n_stride, + kv_h_stride, + kv_d_stride, + kv_idx_base, + chunk_end, + const_v_offset + ); + commit_group(); +#ifdef DEBUG_ATTN_C8 + if (threadIdx.x == PRINT_TID && threadIdx.y == 0 && blockIdx.z == 0) { + printf("444 iter: %d, q_smem_offset_r: %d, k_smem_offset_r: %d, v_smem_offset_r: %d, k_smem_offset_w: %d, v_smem_offset_w: %d\n", + (int)iter, (int)q_smem_offset_r, (int)k_smem_offset_r, (int)v_smem_offset_r, (int)k_smem_offset_w, (int)v_smem_offset_w); + } + __syncthreads(); +#endif + } + wait_group<0>(); + __syncthreads(); + + if constexpr (!partition_kv) { + normalize_d(o_frag, d_frag); + } + + // write o + // [num_frags_x, 16, num_frags_y, 16] + if constexpr (partition_kv) { + write_o_reg_gmem_shift_smooth_quant( + o_frag, + &qo_smem, + o_base_ptr_T, + shift_bias, + smooth_weight, + q_base_seq_id_this_block, + q_head_idx, + in_scale, + q_len, + partition_kv ? q_n_stride * num_chunks : q_n_stride, + HEAD_DIM + ); + } else { + write_o_reg_gmem_shift_smooth_quant( + o_frag, + &qo_smem, + o_base_ptr_int8, + shift_bias, + smooth_weight, + q_base_seq_id_this_block, + q_head_idx, + in_scale, + q_len, + partition_kv ? q_n_stride * num_chunks : q_n_stride, + HEAD_DIM + ); + } + + + if constexpr (partition_kv) { +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { +#pragma unroll + for (uint32_t j = 0; j < 2; ++j) { + const uint32_t qo_idx_now = q_base_seq_id_this_block + tid / 4 + j * 8 + fx * 16; + const uint32_t qo_head_idx = q_head_idx + qo_idx_now % GROUP_SIZE; + const uint32_t qo_idx = q_start_seq_id + qo_idx_now / GROUP_SIZE; + if (qo_idx - q_start_seq_id < q_len) { + uint32_t offset; + if (ENABLE_PREFILL) { + offset = (qo_idx * num_chunks + chunk_idx) * q_num_heads + qo_head_idx; + } else { + offset = ((batch_id * speculate_max_draft_token_num + qo_idx_now / GROUP_SIZE) * num_chunks + chunk_idx) * q_num_heads + qo_head_idx; + } + tmp_m[offset] = m_frag[fx][j]; + tmp_d[offset] = d_frag[fx][j]; + } + } + } + } +} + +template +__global__ void multi_query_append_attention_c8_warp1_4_kernel( + T *__restrict__ q, // [token_num. num_heads, head_dim] + CacheT *__restrict__ cache_k, // [max_block_num, num_heads, block_size, head_dim] + CacheT *__restrict__ cache_v, + const T *__restrict__ cache_k_scale, // [num_kv_heads, head_dim] + const T *__restrict__ cache_v_scale, // [num_kv_heads, head_dim] + const T *__restrict__ shift_bias, // [q_num_heads * HEAD_DIM] + const T *__restrict__ smooth_weight, // [q_num_heads * HEAD_DIM] + const int *__restrict__ seq_lens, + const int *__restrict__ seq_lens_kv, + const int *__restrict__ batch_ids, + const int *__restrict__ tile_ids_per_batch, + const int *__restrict__ cum_offsets, + const int *__restrict__ block_table, // [bsz, block_num_per_seq] + const int max_seq_len, + const int max_dec_len, + const int max_block_num_per_seq, + const float scale, + const float in_scale, + const uint32_t chunk_size, + T *__restrict__ tmp_workspace, // split kv [token_num, num_chunks, + // num_heads, head_dim] + float *__restrict__ tmp_m, // [token_num, num_chunks, num_heads] + float *__restrict__ tmp_d, // [token_num, num_chunks, num_heads] + OutT *__restrict__ out, + const int speculate_max_draft_token_num = 5) { + + constexpr uint32_t num_vecs_per_head = HEAD_DIM / num_elems_per_128b(); + constexpr uint32_t num_vecs_per_head_k = + HEAD_DIM / num_elems_per_128b(); + constexpr uint32_t num_vecs_per_blocksize = + BLOCK_SIZE / num_elems_per_128b(); + constexpr uint32_t inv_k_stride = 8 / num_vecs_per_head_k; + constexpr uint32_t inv_v_stride = 8 / num_vecs_per_blocksize; + static_assert(NUM_WARP_Q == 1, "NUM_WARP_Q must be 1"); + static_assert(NUM_WARP_KV == 4, "NUM_WARP_KV must be 4"); + const uint32_t btid = blockIdx.x, kv_head_idx = blockIdx.z; + const uint32_t kv_num_heads = gridDim.z; + const uint32_t q_num_heads = kv_num_heads * GROUP_SIZE; + const uint32_t q_head_idx = kv_head_idx * GROUP_SIZE; + const uint32_t tid = threadIdx.x, wid = threadIdx.y; + const uint32_t num_chunks = gridDim.y; + const uint32_t chunk_idx = blockIdx.y; + + const uint32_t batch_id = batch_ids[btid]; + const uint32_t tile_id = tile_ids_per_batch[btid]; + const uint32_t num_rows_per_block = num_frags_x * 16; + const int *block_table_now = block_table + batch_id * max_block_num_per_seq; + + const uint32_t q_len = seq_lens[batch_id]; + if (q_len <= 0) { + return; + } + const T cache_k_scale_reg = cache_k_scale[kv_head_idx]; + const T cache_v_scale_reg = cache_v_scale[kv_head_idx]; + const uint32_t q_end = + min(q_len, div_up((tile_id + 1) * num_rows_per_block, GROUP_SIZE)); + uint32_t kv_len = seq_lens_kv[batch_id]; + if (ENABLE_PREFILL) { + kv_len += q_len; // !!! + if (kv_len <= 0) { + return; + } + } else { + if (kv_len <= 0) { + return; + } + kv_len += q_len; + } + const uint32_t num_chunks_this_seq = div_up(kv_len, chunk_size); + if (chunk_idx >= num_chunks_this_seq) { + return; + } + + const uint32_t chunk_start = partition_kv ? chunk_idx * chunk_size : 0; + const uint32_t chunk_end = + partition_kv ? min(kv_len, chunk_start + chunk_size) : kv_len; + const uint32_t chunk_len = chunk_end - chunk_start; + + extern __shared__ uint8_t smem[]; + float s_frag[num_frags_x][num_frags_z][8]; + float o_frag[num_frags_x][num_frags_y][8]; + float m_frag[num_frags_x][2]; + float d_frag[num_frags_x][2]; + init_states(o_frag, m_frag, d_frag); + + const uint32_t q_n_stride = q_num_heads * HEAD_DIM; + const uint32_t q_ori_n_stride = (q_num_heads + kv_num_heads * 2) * HEAD_DIM; + const uint32_t kv_n_stride = kv_num_heads * BLOCK_SIZE * HEAD_DIM; + const uint32_t kv_h_stride = BLOCK_SIZE * HEAD_DIM; + const uint32_t kv_b_stride = HEAD_DIM; + const uint32_t kv_d_stride = BLOCK_SIZE; + const uint32_t q_start_seq_id = + batch_id * max_seq_len - __ldg(&cum_offsets[batch_id]); + const uint32_t q_base_seq_id_this_block = tile_id * num_frags_x * 16; + const uint32_t q_offset = q_start_seq_id * q_ori_n_stride + + q_head_idx * HEAD_DIM + + tid % 8 * num_elems_per_128b(); + const uint32_t o_offset = q_start_seq_id * q_n_stride + + q_head_idx * HEAD_DIM + + tid % 8 * num_elems_per_128b(); + T *q_base_ptr = q + q_offset; + + T *o_base_ptr_T = nullptr; + OutT *o_base_ptr_int8 = nullptr; + if (num_chunks_this_seq <= 1) { + o_base_ptr_int8 = out + o_offset; + } else { + // if constexpr (partition_kv) { + if (ENABLE_PREFILL) { + o_base_ptr_T = tmp_workspace + batch_id * num_chunks * q_n_stride + + chunk_idx * q_n_stride + q_head_idx * HEAD_DIM + + tid % 8 * num_elems_per_128b(); + } else { + o_base_ptr_T = + tmp_workspace + + batch_id * speculate_max_draft_token_num * num_chunks * q_n_stride + + chunk_idx * q_n_stride + q_head_idx * HEAD_DIM + + tid % 8 * num_elems_per_128b(); + } + // } else { + // o_base_ptr_int8 = out + o_offset; + } +#ifdef DEBUG_ATTN_C8 + if (tid == PRINT_TID && wid == 0 && blockIdx.z == 0) { + printf( + "q_base_seq_id_this_block: %d, q_base_seq_id_this_block: %d, q_offset: " + "%d, o_offset: %d\n", + (int)q_base_seq_id_this_block, + (int)q_base_seq_id_this_block, + (int)q_offset, + (int)o_offset); + } + __syncthreads(); +#endif + + smem_t qo_smem(smem); + + /* + 1 | 3 + —————— + 2 | 4 + */ + uint32_t q_smem_offset_r = smem_t::get_permuted_offset( + tid % 16, tid / 16); // 16 * 16 + load_q_global_smem_multi_warps(q_base_ptr, + &qo_smem, + q_base_seq_id_this_block, + q_end, + q_ori_n_stride, + HEAD_DIM); + commit_group(); + wait_group<0>(); + __syncthreads(); + + q_smem_inplace_multiply_sm_scale_multi_warps( + &qo_smem, scale); +#ifdef DEBUG_ATTN_C8 + if (tid == PRINT_TID && wid == 0 && blockIdx.z == 0 && + blockIdx.x == gridDim.x - 1) { + printf("after scale\n"); + T *q_smem_t = reinterpret_cast(qo_smem.base); + for (uint32_t i = 0; i < num_frags_x * 16; ++i) { + for (uint32_t j = 0; j < num_frags_y * 16; ++j) { + if (blockIdx.z == 0) { + printf("q_smem[%d][%d] = %f ", + (int)i, + (int)(j), + (float)q_smem_t[i * num_frags_y * 16 + j]); + } else { + int res = q_smem_t[i * num_frags_y * 16 + j] + static_cast(1.f); + } + } + printf("\n"); + } + } + __syncthreads(); +#endif + + smem_t k_smem(smem + num_frags_x * 16 * HEAD_DIM * sizeof(T)), + v_smem(smem + num_frags_x * 16 * HEAD_DIM * sizeof(T) + + NUM_WARP_KV * num_frags_z * 16 * HEAD_DIM * sizeof(CacheT)); + + const uint32_t num_iterations = div_up( + CAUSAL + ? (min(chunk_len, + sub_if_greater_or_zero( + kv_len - q_len + + div_up((tile_id + 1) * num_rows_per_block, GROUP_SIZE), + chunk_start))) + : chunk_len, + NUM_WARP_KV * num_frags_z * 16); + const uint32_t mask_check_iteration = + (CAUSAL ? (min(chunk_len, + sub_if_greater_or_zero( + kv_len - q_len + + tile_id * num_rows_per_block / GROUP_SIZE, + chunk_start))) + : chunk_len) / + (NUM_WARP_KV * num_frags_z * 16); +#ifdef DEBUG_ATTN_C8 + if (threadIdx.x == PRINT_TID && threadIdx.y == 0 && blockIdx.z == 0 && + blockIdx.x == gridDim.x - 1) { + printf( + "cid: %d, batch_id: %d, tile_id: %d, chunk_size: %d, q_len: %d, " + "kv_len: %d, chunk_start: %d, chunk_end: %d, num_iterations: %d, " + "mask_check_iteration: %d\n", + (int)blockIdx.y, + (int)batch_id, + (int)tile_id, + (int)chunk_size, + (int)q_len, + (int)kv_len, + (int)chunk_start, + (int)chunk_end, + (int)num_iterations, + (int)mask_check_iteration); + } + __syncthreads(); +#endif + /* + 1 | 2 + —————— + 3 | 4 + */ + uint32_t k_smem_offset_r = + smem_t::get_permuted_offset( + wid * num_frags_z * 16 + 8 * (tid / 16) + tid % 8, (tid % 16) / 8); + /* + 1 | 2 + —————— + 3 | 4 transpose + */ + uint32_t v_smem_offset_r = + smem_t::get_permuted_offset( + (wid / 2) * num_frags_y * 16 + 8 * (tid / 16) + tid % 8, + (wid % 2) * num_frags_z + (tid % 16) / 8); + + uint32_t k_smem_offset_w = + smem_t::get_permuted_offset( + wid * 4 + tid / 8, + tid % + 8); // 4 * 128 / 8 = 64B, 128 nums, just fot head_dim >= 128 !!! + uint32_t v_smem_offset_w = + smem_t::get_permuted_offset( + wid * 8 + tid / 4, tid % 4); // 2 * 128 / 8 = 32B, 64 nums + + // uint32_t kv_idx_base = chunk_start; + // int block_id = __ldg(&block_table_now[kv_idx_base / BLOCK_SIZE]); + // const uint32_t const_offset = kv_head_idx * kv_h_stride + (wid * 4 + tid / + // 8) * kv_b_stride + tid % 8 * num_elems_per_128b(); + uint32_t kv_idx_base = chunk_start; + // int block_id = __ldg(&block_table_now[kv_idx_base / BLOCK_SIZE]); + const uint32_t const_k_offset = kv_head_idx * kv_h_stride + + (wid * 4 + tid / 8) * kv_b_stride + + tid % 8 * num_elems_per_128b(); + const uint32_t const_v_offset = kv_head_idx * kv_h_stride + + (wid * 8 + tid / 4) * kv_d_stride + + tid % 4 * num_elems_per_128b(); + // T *cache_k_now = cache_k + block_id * kv_n_stride + const_k_offset; + // T *cache_v_now = cache_v + block_id * kv_n_stride + const_v_offset; + +#ifdef DEBUG_ATTN_C8 + if (threadIdx.y == 0 && blockIdx.z == 0 && blockIdx.x == gridDim.x - 1 && + blockIdx.y == gridDim.y - 1) { + printf( + "000 tid: %d, ori q_smem_offset_r: %d, k_smem_offset_r: %d, " + "v_smem_offset_r: %d, k_smem_offset_w: %d, v_smem_offset_w: %d, " + "cache_k: %f, cache_k_p: %p, const_k_offset: %d, const_v_offset: %d\n", + (int)threadIdx.x, + (int)q_smem_offset_r, + (int)k_smem_offset_r, + (int)v_smem_offset_r, + (int)k_smem_offset_w, + (int)v_smem_offset_w, + (float)(*cache_k), + cache_k, + (int)const_k_offset, + (int)const_v_offset); + } + __syncthreads(); +#endif + + // load BLOCK_SIZE * HEAD_DIM each time + produce_k_blockwise_c8(k_smem, + &k_smem_offset_w, + cache_k, + block_table_now, + kv_head_idx, + kv_n_stride, + kv_h_stride, + kv_b_stride, + kv_idx_base, + chunk_end, + const_k_offset); + commit_group(); + produce_v_blockwise_c8(v_smem, + &v_smem_offset_w, + cache_v, + block_table_now, + kv_head_idx, + kv_n_stride, + kv_h_stride, + kv_d_stride, + kv_idx_base, + chunk_end, + const_v_offset); + commit_group(); +#pragma unroll 1 + for (uint32_t iter = 0; iter < num_iterations; ++iter) { + wait_group<1>(); + __syncthreads(); +#ifdef DEBUG_ATTN_C8 + if (tid == PRINT_TID && wid == 0 && blockIdx.z == 0 && + blockIdx.x == gridDim.x - 1) { + printf("cache_k_smem\n"); + uint8_t *k_smem_t = reinterpret_cast(k_smem.base); + for (uint32_t i = 0; i < NUM_WARP_KV * num_frags_z * 16; ++i) { + for (uint32_t j = 0; j < num_frags_y * 16; ++j) { + printf("k_smem[%d][%d] = %d ", + (int)i, + (int)j, + (int)k_smem_t[i * num_frags_y * 16 + j]); + } + printf("\n"); + } + } + __syncthreads(); +#endif + // s = qk + compute_qk_c8( + &qo_smem, + &q_smem_offset_r, + &k_smem, + &k_smem_offset_r, + cache_k_scale_reg, + s_frag); + // mask according to kv_idx and q_idx + if (iter >= mask_check_iteration) { + // if (q_len > 1 && iter >= mask_check_iteration) { // not need mask in + // decoder, v will be filled with 0 + mask_s(q_base_seq_id_this_block, + kv_idx_base + wid * num_frags_z * 16, + q_len, + kv_len, + chunk_end, + s_frag); + } +#ifdef DEBUG_ATTN_C8 + if (threadIdx.x == PRINT_TID && threadIdx.y == 0 && blockIdx.z == 0 && + blockIdx.x == gridDim.x - 1) { + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { + for (uint32_t fz = 0; fz < num_frags_z; ++fz) { + for (int k = 0; k < 8; k++) { + printf("mask_s s_frag[%d][%d][%d]: %f ", + (int)fx, + (int)fz, + (int)k, + s_frag[fx][fz][k]); + } + printf("\n"); + } + printf("\n"); + } + } + __syncthreads(); +#endif + + // update m,d + update_mdo_states( + s_frag, o_frag, m_frag, d_frag); + __syncthreads(); +#ifdef DEBUG_ATTN_C8 + if (threadIdx.x == PRINT_TID && threadIdx.y == 0 && blockIdx.z == 0 && + blockIdx.x == gridDim.x - 1) { + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { + for (uint32_t fz = 0; fz < num_frags_z; ++fz) { + for (int k = 0; k < 8; k++) { + printf("update_mdo_states s_frag[%d][%d][%d]: %f ", + (int)fx, + (int)fz, + (int)k, + s_frag[fx][fz][k]); + } + printf("\n"); + } + printf("\n"); + } + } + __syncthreads(); +#endif + + kv_idx_base += NUM_WARP_KV * num_frags_z * 16; + produce_k_blockwise_c8(k_smem, + &k_smem_offset_w, + cache_k, + block_table_now, + kv_head_idx, + kv_n_stride, + kv_h_stride, + kv_b_stride, + kv_idx_base, + chunk_end, + const_k_offset); + commit_group(); + wait_group<1>(); + __syncthreads(); +#ifdef DEBUG_ATTN_C8 + if (tid == PRINT_TID && wid == 0 && blockIdx.z == 0 && + blockIdx.x == gridDim.x - 1) { + printf("cache_v_smem\n"); + uint8_t *v_smem_t = reinterpret_cast(v_smem.base); + for (uint32_t i = 0; i < NUM_WARP_KV / 2 * num_frags_y * 16; ++i) { + for (uint32_t j = 0; j < 2 * num_frags_z * 16; ++j) { + printf("v_smem[%d][%d] = %d ", + (int)i, + (int)j, + (int)v_smem_t[i * 2 * num_frags_z * 16 + j]); + } + printf("\n"); + } + } + __syncthreads(); +#endif + + // compute sfm * v + compute_sfm_v_c8_iter_sq_bvec( + &v_smem, &v_smem_offset_r, s_frag, o_frag, d_frag, cache_v_scale_reg); + __syncthreads(); + + produce_v_blockwise_c8(v_smem, + &v_smem_offset_w, + cache_v, + block_table_now, + kv_head_idx, + kv_n_stride, + kv_h_stride, + kv_d_stride, + kv_idx_base, + chunk_end, + const_v_offset); + commit_group(); + } + wait_group<0>(); + __syncthreads(); +#ifdef DEBUG_ATTN_C8 + if (threadIdx.x == PRINT_TID && threadIdx.y == 0 && blockIdx.z == 0 && + blockIdx.x == gridDim.x - 1) { + printf("before merge z\n"); + for (uint32_t i = 0; i < num_frags_x; ++i) { + printf("m1: %f, m2: %f\n", m_frag[i][0], m_frag[i][1]); + printf("d1: %f, d2: %f\n", d_frag[i][0], d_frag[i][1]); + for (uint32_t j = 0; j < num_frags_y; ++j) { + for (int r_id = 0; r_id < 8; r_id++) { + printf("o_frag[%d][%d][%d]: %f ", + (int)i, + (int)j, + r_id, + o_frag[i][j][r_id]); + } + } + printf("\n"); + } + } + __syncthreads(); +#endif + + merge_block_res_v2( + o_frag, reinterpret_cast(smem), m_frag, d_frag, wid, tid); +#ifdef DEBUG_ATTN_C8 + if (threadIdx.x == PRINT_TID && threadIdx.y == 0 && blockIdx.z == 0 && + blockIdx.x == 0 && blockIdx.y == 0) { + printf("after merge z\n"); + for (uint32_t i = 0; i < num_frags_x; ++i) { + printf("m1: %f, m2: %f\n", m_frag[i][0], m_frag[i][1]); + printf("d1: %f, d2: %f\n", d_frag[i][0], d_frag[i][1]); + for (uint32_t j = 0; j < num_frags_y; ++j) { + for (int r_id = 0; r_id < 8; r_id++) { + printf("o_frag[%d][%d][%d]: %f ", + (int)i, + (int)j, + r_id, + o_frag[i][j][r_id]); + } + } + printf("\n"); + } + } + __syncthreads(); +#endif + + if (num_chunks_this_seq <= 1) { + normalize_d(o_frag, d_frag); + } +#ifdef DEBUG_ATTN_C8 + __syncthreads(); + if (threadIdx.x == PRINT_TID && threadIdx.y == 0 && blockIdx.z == 0 && + blockIdx.x == gridDim.x - 1) { + printf("after normalize_d\n"); + for (uint32_t i = 0; i < num_frags_x; ++i) { + printf("m1: %f, m2: %f\n", m_frag[i][0], m_frag[i][1]); + printf("d1: %f, d2: %f\n", d_frag[i][0], d_frag[i][1]); + for (uint32_t j = 0; j < num_frags_y; ++j) { + for (int r_id = 0; r_id < 8; r_id++) { + printf("o_frag[%d][%d][%d]: %f ", + (int)i, + (int)j, + r_id, + o_frag[i][j][r_id]); + } + } + printf("\n"); + } + } + __syncthreads(); +#endif + + // write o + // [num_frags_x, 16, num_frags_y, 16] + if (num_chunks_this_seq <= 1) { + write_o_reg_gmem_multi_warps_shift_smooth_quant( + o_frag, + &qo_smem, + o_base_ptr_int8, + shift_bias, + smooth_weight, + q_base_seq_id_this_block, + q_head_idx, + in_scale, + q_len, + q_n_stride, + HEAD_DIM + ); + } else { + write_o_reg_gmem_multi_warps_shift_smooth_quant( + o_frag, + &qo_smem, + o_base_ptr_T, + shift_bias, + smooth_weight, + q_base_seq_id_this_block, + q_head_idx, + in_scale, + q_len, + q_n_stride * num_chunks, + HEAD_DIM); + // } else { + // write_o_reg_gmem_multi_warps_shift_smooth_quant( + // o_frag, + // &qo_smem, + // o_base_ptr_int8, + // shift_bias, + // smooth_weight, + // q_base_seq_id_this_block, + // q_head_idx, + // in_scale, + // q_len, + // partition_kv ? q_n_stride * num_chunks : q_n_stride, + // HEAD_DIM); + } +#ifdef DEBUG_ATTN_C8 + __syncthreads(); + if (threadIdx.x == PRINT_TID && threadIdx.y == 0 && blockIdx.z == 0 && + blockIdx.x == gridDim.x - 1) { + printf("after normalize_d\n"); + for (uint32_t i = 0; i < num_frags_x; ++i) { + printf("m1: %f, m2: %f\n", m_frag[i][0], m_frag[i][1]); + printf("d1: %f, d2: %f\n", d_frag[i][0], d_frag[i][1]); + for (uint32_t j = 0; j < num_frags_y; ++j) { + for (int r_id = 0; r_id < 8; r_id++) { + printf("o_frag[%d][%d][%d]: %f ", + (int)i, + (int)j, + r_id, + o_frag[i][j][r_id]); + } + } + printf("\n"); + } + } + __syncthreads(); +#endif + + if (num_chunks_this_seq > 1) { + if (wid == 0) { +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { +#pragma unroll + for (uint32_t j = 0; j < 2; ++j) { + const uint32_t qo_idx_now = + q_base_seq_id_this_block + tid / 4 + j * 8 + fx * 16; + const uint32_t qo_head_idx = q_head_idx + qo_idx_now % GROUP_SIZE; + const uint32_t qo_idx = q_start_seq_id + qo_idx_now / GROUP_SIZE; +#ifdef DEBUG_ATTN_C8 + if (batch_id == 0) { + printf( + "bid: %d, tid: %d, wid: %d, q_base_seq_id_this_block: %d, " + "qo_idx_now: %d, qo_idx: %d, q_start_seq_id: %d, q_len: %d, m: " + "%f, d: %f\n", + (int)batch_id, + (int)tid, + (int)wid, + (int)q_base_seq_id_this_block, + (int)qo_idx_now, + (int)qo_idx, + (int)q_start_seq_id, + (int)q_len, + (float)m_frag[fx][j], + (float)d_frag[fx][j]); + } +#endif + if (qo_idx - q_start_seq_id < q_len) { + // const uint32_t offset = (qo_idx * num_chunks + chunk_idx) * + // q_num_heads + qo_head_idx; + + uint32_t offset; + if (ENABLE_PREFILL) { + offset = (batch_id * num_chunks + chunk_idx) * q_num_heads + + qo_head_idx; + } else { + offset = ((batch_id * speculate_max_draft_token_num + + qo_idx_now / GROUP_SIZE) * + num_chunks + + chunk_idx) * + q_num_heads + + qo_head_idx; + } + tmp_m[offset] = m_frag[fx][j]; + tmp_d[offset] = d_frag[fx][j]; + } + } + } + } + } +} + +template +__global__ void multi_query_append_attention_c4_kernel( + T *__restrict__ q, // [token_num. num_heads, head_dim] + CacheT *__restrict__ cache_k, // [max_block_num, num_heads, block_size, + // head_dim] + CacheT *__restrict__ cache_v, + const T *__restrict__ cache_k_scale, // [num_kv_heads, head_dim] + const T *__restrict__ cache_k_zero_point, // [num_kv_heads, head_dim] + const T *__restrict__ cache_v_scale, // [num_kv_heads, head_dim] + const T *__restrict__ cache_v_zero_point, // [num_kv_heads, head_dim] + const T *__restrict__ shift_bias, // [q_num_heads * HEAD_DIM] + const T *__restrict__ smooth_weight, // [q_num_heads * HEAD_DIM] + const int *__restrict__ seq_lens, + const int *__restrict__ seq_lens_kv, + const int *__restrict__ batch_ids, + const int *__restrict__ tile_ids_per_batch, + const int *__restrict__ cum_offsets, + const int *__restrict__ block_table, // [bsz, block_num_per_seq] + const int max_seq_len, + const int max_dec_len, + const int max_block_num_per_seq, + const float scale, + const float in_scale, + const uint32_t chunk_size, + T *__restrict__ tmp_workspace, // split kv [token_num, num_chunks, + // num_heads, head_dim] + float *__restrict__ tmp_m, // [token_num, num_chunks, num_heads] + float *__restrict__ tmp_d, // [token_num, num_chunks, num_heads] + OutT *__restrict__ out, + const int speculate_max_draft_token_num = 5) { + constexpr uint32_t num_vecs_per_head = HEAD_DIM / num_elems_per_128b(); + constexpr uint32_t num_vecs_per_head_k = + HEAD_DIM / 2 / num_elems_per_128b(); + constexpr uint32_t num_vecs_per_blocksize = + BLOCK_SIZE / 2 / num_elems_per_128b(); + constexpr uint32_t inv_k_stride = 8 / num_vecs_per_head_k; + constexpr uint32_t inv_v_stride = 8 / num_vecs_per_blocksize; + const uint32_t btid = blockIdx.x, kv_head_idx = blockIdx.z; + const uint32_t kv_num_heads = gridDim.z; + const uint32_t q_num_heads = kv_num_heads * GROUP_SIZE; + const uint32_t q_head_idx = kv_head_idx * GROUP_SIZE; + const uint32_t tid = threadIdx.x, wid = threadIdx.y; + const uint32_t num_chunks = gridDim.y; + const uint32_t chunk_idx = blockIdx.y; +#ifdef DEBUG_ATTN_C4 + if (tid == PRINT_TID && wid == 0 && blockIdx.z == 0) { + printf( + "num_vecs_per_head: %d, num_vecs_per_head_k: %d, " + "num_vecs_per_blocksize: %d, inv_k_stride: %d, inv_v_stride: %d\n", + (int)num_vecs_per_head, + (int)num_vecs_per_head_k, + (int)num_vecs_per_blocksize, + (int)inv_k_stride, + (int)inv_v_stride); + } + __syncthreads(); +#endif + + const uint32_t batch_id = batch_ids[btid]; + const uint32_t tile_id = tile_ids_per_batch[btid]; + const uint32_t num_rows_per_block = NUM_WARPS * num_frags_x * 16; + const int *block_table_now = nullptr; + + block_table_now = block_table + batch_id * max_block_num_per_seq; + + const uint32_t q_len = seq_lens[batch_id]; + if (q_len <= 0) { + return; + } + const uint32_t q_end = + min(q_len, div_up((tile_id + 1) * num_rows_per_block, GROUP_SIZE)); + uint32_t kv_len = seq_lens_kv[batch_id]; + if (ENABLE_PREFILL) { + kv_len += q_len; // !!! + if (kv_len <= 0) { + return; + } + } else { + if (kv_len <= 0) { + return; + } + kv_len += q_len; + } + const uint32_t num_chunks_this_seq = div_up(kv_len, chunk_size); + if (chunk_idx >= num_chunks_this_seq) { + return; + } + + const uint32_t chunk_start = partition_kv ? chunk_idx * chunk_size : 0; + const uint32_t chunk_end = + partition_kv ? min(kv_len, chunk_start + chunk_size) : kv_len; + const uint32_t chunk_len = chunk_end - chunk_start; + + extern __shared__ uint8_t smem[]; + float s_frag[num_frags_x][num_frags_z][8]; + float o_frag[num_frags_x][num_frags_y][8]; + float m_frag[num_frags_x][2]; + float d_frag[num_frags_x][2]; + + // load kv scale/zp + // TODO(load kv scale and zp to smem) + const T *cache_k_scale_now = cache_k_scale + kv_head_idx * HEAD_DIM; + const T *cache_k_zp_now = cache_k_zero_point + kv_head_idx * HEAD_DIM; + const T *cache_v_scale_now = cache_v_scale + kv_head_idx * HEAD_DIM; + const T *cache_v_zp_now = cache_v_zero_point + kv_head_idx * HEAD_DIM; + // constexpr uint32_t HEAD_DIM_PAD = div_up(HEAD_DIM, 4) * 4; + T *cache_k_scale_smem = reinterpret_cast( + smem + NUM_WARPS * num_frags_x * 16 * HEAD_DIM * sizeof(T) + + num_frags_z * 16 * HEAD_DIM / 2 * sizeof(CacheT) * 2); + T *cache_k_zero_point_smem = cache_k_scale_smem + HEAD_DIM; + T *cache_v_scale_smem = cache_k_zero_point_smem + HEAD_DIM; + T *cache_v_zero_point_smem = cache_v_scale_smem + HEAD_DIM; +#pragma unroll + for (uint32_t i = wid * 32 + tid; i < HEAD_DIM; i += 128) { + cache_k_scale_smem[i] = cache_k_scale_now[i]; + cache_k_zero_point_smem[i] = cache_k_zp_now[i]; + cache_v_scale_smem[i] = cache_v_scale_now[i]; + cache_v_zero_point_smem[i] = cache_v_zp_now[i]; + } + + init_states(o_frag, m_frag, d_frag); + + const uint32_t q_n_stride = q_num_heads * HEAD_DIM; + const uint32_t q_ori_n_stride = (q_num_heads + kv_num_heads * 2) * HEAD_DIM; + const uint32_t kv_n_stride = kv_num_heads * BLOCK_SIZE * HEAD_DIM / 2; + const uint32_t kv_h_stride = BLOCK_SIZE * HEAD_DIM / 2; + const uint32_t kv_b_stride = HEAD_DIM / 2; + const uint32_t kv_d_stride = BLOCK_SIZE / 2; + const uint32_t q_start_seq_id = + batch_id * max_seq_len - __ldg(&cum_offsets[batch_id]); + const uint32_t q_base_seq_id_this_block = + (tile_id * NUM_WARPS + wid) * num_frags_x * 16; + const uint32_t q_offset = q_start_seq_id * q_ori_n_stride + + q_head_idx * HEAD_DIM + + tid % 8 * num_elems_per_128b(); + const uint32_t o_offset = q_start_seq_id * q_n_stride + + q_head_idx * HEAD_DIM + + tid % 8 * num_elems_per_128b(); + T *q_base_ptr = q + q_offset; +#ifdef DEBUG_ATTN_C4 + if (tid == PRINT_TID && wid == 0 && blockIdx.z == 0) { + printf( + "q_base_seq_id_this_block: %d, q_start_seq_id: %d, q_offset: %d, " + "q_ori_n_stride: %d, q_base: %f\n", + (int)q_base_seq_id_this_block, + (int)q_start_seq_id, + (int)q_offset, + (int)q_ori_n_stride, + (float)*q_base_ptr); + } + __syncthreads(); +#endif + T *o_base_ptr_T = nullptr; + OutT *o_base_ptr_int8 = nullptr; + if constexpr (partition_kv) { + if (ENABLE_PREFILL) { + o_base_ptr_T = tmp_workspace + q_start_seq_id * num_chunks * q_n_stride + + chunk_idx * q_n_stride + q_head_idx * HEAD_DIM + + tid % 8 * num_elems_per_128b(); + } else { + o_base_ptr_T = + tmp_workspace + + batch_id * speculate_max_draft_token_num * num_chunks * q_n_stride + + chunk_idx * q_n_stride + q_head_idx * HEAD_DIM + + tid % 8 * num_elems_per_128b(); + } + } else { + o_base_ptr_int8 = out + o_offset; + } + smem_t qo_smem(smem); + + /* + 1 | 3 + —————— + 2 | 4 + */ + uint32_t q_smem_offset_r = smem_t::get_permuted_offset( + wid * num_frags_x * 16 + tid % 16, tid / 16); // 16 * 16 + load_q_global_smem( + q_base_ptr, + &qo_smem, + q_base_seq_id_this_block, + q_end, + q_ori_n_stride, + HEAD_DIM); + commit_group(); + wait_group<0>(); + __syncthreads(); + + q_smem_inplace_multiply_sm_scale(&qo_smem, + scale); +#ifdef DEBUG_ATTN_C4 + if (tid == PRINT_TID && wid == 0 && blockIdx.z == 0) { + printf("after scale\n"); + T *q_smem_t = reinterpret_cast(qo_smem.base); + for (uint32_t i = 0; i < 4 * num_frags_x * 16; ++i) { + for (uint32_t j = 0; j < num_frags_y * 16; ++j) { + printf("q_smem[%d][%d] = %f ", + (int)i, + (int)j, + (float)q_smem_t[i * num_frags_y * 16 + j]); + } + printf("\n"); + } + } + __syncthreads(); +#endif + + T cache_k_scale_frag[num_frags_y][4]; + T cache_k_zp_frag[num_frags_y][4]; + T magic_number; + if constexpr (std::is_same::value) { + magic_number = static_cast(1032.f); + } else { + magic_number = static_cast(136.f); + } +#pragma unroll + for (uint32_t fy = 0; fy < num_frags_y; ++fy) { + *(reinterpret_cast(&cache_k_scale_frag[fy][0])) = + *(reinterpret_cast(&cache_k_scale_smem[fy * 16]) + tid % 4); + *(reinterpret_cast(&cache_k_scale_frag[fy][2])) = + *(reinterpret_cast(&cache_k_scale_smem[fy * 16]) + tid % 4 + + 4); + *(reinterpret_cast(&cache_k_zp_frag[fy][0])) = + *(reinterpret_cast(&cache_k_zero_point_smem[fy * 16]) + + tid % 4); + *(reinterpret_cast(&cache_k_zp_frag[fy][2])) = + *(reinterpret_cast(&cache_k_zero_point_smem[fy * 16]) + + tid % 4 + 4); +#pragma unroll + for (uint32_t zp_i = 0; zp_i < 4; ++zp_i) { + cache_k_zp_frag[fy][zp_i] += magic_number; // 128 + 8 + } + } + T cache_v_scale_frag[num_frags_y][2]; + T cache_v_zp_frag[num_frags_y][2]; +#pragma unroll + for (uint32_t fy = 0; fy < num_frags_y; ++fy) { + cache_v_scale_frag[fy][0] = cache_v_scale_smem[fy * 16 + tid / 4]; + cache_v_scale_frag[fy][1] = cache_v_scale_smem[fy * 16 + tid / 4 + 8]; + cache_v_zp_frag[fy][0] = + cache_v_zero_point_smem[fy * 16 + tid / 4] + magic_number; + cache_v_zp_frag[fy][1] = + cache_v_zero_point_smem[fy * 16 + tid / 4 + 8] + magic_number; + } + + smem_t k_smem(smem + NUM_WARPS * num_frags_x * 16 * HEAD_DIM * sizeof(T)), + v_smem(smem + NUM_WARPS * num_frags_x * 16 * HEAD_DIM * sizeof(T) + + num_frags_z * 16 * HEAD_DIM / 2 * sizeof(CacheT)); + + + const uint32_t num_iterations = div_up( + CAUSAL + ? (min(chunk_len, + sub_if_greater_or_zero( + kv_len - q_len + + div_up((tile_id + 1) * num_rows_per_block, GROUP_SIZE), + chunk_start))) + : chunk_len, + num_frags_z * 16); + const uint32_t mask_check_iteration = + (CAUSAL ? (min(chunk_len, + sub_if_greater_or_zero( + kv_len - q_len + + tile_id * num_rows_per_block / GROUP_SIZE, + chunk_start))) + : chunk_len) / + (num_frags_z * 16); +#ifdef DEBUG_ATTN + if (tid == 0 && wid == 0) { + printf( + "batch_id: %d, tile_id: %d, chunk_size: %d, q_len: %d, kv_len: %d, " + "chunk_start: %d, chunk_end: %d, num_iterations: %d, " + "mask_check_iteration: %d\n", + (int)batch_id, + (int)tile_id, + (int)chunk_size, + (int)q_len, + (int)kv_len, + (int)chunk_start, + (int)chunk_end, + (int)num_iterations, + (int)mask_check_iteration); + } + __syncthreads(); +#endif + + /* + 1 | 2 + —————— + 3 | 4 + */ + uint32_t k_smem_offset_r = + smem_t::get_permuted_offset( + 8 * (tid / 16) + tid % 8, (tid % 16) / 8); + /* + 1 | 2 + —————— + 3 | 4 + */ + uint32_t v_smem_offset_r = + smem_t::get_permuted_offset( + 8 * (tid / 16) + tid % 8, (tid % 16) / 8); + + uint32_t k_smem_offset_w = + smem_t::get_permuted_offset( + wid * 8 + tid / 4, + tid % + 4); // 4 * 128 / 8 = 64B, 128 nums, just fot head_dim >= 128 !!! + uint32_t v_smem_offset_w = + smem_t::get_permuted_offset( + wid * 16 + tid / 2, tid % 2); // 2 * 128 / 8 = 32B, 64 nums + + uint32_t kv_idx_base = chunk_start; + // int block_id = __ldg(&block_table_now[kv_idx_base / BLOCK_SIZE]); + const uint32_t const_k_offset = kv_head_idx * kv_h_stride + + (wid * 8 + tid / 4) * kv_b_stride + + tid % 4 * num_elems_per_128b(); + const uint32_t const_v_offset = kv_head_idx * kv_h_stride + + (wid * 16 + tid / 2) * kv_d_stride + + tid % 2 * num_elems_per_128b(); + // CacheT *cache_k_now = cache_k + block_id * kv_n_stride + const_k_offset; + // CacheT *cache_v_now = cache_v + block_id * kv_n_stride + const_v_offset; +#ifdef DEBUG_ATTN_C4 + if (threadIdx.x == PRINT_TID && threadIdx.y == 0 && blockIdx.z == 0) { + printf( + "752 ori q_smem_offset_r: %d, k_smem_offset_r: %d, v_smem_offset_r: " + "%d, k_smem_offset_w: %d, v_smem_offset_w: %d\n", + (int)q_smem_offset_r, + (int)k_smem_offset_r, + (int)v_smem_offset_r, + (int)k_smem_offset_w, + (int)v_smem_offset_w); + } + __syncthreads(); +#endif + + produce_k_blockwise_c4(k_smem, + &k_smem_offset_w, + // &cache_k_now, + cache_k, + block_table_now, + kv_head_idx, + kv_n_stride, + kv_h_stride, + kv_b_stride, + kv_idx_base, + chunk_end, + const_k_offset); + commit_group(); + produce_v_blockwise_c4(v_smem, + &v_smem_offset_w, + // &cache_v_now, + cache_v, + block_table_now, + kv_head_idx, + kv_n_stride, + kv_h_stride, + kv_d_stride, + kv_idx_base, + chunk_end, + const_v_offset); + commit_group(); +#ifdef DEBUG_ATTN_C4 + if (threadIdx.x == PRINT_TID && threadIdx.y == 0 && blockIdx.z == 0) { + printf( + "782 ori q_smem_offset_r: %d, k_smem_offset_r: %d, v_smem_offset_r: " + "%d, k_smem_offset_w: %d, v_smem_offset_w: %d\n", + (int)q_smem_offset_r, + (int)k_smem_offset_r, + (int)v_smem_offset_r, + (int)k_smem_offset_w, + (int)v_smem_offset_w); + } + __syncthreads(); +#endif + +#pragma unroll 1 + for (uint32_t iter = 0; iter < num_iterations; ++iter) { + wait_group<1>(); + __syncthreads(); +#ifdef DEBUG_ATTN_C4 + if (tid == PRINT_TID && wid == 0 && blockIdx.z == 0 && + blockIdx.x == gridDim.x - 1) { + printf("cache_k_smem\n"); + uint8_t *k_smem_t = reinterpret_cast(k_smem.base); + for (uint32_t i = 0; i < num_frags_z * 16; ++i) { + for (uint32_t j = 0; j < num_frags_y * 16 / 2; ++j) { + printf("k_smem[%d][%d] = %d ", + (int)i, + (int)j, + (int)k_smem_t[i * num_frags_y * 16 / 2 + j]); + } + printf("\n"); + } + } + __syncthreads(); +#endif + // s = qk + compute_qk_c4( + &qo_smem, + &q_smem_offset_r, + &k_smem, + &k_smem_offset_r, + s_frag, + cache_k_scale_frag, + cache_k_zp_frag); +#ifdef DEBUG_ATTN_C4 + if (threadIdx.x == PRINT_TID && threadIdx.y == 0 && blockIdx.z == 0) { + printf( + "111 iter: %d, q_smem_offset_r: %d, k_smem_offset_r: %d, " + "v_smem_offset_r: %d, k_smem_offset_w: %d, v_smem_offset_w: %d\n", + (int)iter, + (int)q_smem_offset_r, + (int)k_smem_offset_r, + (int)v_smem_offset_r, + (int)k_smem_offset_w, + (int)v_smem_offset_w); + } + __syncthreads(); +#endif + // mask according to kv_idx and q_idx + if (iter >= mask_check_iteration) { + mask_s(q_base_seq_id_this_block, + kv_idx_base, + q_len, + kv_len, + chunk_end, + s_frag); + } +#ifdef DEBUG_ATTN_C4 + if (threadIdx.x == PRINT_TID && threadIdx.y == 0 && blockIdx.z == 0 && + blockIdx.x == gridDim.x - 1) { + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { + for (uint32_t fz = 0; fz < num_frags_z; ++fz) { + for (int k = 0; k < 8; k++) { + printf("mask_s s_frag[%d][%d][%d]: %f ", + (int)fx, + (int)fz, + (int)k, + s_frag[fx][fz][k]); + } + printf("\n"); + } + printf("\n"); + } + } + __syncthreads(); +#endif + + // update m,d + update_mdo_states( + s_frag, o_frag, m_frag, d_frag); + __syncthreads(); +#ifdef DEBUG_ATTN_C4 + if (threadIdx.x == PRINT_TID && threadIdx.y == 0 && blockIdx.z == 0 && + blockIdx.x == gridDim.x - 1) { + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { + for (uint32_t fz = 0; fz < num_frags_z; ++fz) { + for (int k = 0; k < 8; k++) { + printf("update_mdo_states s_frag[%d][%d][%d]: %f ", + (int)fx, + (int)fz, + (int)k, + s_frag[fx][fz][k]); + } + printf("\n"); + } + printf("\n"); + } + } + __syncthreads(); +#endif + + kv_idx_base += num_frags_z * 16; + produce_k_blockwise_c4(k_smem, + &k_smem_offset_w, + // &cache_k_now, + cache_k, + block_table_now, + kv_head_idx, + kv_n_stride, + kv_h_stride, + kv_b_stride, + kv_idx_base, + chunk_end, + const_k_offset); + commit_group(); + wait_group<1>(); + __syncthreads(); +#ifdef DEBUG_ATTN_C4 + if (tid == PRINT_TID && wid == 0 && blockIdx.z == 0 && + blockIdx.x == gridDim.x - 1) { + printf("cache_v_smem\n"); + uint8_t *v_smem_t = reinterpret_cast(v_smem.base); + for (uint32_t i = 0; i < num_frags_y * 16; ++i) { + for (uint32_t j = 0; j < num_frags_z * 16 / 2; ++j) { + printf("v_smem[%d][%d] = %d ", + (int)(iter * 128 + i), + (int)j, + (int)v_smem_t[i * num_frags_z * 16 / 2 + j]); + } + printf("\n"); + } + } + __syncthreads(); +#endif + + // compute sfm*v + compute_sfm_v_c4(&v_smem, + &v_smem_offset_r, + s_frag, + o_frag, + d_frag, + cache_v_scale_frag, + cache_v_zp_frag); + __syncthreads(); + + produce_v_blockwise_c4(v_smem, + &v_smem_offset_w, + // &cache_v_now, + cache_v, + block_table_now, + kv_head_idx, + kv_n_stride, + kv_h_stride, + kv_d_stride, + kv_idx_base, + chunk_end, + const_v_offset); + commit_group(); + } + wait_group<0>(); + __syncthreads(); +#ifdef DEBUG_ATTN_C4 + if (threadIdx.x == PRINT_TID && threadIdx.y == 0 && blockIdx.z == 0 && + blockIdx.x == gridDim.x - 1) { + printf("tmp res\n"); + for (uint32_t i = 0; i < num_frags_x; ++i) { + printf("m1: %f, m2: %f\n", m_frag[i][0], m_frag[i][1]); + printf("d1: %f, d2: %f\n", d_frag[i][0], d_frag[i][1]); + for (uint32_t j = 0; j < num_frags_y; ++j) { + for (int r_id = 0; r_id < 8; r_id++) { + printf("o_frag[%d][%d][%d]: %f ", + (int)i, + (int)j, + r_id, + o_frag[i][j][r_id]); + } + } + printf("\n"); + } + } + __syncthreads(); +#endif + + if constexpr (!partition_kv) { + normalize_d(o_frag, d_frag); + } + + // write o + // [num_frags_x, 16, num_frags_y, 16] + if constexpr (partition_kv) { + write_o_reg_gmem_shift_smooth_quant( + o_frag, + &qo_smem, + o_base_ptr_T, + shift_bias, + smooth_weight, + q_base_seq_id_this_block, + q_head_idx, + in_scale, + q_len, + partition_kv ? q_n_stride * num_chunks : q_n_stride, + HEAD_DIM); + } else { + write_o_reg_gmem_shift_smooth_quant( + o_frag, + &qo_smem, + o_base_ptr_int8, + shift_bias, + smooth_weight, + q_base_seq_id_this_block, + q_head_idx, + in_scale, + q_len, + partition_kv ? q_n_stride * num_chunks : q_n_stride, + HEAD_DIM); + } + + if constexpr (partition_kv) { +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { +#pragma unroll + for (uint32_t j = 0; j < 2; ++j) { + const uint32_t qo_idx_now = + q_base_seq_id_this_block + tid / 4 + j * 8 + fx * 16; + const uint32_t qo_head_idx = q_head_idx + qo_idx_now % GROUP_SIZE; + const uint32_t qo_idx = q_start_seq_id + qo_idx_now / GROUP_SIZE; + if (qo_idx - q_start_seq_id < q_len) { + uint32_t offset; + if (ENABLE_PREFILL) { + offset = + (qo_idx * num_chunks + chunk_idx) * q_num_heads + qo_head_idx; + } else { + offset = ((batch_id * speculate_max_draft_token_num + + qo_idx_now / GROUP_SIZE) * + num_chunks + + chunk_idx) * + q_num_heads + + qo_head_idx; + } + tmp_m[offset] = m_frag[fx][j]; + tmp_d[offset] = d_frag[fx][j]; + } + } + } + } +} + +template +__global__ void multi_query_append_attention_c4_warp1_4_kernel( + T *__restrict__ q, // [token_num. num_heads, head_dim] + CacheT *__restrict__ cache_k, // [max_block_num, num_heads, block_size, + // head_dim] + CacheT *__restrict__ cache_v, + const T *__restrict__ cache_k_scale, // [num_kv_heads, head_dim] + const T *__restrict__ cache_k_zero_point, // [num_kv_heads, head_dim] + const T *__restrict__ cache_v_scale, // [num_kv_heads, head_dim] + const T *__restrict__ cache_v_zero_point, // [num_kv_heads, head_dim] + const T *__restrict__ shift_bias, // [q_num_heads * HEAD_DIM] + const T *__restrict__ smooth_weight, // [q_num_heads * HEAD_DIM] + const int *__restrict__ seq_lens, + const int *__restrict__ seq_lens_kv, + const int *__restrict__ batch_ids, + const int *__restrict__ tile_ids_per_batch, + const int *__restrict__ cum_offsets, + const int *__restrict__ block_table, // [bsz, block_num_per_seq] + const int max_seq_len, + const int max_dec_len, + const int max_block_num_per_seq, + const float scale, + const float in_scale, + const uint32_t chunk_size, + T *__restrict__ tmp_workspace, // split kv [token_num, num_chunks, + // num_heads, head_dim] + float *__restrict__ tmp_m, // [token_num, num_chunks, num_heads] + float *__restrict__ tmp_d, // [token_num, num_chunks, num_heads] + OutT *__restrict__ out, + const int speculate_max_draft_token_num = 5) { + // q_len <= 32, num_frags_x = 1/2, num_frags_z = 4 / 4 * 1/2/4, num_frags_y = + // HEAD_DIM / 16 + constexpr uint32_t num_vecs_per_head = HEAD_DIM / num_elems_per_128b(); + constexpr uint32_t num_vecs_per_head_k = + HEAD_DIM / 2 / num_elems_per_128b(); + constexpr uint32_t num_vecs_per_blocksize = + BLOCK_SIZE / 2 / num_elems_per_128b(); + constexpr uint32_t inv_k_stride = 8 / num_vecs_per_head_k; + constexpr uint32_t inv_v_stride = 8 / num_vecs_per_blocksize; + static_assert(NUM_WARP_Q == 1, "NUM_WARP_Q must be 1"); + static_assert(NUM_WARP_KV == 4, "NUM_WARP_KV must be 4"); + const uint32_t btid = blockIdx.x, kv_head_idx = blockIdx.z; + const uint32_t kv_num_heads = gridDim.z; + const uint32_t q_num_heads = kv_num_heads * GROUP_SIZE; + const uint32_t q_head_idx = kv_head_idx * GROUP_SIZE; + const uint32_t tid = threadIdx.x, wid = threadIdx.y; + const uint32_t num_chunks = gridDim.y; + const uint32_t chunk_idx = blockIdx.y; +#ifdef DEBUG_ATTN_C4 + if (tid == PRINT_TID && wid == 0 && blockIdx.z == 0) { + printf( + "num_vecs_per_head: %d, num_vecs_per_head_k: %d, " + "num_vecs_per_blocksize: %d, inv_k_stride: %d, inv_v_stride: %d\n", + (int)num_vecs_per_head, + (int)num_vecs_per_head_k, + (int)num_vecs_per_blocksize, + (int)inv_k_stride, + (int)inv_v_stride); + } + __syncthreads(); +#endif + + const uint32_t batch_id = batch_ids[btid]; + const uint32_t tile_id = tile_ids_per_batch[btid]; + const uint32_t num_rows_per_block = num_frags_x * 16; + const int *block_table_now = block_table + batch_id * max_block_num_per_seq; + + const uint32_t q_len = seq_lens[batch_id]; + if (q_len <= 0) { + return; + } + const uint32_t q_end = + min(q_len, div_up((tile_id + 1) * num_rows_per_block, GROUP_SIZE)); + uint32_t kv_len = seq_lens_kv[batch_id]; + if (ENABLE_PREFILL) { + kv_len += q_len; // !!! + if (kv_len <= 0) { + return; + } + } else { + if (kv_len <= 0) { + return; + } + kv_len += q_len; + } + const uint32_t num_chunks_this_seq = div_up(kv_len, chunk_size); + if (chunk_idx >= num_chunks_this_seq) { + return; + } + + const uint32_t chunk_start = partition_kv ? chunk_idx * chunk_size : 0; + const uint32_t chunk_end = + partition_kv ? min(kv_len, chunk_start + chunk_size) : kv_len; + const uint32_t chunk_len = chunk_end - chunk_start; + + extern __shared__ uint8_t smem[]; + float s_frag[num_frags_x][num_frags_z][8]; + float o_frag[num_frags_x][num_frags_y][8]; + float m_frag[num_frags_x][2]; + float d_frag[num_frags_x][2]; + init_states(o_frag, m_frag, d_frag); + + const T *cache_k_scale_now = cache_k_scale + kv_head_idx * HEAD_DIM; + const T *cache_k_zp_now = cache_k_zero_point + kv_head_idx * HEAD_DIM; + const T *cache_v_scale_now = cache_v_scale + kv_head_idx * HEAD_DIM; + const T *cache_v_zp_now = cache_v_zero_point + kv_head_idx * HEAD_DIM; + T *cache_k_scale_smem = reinterpret_cast( + smem + NUM_WARP_Q * num_frags_x * 16 * HEAD_DIM * sizeof(T) + + NUM_WARP_KV * num_frags_z * 16 * HEAD_DIM / 2 * sizeof(CacheT) * 2); + T *cache_k_zero_point_smem = cache_k_scale_smem + HEAD_DIM; + T *cache_v_scale_smem = cache_k_zero_point_smem + HEAD_DIM; + T *cache_v_zero_point_smem = cache_v_scale_smem + HEAD_DIM; +#pragma unroll + for (uint32_t i = wid * 32 + tid; i < HEAD_DIM; i += 128) { + cache_k_scale_smem[i] = cache_k_scale_now[i]; + cache_k_zero_point_smem[i] = cache_k_zp_now[i]; + cache_v_scale_smem[i] = cache_v_scale_now[i]; + cache_v_zero_point_smem[i] = cache_v_zp_now[i]; + } + + const uint32_t q_n_stride = q_num_heads * HEAD_DIM; + const uint32_t q_ori_n_stride = (q_num_heads + kv_num_heads * 2) * HEAD_DIM; + const uint32_t kv_n_stride = kv_num_heads * BLOCK_SIZE * HEAD_DIM / 2; + const uint32_t kv_h_stride = BLOCK_SIZE * HEAD_DIM / 2; + const uint32_t kv_b_stride = HEAD_DIM / 2; + const uint32_t kv_d_stride = BLOCK_SIZE / 2; + const uint32_t q_start_seq_id = + batch_id * max_seq_len - __ldg(&cum_offsets[batch_id]); + const uint32_t q_base_seq_id_this_block = tile_id * num_frags_x * 16; + const uint32_t q_offset = q_start_seq_id * q_ori_n_stride + + q_head_idx * HEAD_DIM + + tid % 8 * num_elems_per_128b(); + const uint32_t o_offset = q_start_seq_id * q_n_stride + + q_head_idx * HEAD_DIM + + tid % 8 * num_elems_per_128b(); + T *q_base_ptr = q + q_offset; +#ifdef DEBUG_ATTN_C4 + if (tid == PRINT_TID && wid == 0 && blockIdx.z == 0) { + printf( + "q_base_seq_id_this_block: %d, q_start_seq_id: %d, q_offset: %d, " + "q_ori_n_stride: %d, q_base: %f\n", + (int)q_base_seq_id_this_block, + (int)q_start_seq_id, + (int)q_offset, + (int)q_ori_n_stride, + (float)*q_base_ptr); + } + __syncthreads(); +#endif + T *o_base_ptr_T = nullptr; + OutT *o_base_ptr_int8 = nullptr; + if (num_chunks_this_seq <= 1) { + o_base_ptr_int8 = out + o_offset; + } else { + if (ENABLE_PREFILL) { + o_base_ptr_T = tmp_workspace + batch_id * num_chunks * q_n_stride + + chunk_idx * q_n_stride + q_head_idx * HEAD_DIM + + tid % 8 * num_elems_per_128b(); + } else { + o_base_ptr_T = + tmp_workspace + + batch_id * speculate_max_draft_token_num * num_chunks * q_n_stride + + chunk_idx * q_n_stride + q_head_idx * HEAD_DIM + + tid % 8 * num_elems_per_128b(); + } + // } else { + // o_base_ptr_int8 = out + o_offset; + } +#ifdef DEBUG_ATTN_C4 + if (tid == PRINT_TID && wid == 0 && blockIdx.z == 0) { + printf( + "q_base_seq_id_this_block: %d, q_base_seq_id_this_block: %d, q_offset: " + "%d, o_offset: %d\n", + (int)q_base_seq_id_this_block, + (int)q_base_seq_id_this_block, + (int)q_offset, + (int)o_offset); + } + __syncthreads(); +#endif + + smem_t qo_smem(smem); + + /* + 1 | 3 + —————— + 2 | 4 + */ + uint32_t q_smem_offset_r = smem_t::get_permuted_offset( + tid % 16, tid / 16); // 16 * 16 + load_q_global_smem_multi_warps(q_base_ptr, + &qo_smem, + q_base_seq_id_this_block, + q_end, + q_ori_n_stride, + HEAD_DIM); + commit_group(); + wait_group<0>(); + __syncthreads(); +#ifdef DEBUG_ATTN_C4 + if (tid == PRINT_TID && wid == 0 && blockIdx.z == 0) { + printf("before scale\n"); + T *q_smem_t = reinterpret_cast(qo_smem.base); + for (uint32_t i = 0; i < num_frags_x * 16; ++i) { + for (uint32_t j = 0; j < num_frags_y * 16; ++j) { + if (blockIdx.z == 0) { + printf("q_smem[%d][%d] = %f ", + (int)i, + (int)(j), + (float)q_smem_t[i * num_frags_y * 16 + j]); + } else { + int res = q_smem_t[i * num_frags_y * 16 + j] + static_cast(1.f); + } + } + printf("\n"); + } + } + __syncthreads(); +#endif + + q_smem_inplace_multiply_sm_scale_multi_warps( + &qo_smem, scale); +#ifdef DEBUG_ATTN_C4 + if (tid == PRINT_TID && wid == 0 && blockIdx.z == 0) { + printf("after scale\n"); + T *q_smem_t = reinterpret_cast(qo_smem.base); + for (uint32_t i = 0; i < num_frags_x * 16; ++i) { + for (uint32_t j = 0; j < num_frags_y * 16; ++j) { + if (blockIdx.z == 0) { + printf("q_smem[%d][%d] = %f ", + (int)i, + (int)(j), + (float)q_smem_t[i * num_frags_y * 16 + j]); + } else { + int res = q_smem_t[i * num_frags_y * 16 + j] + static_cast(1.f); + } + } + printf("\n"); + } + } + __syncthreads(); +#endif + + T cache_k_scale_frag[num_frags_y][4]; + T cache_k_zp_frag[num_frags_y][4]; + T magic_number; + if constexpr (std::is_same::value) { + magic_number = static_cast(1032.f); + } else { + magic_number = static_cast(136.f); + } +#pragma unroll + for (uint32_t fy = 0; fy < num_frags_y; ++fy) { + *(reinterpret_cast(&cache_k_scale_frag[fy][0])) = + *(reinterpret_cast(&cache_k_scale_smem[fy * 16]) + tid % 4); + *(reinterpret_cast(&cache_k_scale_frag[fy][2])) = + *(reinterpret_cast(&cache_k_scale_smem[fy * 16]) + tid % 4 + + 4); + *(reinterpret_cast(&cache_k_zp_frag[fy][0])) = + *(reinterpret_cast(&cache_k_zero_point_smem[fy * 16]) + + tid % 4); + *(reinterpret_cast(&cache_k_zp_frag[fy][2])) = + *(reinterpret_cast(&cache_k_zero_point_smem[fy * 16]) + + tid % 4 + 4); +#pragma unroll + for (uint32_t zp_i = 0; zp_i < 4; ++zp_i) { + cache_k_zp_frag[fy][zp_i] += magic_number; // 128 + 8 + } + } + T cache_v_scale_frag[num_frags_y][2]; + T cache_v_zp_frag[num_frags_y][2]; +#pragma unroll + for (uint32_t fy = 0; fy < num_frags_y; ++fy) { + cache_v_scale_frag[fy][0] = cache_v_scale_smem[fy * 16 + tid / 4]; + cache_v_scale_frag[fy][1] = cache_v_scale_smem[fy * 16 + tid / 4 + 8]; + cache_v_zp_frag[fy][0] = + cache_v_zero_point_smem[fy * 16 + tid / 4] + magic_number; + cache_v_zp_frag[fy][1] = + cache_v_zero_point_smem[fy * 16 + tid / 4 + 8] + magic_number; + } + + // smem_t k_smem(smem + (num_frags_x + wid * num_frags_z) * 16 * HEAD_DIM * + // sizeof(T)), + // v_smem(smem + (num_frags_x + (NUM_WARP_KV + wid) * num_frags_z) * 16 + // * HEAD_DIM * sizeof(T)); + smem_t k_smem(smem + num_frags_x * 16 * HEAD_DIM * sizeof(T)), + v_smem(smem + num_frags_x * 16 * HEAD_DIM * sizeof(T) + + NUM_WARP_KV * num_frags_z * 16 * HEAD_DIM / 2 * sizeof(CacheT)); + + const uint32_t num_iterations = div_up( + CAUSAL + ? (min(chunk_len, + sub_if_greater_or_zero( + kv_len - q_len + + div_up((tile_id + 1) * num_rows_per_block, GROUP_SIZE), + chunk_start))) + : chunk_len, + NUM_WARP_KV * num_frags_z * 16); + const uint32_t mask_check_iteration = + (CAUSAL ? (min(chunk_len, + sub_if_greater_or_zero( + kv_len - q_len + + tile_id * num_rows_per_block / GROUP_SIZE, + chunk_start))) + : chunk_len) / + (NUM_WARP_KV * num_frags_z * 16); +#ifdef DEBUG_ATTN_C4 + if (tid == 0 && wid == 0 && kv_head_idx == 0) { + printf( + "batch_id: %d, tile_id: %d, chunk_size: %d, q_len: %d, kv_len: %d, " + "chunk_start: %d, chunk_end: %d, num_iterations: %d, " + "mask_check_iteration: %d\n", + (int)batch_id, + (int)tile_id, + (int)chunk_size, + (int)q_len, + (int)kv_len, + (int)chunk_start, + (int)chunk_end, + (int)num_iterations, + (int)mask_check_iteration); + } + __syncthreads(); +#endif + /* + 1 | 2 + —————— + 3 | 4 + */ + uint32_t k_smem_offset_r = + smem_t::get_permuted_offset( + wid * num_frags_z * 16 + 8 * (tid / 16) + tid % 8, (tid % 16) / 8); + /* + 1 | 2 + —————— + 3 | 4 transpose + */ + uint32_t v_smem_offset_r = + smem_t::get_permuted_offset( + wid * num_frags_y * 16 + 8 * (tid / 16) + tid % 8, (tid % 16) / 8); + + uint32_t k_smem_offset_w = + smem_t::get_permuted_offset( + wid * 8 + tid / 4, + tid % + 4); // 4 * 128 / 8 = 64B, 128 nums, just fot head_dim >= 128 !!! + uint32_t v_smem_offset_w = + smem_t::get_permuted_offset( + wid * 16 + tid / 2, tid % 2); // 2 * 128 / 8 = 32B, 64 nums + + // uint32_t kv_idx_base = chunk_start; + // int block_id = __ldg(&block_table_now[kv_idx_base / BLOCK_SIZE]); + // const uint32_t const_offset = kv_head_idx * kv_h_stride + (wid * 4 + tid / + // 8) * kv_b_stride + tid % 8 * num_elems_per_128b(); + uint32_t kv_idx_base = chunk_start; + // int block_id = __ldg(&block_table_now[kv_idx_base / BLOCK_SIZE]); + const uint32_t const_k_offset = kv_head_idx * kv_h_stride + + (wid * 8 + tid / 4) * kv_b_stride + + tid % 4 * num_elems_per_128b(); + const uint32_t const_v_offset = kv_head_idx * kv_h_stride + + (wid * 16 + tid / 2) * kv_d_stride + + tid % 2 * num_elems_per_128b(); + // T *cache_k_now = cache_k + block_id * kv_n_stride + const_k_offset; + // T *cache_v_now = cache_v + block_id * kv_n_stride + const_v_offset; + +#ifdef DEBUG_ATTN_C4 + if (threadIdx.x == PRINT_TID && threadIdx.y == 0 && blockIdx.z == 0) { + printf( + "ori q_smem_offset_r: %d, k_smem_offset_r: %d, v_smem_offset_r: %d, " + "k_smem_offset_w: %d, v_smem_offset_w: %d, cache_k: %f, cache_k_p: %p, " + "const_k_offset: %d, const_v_offset: %d\n", + (int)q_smem_offset_r, + (int)k_smem_offset_r, + (int)v_smem_offset_r, + (int)k_smem_offset_w, + (int)v_smem_offset_w, + (float)(*cache_k), + cache_k, + (int)const_k_offset, + (int)const_v_offset); + } + __syncthreads(); +#endif + + // load BLOCK_SIZE * HEAD_DIM each time + produce_k_blockwise_c4(k_smem, + &k_smem_offset_w, + cache_k, + block_table_now, + kv_head_idx, + kv_n_stride, + kv_h_stride, + kv_b_stride, + kv_idx_base, + chunk_end, + const_k_offset); + commit_group(); + produce_v_blockwise_c4(v_smem, + &v_smem_offset_w, + cache_v, + block_table_now, + kv_head_idx, + kv_n_stride, + kv_h_stride, + kv_d_stride, + kv_idx_base, + chunk_end, + const_v_offset); + commit_group(); +#pragma unroll 1 + for (uint32_t iter = 0; iter < num_iterations; ++iter) { + wait_group<1>(); + __syncthreads(); +#ifdef DEBUG_ATTN_C4 + if (tid == PRINT_TID && wid == 0 && blockIdx.z == 0 && + blockIdx.x == gridDim.x - 1) { + printf("cache_k_smem\n"); + uint8_t *k_smem_t = reinterpret_cast(k_smem.base); + for (uint32_t i = 0; i < NUM_WARP_KV * num_frags_z * 16; ++i) { + for (uint32_t j = 0; j < num_frags_y * 16 / 2; ++j) { + printf("k_smem[%d][%d] = %d ", + (int)i, + (int)j, + (int)k_smem_t[i * num_frags_y * 16 / 2 + j]); + } + printf("\n"); + } + } + __syncthreads(); +#endif + // s = qk + compute_qk_c4( + &qo_smem, + &q_smem_offset_r, + &k_smem, + &k_smem_offset_r, + s_frag, + cache_k_scale_frag, + cache_k_zp_frag); + // mask according to kv_idx and q_idx + if (iter >= mask_check_iteration) { + // if (q_len > 1 && iter >= mask_check_iteration) { // not need mask in + // decoder, v will be filled with 0 + mask_s(q_base_seq_id_this_block, + kv_idx_base + wid * num_frags_z * 16, + q_len, + kv_len, + chunk_end, + s_frag); + } +#ifdef DEBUG_ATTN_C4 + if (threadIdx.x == PRINT_TID && threadIdx.y == 0 && blockIdx.z == 0 && + blockIdx.x == gridDim.x - 1) { + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { + for (uint32_t fz = 0; fz < num_frags_z; ++fz) { + for (int k = 0; k < 8; k++) { + printf("mask_s s_frag[%d][%d][%d]: %f ", + (int)fx, + (int)fz, + (int)k, + s_frag[fx][fz][k]); + } + printf("\n"); + } + printf("\n"); + } + } + __syncthreads(); +#endif + + // update m,d + update_mdo_states( + s_frag, o_frag, m_frag, d_frag); + __syncthreads(); +#ifdef DEBUG_ATTN_C4 + if (threadIdx.x == PRINT_TID && threadIdx.y == 0 && blockIdx.z == 0 && + blockIdx.x == gridDim.x - 1) { + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { + for (uint32_t fz = 0; fz < num_frags_z; ++fz) { + for (int k = 0; k < 8; k++) { + printf("update_mdo_states s_frag[%d][%d][%d]: %f ", + (int)fx, + (int)fz, + (int)k, + s_frag[fx][fz][k]); + } + printf("\n"); + } + printf("\n"); + } + } + __syncthreads(); +#endif + + kv_idx_base += NUM_WARP_KV * num_frags_z * 16; + produce_k_blockwise_c4(k_smem, + &k_smem_offset_w, + cache_k, + block_table_now, + kv_head_idx, + kv_n_stride, + kv_h_stride, + kv_b_stride, + kv_idx_base, + chunk_end, + const_k_offset); + commit_group(); + wait_group<1>(); + __syncthreads(); +#ifdef DEBUG_ATTN_C4 + if (tid == PRINT_TID && wid == 0 && blockIdx.z == 0 && + blockIdx.x == gridDim.x - 1) { + printf("cache_v_smem\n"); + uint8_t *v_smem_t = reinterpret_cast(v_smem.base); + for (uint32_t i = 0; i < NUM_WARP_KV * num_frags_y * 16; ++i) { + for (uint32_t j = 0; j < num_frags_z * 16 / 2; ++j) { + printf("v_smem[%d][%d] = %d ", + (int)i, + (int)j, + (int)v_smem_t[i * num_frags_z * 16 / 2 + j]); + } + printf("\n"); + } + } + __syncthreads(); +#endif + // compute sfm*v + compute_sfm_v_c4(&v_smem, + &v_smem_offset_r, + s_frag, + o_frag, + d_frag, + cache_v_scale_frag, + cache_v_zp_frag); + __syncthreads(); + + produce_v_blockwise_c4(v_smem, + &v_smem_offset_w, + cache_v, + block_table_now, + kv_head_idx, + kv_n_stride, + kv_h_stride, + kv_d_stride, + kv_idx_base, + chunk_end, + const_v_offset); + commit_group(); + } + wait_group<0>(); + __syncthreads(); +#ifdef DEBUG_ATTN_C4 + if (threadIdx.x == PRINT_TID && threadIdx.y == 0 && blockIdx.z == 0 && + blockIdx.x == gridDim.x - 1) { + printf("before merge z\n"); + for (uint32_t i = 0; i < num_frags_x; ++i) { + printf("m1: %f, m2: %f\n", m_frag[i][0], m_frag[i][1]); + printf("d1: %f, d2: %f\n", d_frag[i][0], d_frag[i][1]); + for (uint32_t j = 0; j < num_frags_y; ++j) { + for (int r_id = 0; r_id < 8; r_id++) { + printf("o_frag[%d][%d][%d]: %f ", + (int)i, + (int)j, + r_id, + o_frag[i][j][r_id]); + } + } + printf("\n"); + } + } + __syncthreads(); +#endif + + merge_block_res_v2( + o_frag, reinterpret_cast(smem), m_frag, d_frag, wid, tid); +#ifdef DEBUG_ATTN_C4 + if (threadIdx.x == PRINT_TID && threadIdx.y == 0 && blockIdx.z == 0 && + blockIdx.x == gridDim.x - 1) { + printf("after merge z\n"); + for (uint32_t i = 0; i < num_frags_x; ++i) { + printf("m1: %f, m2: %f\n", m_frag[i][0], m_frag[i][1]); + printf("d1: %f, d2: %f\n", d_frag[i][0], d_frag[i][1]); + for (uint32_t j = 0; j < num_frags_y; ++j) { + for (int r_id = 0; r_id < 8; r_id++) { + printf("o_frag[%d][%d][%d]: %f ", + (int)i, + (int)j, + r_id, + o_frag[i][j][r_id]); + } + } + printf("\n"); + } + } + __syncthreads(); +#endif + + if (num_chunks_this_seq <= 1) { + normalize_d(o_frag, d_frag); + } + + // write o + // [num_frags_x, 16, num_frags_y, 16] + if (num_chunks_this_seq <= 1) { + write_o_reg_gmem_multi_warps_shift_smooth_quant( + o_frag, + &qo_smem, + o_base_ptr_int8, + shift_bias, + smooth_weight, + q_base_seq_id_this_block, + q_head_idx, + in_scale, + q_len, + q_n_stride, + HEAD_DIM + ); + } else { + write_o_reg_gmem_multi_warps_shift_smooth_quant( + o_frag, + &qo_smem, + o_base_ptr_T, + shift_bias, + smooth_weight, + q_base_seq_id_this_block, + q_head_idx, + in_scale, + q_len, + q_n_stride * num_chunks, + HEAD_DIM); + // } else { + // write_o_reg_gmem_multi_warps_shift_smooth_quant( + // o_frag, + // &qo_smem, + // o_base_ptr_int8, + // shift_bias, + // smooth_weight, + // q_base_seq_id_this_block, + // q_head_idx, + // in_scale, + // q_len, + // partition_kv ? q_n_stride * num_chunks : q_n_stride, + // HEAD_DIM); + } + + if (num_chunks_this_seq > 1) { + if (wid == 0) { +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { +#pragma unroll + for (uint32_t j = 0; j < 2; ++j) { + const uint32_t qo_idx_now = + q_base_seq_id_this_block + tid / 4 + j * 8 + fx * 16; + const uint32_t qo_head_idx = q_head_idx + qo_idx_now % GROUP_SIZE; + const uint32_t qo_idx = q_start_seq_id + qo_idx_now / GROUP_SIZE; +#ifdef DEBUG_ATTN_C4 + if (batch_id == 0) { + printf( + "bid: %d, tid: %d, wid: %d, q_base_seq_id_this_block: %d, " + "qo_idx_now: %d, qo_idx: %d, q_start_seq_id: %d, q_len: %d, m: " + "%f, d: %f\n", + (int)batch_id, + (int)tid, + (int)wid, + (int)q_base_seq_id_this_block, + (int)qo_idx_now, + (int)qo_idx, + (int)q_start_seq_id, + (int)q_len, + (float)m_frag[fx][j], + (float)d_frag[fx][j]); + } +#endif + if (qo_idx - q_start_seq_id < q_len) { + uint32_t offset; + if (ENABLE_PREFILL) { + offset = (batch_id * num_chunks + chunk_idx) * q_num_heads + + qo_head_idx; + } else { + offset = ((batch_id * speculate_max_draft_token_num + + qo_idx_now / GROUP_SIZE) * + num_chunks + + chunk_idx) * + q_num_heads + + qo_head_idx; + } + tmp_m[offset] = m_frag[fx][j]; + tmp_d[offset] = d_frag[fx][j]; + } + } + } + } + } +} + + +template +__global__ void merge_multi_chunks_decoder_kernel( + const T *__restrict__ multi_out, // [token_num, num_chunks, num_heads, + // head_dim] + const float *__restrict__ multi_m, // [token_num, num_chunks, num_heads] + const float *__restrict__ multi_d, // [token_num, num_chunks, num_heads] + const int *__restrict__ seq_lens_q, + const int *__restrict__ seq_lens_kv, + const int *__restrict__ seq_lens_encoder, + const int *__restrict__ cum_offsets, + const T *__restrict__ shift_bias, // [q_num_heads * HEAD_DIM] + const T *__restrict__ smooth_weight, // [q_num_heads * HEAD_DIM] + OutT *__restrict__ out, + const float in_scale, + const int max_seq_len, + const int num_chunks, + const int num_heads, + const int chunk_size, + const int head_dim) { + const int vid = threadIdx.x, ty = threadIdx.y; + const int bid = blockIdx.x, hid = blockIdx.y; + __shared__ T smem[bdy * HEAD_DIM]; + __shared__ float md_smem[bdy * 2]; + const int start_token_idx = bid * max_seq_len - cum_offsets[bid]; + const int seq_len_q = seq_lens_q[bid]; + if (seq_len_q == 0) return; + int seq_len_kv = seq_lens_kv[bid]; + + if (ENABLE_PREFILL) { + seq_len_kv += seq_len_q; + if (seq_len_kv == 0) return; + } else { + if (seq_len_kv == 0) return; + seq_len_kv += seq_len_q; + } + const int seq_len_enc = seq_lens_encoder[bid]; + if (seq_len_enc > 0) { + return; + } + const int num_chunks_this_seq = div_up(seq_len_kv, chunk_size); + if (num_chunks_this_seq <= 1) { + return; + } + + using LoadT = AlignedVector; + LoadT load_vec; + LoadT res_vec; + if constexpr (std::is_same::value) { +#pragma unroll + for (int i = 0; i < vec_size / 2; ++i) { + *((half2 *)(&res_vec) + i) = make_half2(0, 0); + } + } else { +#pragma unroll + for (int i = 0; i < vec_size / 2; ++i) { + *((nv_bfloat162 *)(&res_vec) + i) = make_bfloat162(0, 0); + } + } + float m; + float d = 1.f; + if constexpr (std::is_same::value) { + m = -5e4f; + } else if constexpr (std::is_same::value) { + m = -3.0e+30f; + } +#pragma unroll 2 + for (int i = ty; i < num_chunks_this_seq; i += bdy) { + // uint32_t offset = (start_token_idx * num_chunks + i) * num_heads + hid; + uint32_t offset = (bid * num_chunks + i) * num_heads + hid; + float m_prev = m; + float d_prev = d; + const float m_now = multi_m[offset]; + const float d_now = multi_d[offset]; + m = max(m_prev, m_now); + // offset = (start_token_idx * num_chunks * num_heads + i * num_heads + hid) + // * head_dim + vid * vec_size; + offset = (bid * num_chunks * num_heads + i * num_heads + hid) * head_dim + + vid * vec_size; + Load(&multi_out[offset], &load_vec); + const float scale1 = __expf(m_prev - m), scale2 = __expf(m_now - m); + const T scale1_T = static_cast(scale1), + scale2_T = static_cast(scale2); + d = d * scale1 + d_now * scale2; +#pragma unroll + for (int j = 0; j < vec_size; j++) { + res_vec[j] = res_vec[j] * scale1_T + load_vec[j] * scale2_T; + } + } + // store ty res + Store(res_vec, &smem[ty * head_dim + vid * vec_size]); + md_smem[2 * ty] = m; + md_smem[2 * ty + 1] = d; + __syncthreads(); + if (ty == 0) { + // merge bdy + prefill_softmax_state_t st; + st.init(); +#pragma unroll + for (int i = 0; i < bdy; i++) { + Load(&smem[i * head_dim + vid * vec_size], &load_vec); + const float m_tmp = md_smem[2 * i], d_tmp = md_smem[2 * i + 1]; + st.merge(load_vec, m_tmp, d_tmp); + } + st.normalize(); + + const uint32_t shift_smooth_offset = hid * head_dim + vid * vec_size; + AlignedVector shift_bias_vec; + AlignedVector smooth_weight_vec; + AlignedVector out_vec; + if (shift_bias) { + Load(shift_bias + shift_smooth_offset, &shift_bias_vec); + Load(smooth_weight + shift_smooth_offset, + &smooth_weight_vec); + } +#pragma unroll + for (int i = 0; i < vec_size; ++i) { + // float quant_value = 127.0f * static_cast((st.o[i] + + // shift_bias_vec[i]) * smooth_weight_vec[i]) * in_scale; quant_value = + // rintf(quant_value); quant_value = quant_value > 127.0f ? 127.0f : + // quant_value; quant_value = quant_value < -127.0f ? -127.0f : + // quant_value; out_vec[i] = static_cast(quant_value); + StoreFunc()( + st.o, shift_bias_vec, smooth_weight_vec, out_vec, in_scale, i); + } + Store( + out_vec, + &out[(start_token_idx * num_heads + hid) * head_dim + vid * vec_size]); + } +} + + +template +__global__ void merge_multi_chunks_v2_kernel( + const T *__restrict__ multi_out, // [token_num, num_chunks, num_heads, + // head_dim] + const float *__restrict__ multi_m, // [token_num, num_chunks, num_heads] + const float *__restrict__ multi_d, // [token_num, num_chunks, num_heads] + const int *__restrict__ seq_lens_q, + const int *__restrict__ seq_lens_kv, + const int *__restrict__ seq_lens_encoder, + const int *__restrict__ padding_offsets, + const T *__restrict__ shift_bias, // [q_num_heads * HEAD_DIM] + const T *__restrict__ smooth_weight, // [q_num_heads * HEAD_DIM] + OutT *__restrict__ out, + const float in_scale, + const int max_seq_len, + const int num_chunks, + const int num_heads, + const int chunk_size, + const int head_dim, + const int token_num, + const int speculate_max_draft_token_num = 5) { + const int vid = threadIdx.x, ty = threadIdx.y; + // const int qid = blockIdx.x, hid = blockIdx.y; + const int hid = blockIdx.y; + __shared__ T smem[bdy * HEAD_DIM]; + __shared__ float md_smem[bdy * 2]; + for (int qid = blockIdx.x; qid < token_num; qid += gridDim.x) { + const uint32_t ori_token_id = qid + padding_offsets[qid]; + const uint32_t bid = ori_token_id / max_seq_len; + const uint32_t local_seq_id = ori_token_id % max_seq_len; + const int seq_len_q = seq_lens_q[bid]; + if (seq_len_q == 0) continue; + int seq_len_kv = seq_lens_kv[bid]; + if (ENABLE_PREFILL) { + seq_len_kv += seq_len_q; + if (seq_len_kv == 0) continue; + + const int seq_len_enc = seq_lens_encoder[bid]; + if (seq_len_enc <= 0) { + continue; + } + } else { + if (seq_len_kv == 0) continue; + seq_len_kv += seq_len_q; + } + const int num_chunks_this_seq = div_up(seq_len_kv, chunk_size); + if (num_chunks_this_seq <= 1) { + continue; + } + + using LoadT = AlignedVector; + LoadT load_vec; + LoadT res_vec; + if constexpr (std::is_same::value) { +#pragma unroll + for (int i = 0; i < vec_size / 2; ++i) { + *((half2 *)(&res_vec) + i) = make_half2(0, 0); + } + } else { +#pragma unroll + for (int i = 0; i < vec_size / 2; ++i) { + *((nv_bfloat162 *)(&res_vec) + i) = make_bfloat162(0, 0); + } + } + float m; + float d = 1.f; + if constexpr (std::is_same::value) { + m = -5e4f; + } else if constexpr (std::is_same::value) { + m = -3.0e+30f; + } +#pragma unroll 2 + for (int i = ty; i < num_chunks_this_seq; i += bdy) { + uint32_t offset; + if (ENABLE_PREFILL) { + offset = (qid * num_chunks + i) * num_heads + hid; + } else { + offset = + ((bid * speculate_max_draft_token_num + local_seq_id) * num_chunks + + i) * + num_heads + + hid; + } + float m_prev = m; + float d_prev = d; + const float m_now = multi_m[offset]; + const float d_now = multi_d[offset]; + m = max(m_prev, m_now); + if (ENABLE_PREFILL) { + offset = + (qid * num_chunks * num_heads + i * num_heads + hid) * head_dim + + vid * vec_size; + } else { + offset = ((bid * speculate_max_draft_token_num + local_seq_id) * + num_chunks * num_heads + + i * num_heads + hid) * + head_dim + + vid * vec_size; + } + Load(&multi_out[offset], &load_vec); + const float scale1 = __expf(m_prev - m), scale2 = __expf(m_now - m); + const T scale1_T = static_cast(scale1), + scale2_T = static_cast(scale2); + d = d * scale1 + d_now * scale2; +#pragma unroll + for (int j = 0; j < vec_size; j++) { + res_vec[j] = res_vec[j] * scale1_T + load_vec[j] * scale2_T; + } + } + // store ty res + Store(res_vec, &smem[ty * head_dim + vid * vec_size]); + md_smem[2 * ty] = m; + md_smem[2 * ty + 1] = d; + __syncthreads(); + if (ty == 0) { + // merge bdy + prefill_softmax_state_t st; + st.init(); +#pragma unroll + for (int i = 0; i < bdy; i++) { + Load(&smem[i * head_dim + vid * vec_size], &load_vec); + const float m_tmp = md_smem[2 * i], d_tmp = md_smem[2 * i + 1]; + st.merge(load_vec, m_tmp, d_tmp); + } + st.normalize(); + + const uint32_t shift_smooth_offset = hid * head_dim + vid * vec_size; + AlignedVector shift_bias_vec; + AlignedVector smooth_weight_vec; + AlignedVector out_vec; + if (shift_bias) { + Load(shift_bias + shift_smooth_offset, &shift_bias_vec); + Load(smooth_weight + shift_smooth_offset, + &smooth_weight_vec); + } +#pragma unroll + for (int i = 0; i < vec_size; ++i) { + // float quant_value = 127.0f * static_cast((st.o[i] + + // shift_bias_vec[i]) * smooth_weight_vec[i]) * in_scale; quant_value = + // rintf(quant_value); quant_value = quant_value > 127.0f ? 127.0f : + // quant_value; quant_value = quant_value < -127.0f ? -127.0f : + // quant_value; out_vec[i] = static_cast(quant_value); + + StoreFunc()( + st.o, shift_bias_vec, smooth_weight_vec, out_vec, in_scale, i); + } + Store( + out_vec, &out[(qid * num_heads + hid) * head_dim + vid * vec_size]); + } + __syncthreads(); + } +} + +template + +void MultiQueryAppendAttention( + const paddle::Tensor& qkv, + const paddle::Tensor& cache_k, + const paddle::Tensor& cache_v, + const paddle::optional& attn_mask, + const paddle::optional& shift_bias, + const paddle::optional& smooth_weight, + const paddle::Tensor& seq_lens_q, + const paddle::Tensor& seq_lens_kv, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& padding_offsets, + const paddle::Tensor& cum_offsets, + const paddle::Tensor& block_table, + const paddle::Tensor& batch_ids, + const paddle::Tensor& tile_ids_per_batch, + const int num_blocks_x_cpu, + const int max_seq_len, + const int max_dec_len, + const int num_heads, + const int kv_num_heads, + const float in_scale, + const int max_partition_size, + const int encoder_max_partition_size, + const int speculate_max_draft_token_num, + const bool is_decoder, + cudaStream_t& stream, + paddle::Tensor* out) { + using NV_TYPE = typename cascade_attn_type_traits::type; + using OUT_NV_TYPE = typename cascade_attn_type_traits::type; + + const auto& q_dims = qkv.dims(); + const auto& k_dims = cache_k.dims(); + const auto& cum_offsets_dims = cum_offsets.dims(); + const uint32_t token_num = q_dims[0]; + const uint32_t bsz = cum_offsets_dims[0]; + const uint32_t max_block_num_per_seq = block_table.dims()[1]; + + constexpr uint32_t num_warps = 4; + constexpr uint32_t NUM_WARP_KV = num_warps / NUM_WARP_Q; + constexpr uint32_t num_frags_x = BLOCK_SHAPE_Q / (16 * NUM_WARP_Q); // 1 or 2 + constexpr uint32_t num_frags_y = HEAD_DIM / 16; + constexpr uint32_t num_qrow_per_block = NUM_WARP_Q * num_frags_x * 16; + + auto* allocator = paddle::GetAllocator(qkv.place()); + + const float scale = 1.f / sqrt(HEAD_DIM); + + if constexpr (NUM_WARP_Q == 4) { + constexpr uint32_t num_frags_z = BLOCK_SIZE / 16; // !!! + // constexpr uint32_t num_frags_z = 8; // 128 per iter, 4 is better? + constexpr uint32_t smem_size = + (num_warps * num_frags_x + NUM_WARP_KV * num_frags_z * 2) * 16 * + HEAD_DIM * sizeof(T); + auto split_kv_kernel = multi_query_append_attention_kernel; + if (smem_size >= 48 * 1024) { + cudaFuncSetAttribute(split_kv_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size); + } + const int dev_id = 0; + int sm_count; + cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev_id); + + uint32_t chunk_size = static_cast(max_partition_size); + if (!is_decoder) { + chunk_size = static_cast(encoder_max_partition_size); + } + const int num_chunks = div_up(max_dec_len, chunk_size); + dim3 grids(num_blocks_x_cpu, num_chunks, kv_num_heads); + dim3 blocks(32, num_warps); + if (num_chunks <= 1) { + auto nosplit_kv_kernel = + multi_query_append_attention_kernel; + if (smem_size >= 48 * 1024) { + cudaFuncSetAttribute(nosplit_kv_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size); + } + + nosplit_kv_kernel<<>>( + reinterpret_cast(const_cast(qkv.data())), + reinterpret_cast(const_cast(cache_k.data())), + reinterpret_cast(const_cast(cache_v.data())), + shift_bias ? reinterpret_cast( + const_cast(shift_bias.get().data())) : nullptr, + smooth_weight ? reinterpret_cast( + const_cast(smooth_weight.get().data())) : nullptr, + seq_lens_q.data(), + seq_lens_kv.data(), + batch_ids.data(), + tile_ids_per_batch.data(), + cum_offsets.data(), + block_table.data(), + max_seq_len, + max_dec_len, + max_block_num_per_seq, + scale, + in_scale, + chunk_size, + nullptr, + nullptr, + nullptr, + reinterpret_cast(out->data()), + speculate_max_draft_token_num); + + } else { + phi::Allocator::AllocationPtr tmp_workspace, tmp_m, tmp_d; + if (ENABLE_PREFILL) { + tmp_workspace = allocator->Allocate( + phi::SizeOf(qkv.dtype()) * + static_cast(token_num * num_chunks * num_heads * HEAD_DIM)); + tmp_m = allocator->Allocate( + phi::SizeOf(paddle::DataType::FLOAT32) * + static_cast(token_num * num_chunks * num_heads)); + tmp_d = allocator->Allocate( + phi::SizeOf(paddle::DataType::FLOAT32) * + static_cast(token_num * num_chunks * num_heads)); + } else { + tmp_workspace = allocator->Allocate( + phi::SizeOf(qkv.dtype()) * + static_cast(speculate_max_draft_token_num * bsz * + num_chunks * num_heads * HEAD_DIM)); + tmp_m = allocator->Allocate( + phi::SizeOf(paddle::DataType::FLOAT32) * + static_cast(speculate_max_draft_token_num * bsz * + num_chunks * num_heads)); + tmp_d = allocator->Allocate( + phi::SizeOf(paddle::DataType::FLOAT32) * + static_cast(speculate_max_draft_token_num * bsz * + num_chunks * num_heads)); + } + + split_kv_kernel<<>>( + reinterpret_cast(const_cast(qkv.data())), + reinterpret_cast(const_cast(cache_k.data())), + reinterpret_cast(const_cast(cache_v.data())), + shift_bias ? reinterpret_cast( + const_cast(shift_bias.get().data())) : nullptr, + smooth_weight ? reinterpret_cast( + const_cast(smooth_weight.get().data())) : nullptr, + seq_lens_q.data(), + seq_lens_kv.data(), + batch_ids.data(), + tile_ids_per_batch.data(), + cum_offsets.data(), + block_table.data(), + max_seq_len, + max_dec_len, + max_block_num_per_seq, + scale, + in_scale, + chunk_size, + reinterpret_cast(tmp_workspace->ptr()), + static_cast(tmp_m->ptr()), + static_cast(tmp_d->ptr()), + reinterpret_cast(out->data()), + speculate_max_draft_token_num); + // merge + constexpr int vec_size = num_elems_per_128b(); + if (is_decoder) { + constexpr int blockx = HEAD_DIM / vec_size; + constexpr int blocky = (128 + blockx - 1) / blockx; + dim3 grids_merge(bsz, num_heads); // 128k is too large + dim3 blocks_merge(blockx, blocky); + merge_multi_chunks_decoder_kernel + <<>>( + reinterpret_cast(tmp_workspace->ptr()), + static_cast(tmp_m->ptr()), + static_cast(tmp_d->ptr()), + seq_lens_q.data(), + seq_lens_kv.data(), + seq_lens_encoder.data(), + cum_offsets.data(), + shift_bias ? reinterpret_cast( + const_cast(shift_bias.get().data())) : nullptr, + smooth_weight ? reinterpret_cast(const_cast( + smooth_weight.get().data())) : nullptr, + reinterpret_cast(out->data()), + in_scale, + max_seq_len, + num_chunks, + num_heads, + chunk_size, + HEAD_DIM); + } else { + constexpr int blockx = HEAD_DIM / vec_size; + constexpr int blocky = (128 + blockx - 1) / blockx; + dim3 grids_merge(min(sm_count * 4, token_num), + num_heads); // 128k is too large + dim3 blocks_merge(blockx, blocky); + merge_multi_chunks_v2_kernel + <<>>( + reinterpret_cast(tmp_workspace->ptr()), + static_cast(tmp_m->ptr()), + static_cast(tmp_d->ptr()), + seq_lens_q.data(), + seq_lens_kv.data(), + seq_lens_encoder.data(), + padding_offsets.data(), + shift_bias ? reinterpret_cast( + const_cast(shift_bias.get().data())) : nullptr, + smooth_weight ? reinterpret_cast(const_cast( + smooth_weight.get().data())) : nullptr, + reinterpret_cast(out->data()), + in_scale, + max_seq_len, + num_chunks, + num_heads, + chunk_size, + HEAD_DIM, + token_num, + speculate_max_draft_token_num); + } + } + } else { + constexpr uint32_t num_frags_z = BLOCK_SIZE / 16 / NUM_WARP_KV; // !!! + constexpr uint32_t smem_size = + (num_frags_x + NUM_WARP_KV * num_frags_z * 2) * 16 * HEAD_DIM * + sizeof(T); + auto split_kv_kernel = + multi_query_append_attention_warp1_4_kernel; + if (smem_size >= 48 * 1024) { + cudaFuncSetAttribute(split_kv_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size); + } + const int dev_id = 0; + int sm_count; + // int act_blocks_per_sm; + cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev_id); + + uint32_t chunk_size = static_cast(max_partition_size); + if (!is_decoder) { + chunk_size = static_cast(encoder_max_partition_size); + } + const int num_chunks = div_up(max_dec_len, chunk_size); + + dim3 grids(num_blocks_x_cpu, num_chunks, kv_num_heads); + // dim3 grids(num_blocks_x_cpu, num_chunks, 1); + dim3 blocks(32, num_warps); + + if (num_chunks <= 1) { + auto nosplit_kv_kernel = + multi_query_append_attention_warp1_4_kernel; + if (smem_size >= 48 * 1024) { + cudaFuncSetAttribute(nosplit_kv_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size); + } + + nosplit_kv_kernel<<>>( + reinterpret_cast(const_cast(qkv.data())), + reinterpret_cast(const_cast(cache_k.data())), + reinterpret_cast(const_cast(cache_v.data())), + shift_bias ? reinterpret_cast( + const_cast(shift_bias.get().data())) : nullptr, + smooth_weight ? reinterpret_cast( + const_cast(smooth_weight.get().data())) : nullptr, + seq_lens_q.data(), + seq_lens_kv.data(), + batch_ids.data(), + tile_ids_per_batch.data(), + cum_offsets.data(), + block_table.data(), + max_seq_len, + max_dec_len, + max_block_num_per_seq, + scale, + in_scale, + chunk_size, + nullptr, + nullptr, + nullptr, + reinterpret_cast(out->data()), + speculate_max_draft_token_num); + } else { + phi::Allocator::AllocationPtr tmp_workspace, tmp_m, tmp_d; + if (is_decoder) { + tmp_workspace = allocator->Allocate( + phi::SizeOf(qkv.dtype()) * + static_cast(bsz * num_chunks * num_heads * HEAD_DIM)); + tmp_m = allocator->Allocate( + phi::SizeOf(paddle::DataType::FLOAT32) * + static_cast(bsz * num_chunks * num_heads)); + tmp_d = allocator->Allocate( + phi::SizeOf(paddle::DataType::FLOAT32) * + static_cast(bsz * num_chunks * num_heads)); + } else { + if (ENABLE_PREFILL) { + tmp_workspace = + allocator->Allocate(phi::SizeOf(qkv.dtype()) * + static_cast(token_num * num_chunks * + num_heads * HEAD_DIM)); + tmp_m = allocator->Allocate( + phi::SizeOf(paddle::DataType::FLOAT32) * + static_cast(token_num * num_chunks * num_heads)); + tmp_d = allocator->Allocate( + phi::SizeOf(paddle::DataType::FLOAT32) * + static_cast(token_num * num_chunks * num_heads)); + } else { + tmp_workspace = allocator->Allocate( + phi::SizeOf(qkv.dtype()) * + static_cast(speculate_max_draft_token_num * bsz * + num_chunks * num_heads * HEAD_DIM)); + tmp_m = allocator->Allocate( + phi::SizeOf(paddle::DataType::FLOAT32) * + static_cast(speculate_max_draft_token_num * bsz * + num_chunks * num_heads)); + tmp_d = allocator->Allocate( + phi::SizeOf(paddle::DataType::FLOAT32) * + static_cast(speculate_max_draft_token_num * bsz * + num_chunks * num_heads)); + } + } + split_kv_kernel<<>>( + reinterpret_cast(const_cast(qkv.data())), + reinterpret_cast(const_cast(cache_k.data())), + reinterpret_cast(const_cast(cache_v.data())), + shift_bias ? reinterpret_cast( + const_cast(shift_bias.get().data())) : nullptr, + smooth_weight ? reinterpret_cast( + const_cast(smooth_weight.get().data())) : nullptr, + seq_lens_q.data(), + seq_lens_kv.data(), + batch_ids.data(), + tile_ids_per_batch.data(), + cum_offsets.data(), + block_table.data(), + max_seq_len, + max_dec_len, + max_block_num_per_seq, + scale, + in_scale, + chunk_size, + reinterpret_cast(tmp_workspace->ptr()), + static_cast(tmp_m->ptr()), + static_cast(tmp_d->ptr()), + reinterpret_cast(out->data()), + speculate_max_draft_token_num); + + // merge + constexpr int vec_size = num_elems_per_128b(); + if (is_decoder) { + constexpr int blockx = HEAD_DIM / vec_size; + constexpr int blocky = (128 + blockx - 1) / blockx; + dim3 grids_merge(bsz, num_heads); // 128k is too large + dim3 blocks_merge(blockx, blocky); + merge_multi_chunks_decoder_kernel + <<>>( + reinterpret_cast(tmp_workspace->ptr()), + static_cast(tmp_m->ptr()), + static_cast(tmp_d->ptr()), + seq_lens_q.data(), + seq_lens_kv.data(), + seq_lens_encoder.data(), + cum_offsets.data(), + shift_bias ? reinterpret_cast( + const_cast(shift_bias.get().data())) : nullptr, + smooth_weight ? reinterpret_cast(const_cast( + smooth_weight.get().data())) : nullptr, + reinterpret_cast(out->data()), + in_scale, + max_seq_len, + num_chunks, + num_heads, + chunk_size, + HEAD_DIM); + } else { + constexpr int blockx = HEAD_DIM / vec_size; + constexpr int blocky = (128 + blockx - 1) / blockx; + dim3 grids_merge(min(sm_count * 4, token_num), + num_heads); // 128k is too large + dim3 blocks_merge(blockx, blocky); + merge_multi_chunks_v2_kernel + <<>>( + reinterpret_cast(tmp_workspace->ptr()), + static_cast(tmp_m->ptr()), + static_cast(tmp_d->ptr()), + seq_lens_q.data(), + seq_lens_kv.data(), + seq_lens_encoder.data(), + padding_offsets.data(), + shift_bias ? reinterpret_cast( + const_cast(shift_bias.get().data())) : nullptr, + smooth_weight ? reinterpret_cast(const_cast( + smooth_weight.get().data())) : nullptr, + reinterpret_cast(out->data()), + in_scale, + max_seq_len, + num_chunks, + num_heads, + chunk_size, + HEAD_DIM, + token_num, + speculate_max_draft_token_num); + } + } + } +} + +template +void MultiQueryAppendC8Attention( + const paddle::Tensor& qkv, + const paddle::Tensor& cache_k, + const paddle::Tensor& cache_v, + const paddle::optional& attn_mask, + const paddle::Tensor& cache_k_scale, + const paddle::Tensor& cache_v_scale, + const paddle::optional& shift_bias, + const paddle::optional& smooth_weight, + const paddle::Tensor& seq_lens_q, + const paddle::Tensor& seq_lens_kv, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& padding_offsets, + const paddle::Tensor& cum_offsets, + const paddle::Tensor& block_table, + const paddle::Tensor& batch_ids, + const paddle::Tensor& tile_ids_per_batch, + const int num_blocks_x_cpu, + const int max_seq_len, + const int max_dec_len, + const int num_heads, + const int kv_num_heads, + const float in_scale, + const int max_partition_size, + const int encoder_max_partition_size, + const int speculate_max_draft_token_num, + const bool is_decoder, + cudaStream_t& stream, + paddle::Tensor* out) { + using NV_TYPE = typename cascade_attn_type_traits::type; + using OUT_NV_TYPE = typename cascade_attn_type_traits::type; + + const auto& q_dims = qkv.dims(); + const auto& k_dims = cache_k.dims(); + const auto& cum_offsets_dims = cum_offsets.dims(); + const uint32_t token_num = q_dims[0]; + const uint32_t bsz = cum_offsets_dims[0]; + const uint32_t max_block_num_per_seq = block_table.dims()[1]; + + constexpr uint32_t num_warps = 4; + constexpr uint32_t NUM_WARP_KV = num_warps / NUM_WARP_Q; + constexpr uint32_t num_frags_x = BLOCK_SHAPE_Q / (16 * NUM_WARP_Q); // 1 or 2 + constexpr uint32_t num_frags_y = HEAD_DIM / 16; + constexpr uint32_t num_qrow_per_block = NUM_WARP_Q * num_frags_x * 16; + + auto* allocator = paddle::GetAllocator(qkv.place()); + + const float scale = 1.f / sqrt(HEAD_DIM); + + if constexpr (NUM_WARP_Q == 4) { + constexpr uint32_t num_frags_z = BLOCK_SIZE / 16; // !!! + constexpr uint32_t smem_size = + num_warps * num_frags_x * 16 * HEAD_DIM * sizeof(T) + + num_frags_z * 16 * HEAD_DIM * sizeof(uint8_t) * 2; + auto split_kv_kernel = + multi_query_append_attention_c8_kernel; + if (smem_size >= 48 * 1024) { + cudaFuncSetAttribute(split_kv_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size); + } + const int dev_id = 0; + int sm_count; + // int act_blocks_per_sm; + cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev_id); + uint32_t chunk_size = static_cast(max_partition_size); + if (!is_decoder) { + chunk_size = static_cast(encoder_max_partition_size); + } + const int num_chunks = div_up(max_dec_len, chunk_size); + dim3 grids(num_blocks_x_cpu, num_chunks, kv_num_heads); + dim3 blocks(32, num_warps); + if (num_chunks <= 1) { + auto nosplit_kv_kernel = + multi_query_append_attention_c8_kernel; + if (smem_size >= 48 * 1024) { + cudaFuncSetAttribute(nosplit_kv_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size); + } + + nosplit_kv_kernel<<>>( + reinterpret_cast(const_cast(qkv.data())), + const_cast(cache_k.data()), + const_cast(cache_v.data()), + reinterpret_cast(const_cast(cache_k_scale.data())), + reinterpret_cast(const_cast(cache_v_scale.data())), + shift_bias ? reinterpret_cast( + const_cast(shift_bias.get().data())) : nullptr, + smooth_weight ? reinterpret_cast( + const_cast(smooth_weight.get().data())) : nullptr, + seq_lens_q.data(), + seq_lens_kv.data(), + batch_ids.data(), + tile_ids_per_batch.data(), + cum_offsets.data(), + block_table.data(), + max_seq_len, + max_dec_len, + max_block_num_per_seq, + scale, + in_scale, + chunk_size, + nullptr, + nullptr, + nullptr, + reinterpret_cast(out->data()), + speculate_max_draft_token_num); + } else { + phi::Allocator::AllocationPtr tmp_workspace, tmp_m, tmp_d; + if (ENABLE_PREFILL) { + tmp_workspace = allocator->Allocate( + phi::SizeOf(qkv.dtype()) * + static_cast(token_num * num_chunks * num_heads * HEAD_DIM)); + tmp_m = allocator->Allocate( + phi::SizeOf(paddle::DataType::FLOAT32) * + static_cast(token_num * num_chunks * num_heads)); + tmp_d = allocator->Allocate( + phi::SizeOf(paddle::DataType::FLOAT32) * + static_cast(token_num * num_chunks * num_heads)); + } else { + tmp_workspace = allocator->Allocate( + phi::SizeOf(qkv.dtype()) * + static_cast(speculate_max_draft_token_num * bsz * + num_chunks * num_heads * HEAD_DIM)); + tmp_m = allocator->Allocate( + phi::SizeOf(paddle::DataType::FLOAT32) * + static_cast(speculate_max_draft_token_num * bsz * + num_chunks * num_heads)); + tmp_d = allocator->Allocate( + phi::SizeOf(paddle::DataType::FLOAT32) * + static_cast(speculate_max_draft_token_num * bsz * + num_chunks * num_heads)); + } + split_kv_kernel<<>>( + reinterpret_cast(const_cast(qkv.data())), + const_cast(cache_k.data()), + const_cast(cache_v.data()), + reinterpret_cast(const_cast(cache_k_scale.data())), + reinterpret_cast(const_cast(cache_v_scale.data())), + shift_bias ? reinterpret_cast( + const_cast(shift_bias.get().data())) : nullptr, + smooth_weight ? reinterpret_cast( + const_cast(smooth_weight.get().data())) : nullptr, + seq_lens_q.data(), + seq_lens_kv.data(), + batch_ids.data(), + tile_ids_per_batch.data(), + cum_offsets.data(), + block_table.data(), + max_seq_len, + max_dec_len, + max_block_num_per_seq, + scale, + in_scale, + chunk_size, + reinterpret_cast(tmp_workspace->ptr()), + static_cast(tmp_m->ptr()), + static_cast(tmp_d->ptr()), + reinterpret_cast(out->data()), + speculate_max_draft_token_num); + // merge + constexpr int vec_size = num_elems_per_128b(); + if (is_decoder) { + constexpr int blockx = HEAD_DIM / vec_size; + constexpr int blocky = (128 + blockx - 1) / blockx; + dim3 grids_merge(bsz, num_heads); // 128k is too large + dim3 blocks_merge(blockx, blocky); + merge_multi_chunks_decoder_kernel + <<>>( + reinterpret_cast(tmp_workspace->ptr()), + static_cast(tmp_m->ptr()), + static_cast(tmp_d->ptr()), + seq_lens_q.data(), + seq_lens_kv.data(), + seq_lens_encoder.data(), + cum_offsets.data(), + shift_bias ? reinterpret_cast( + const_cast(shift_bias.get().data())) : nullptr, + smooth_weight ? reinterpret_cast(const_cast( + smooth_weight.get().data())) : nullptr, + reinterpret_cast(out->data()), + in_scale, + max_seq_len, + num_chunks, + num_heads, + chunk_size, + HEAD_DIM); + } else { + constexpr int blockx = HEAD_DIM / vec_size; + constexpr int blocky = (128 + blockx - 1) / blockx; + dim3 grids_merge(min(sm_count * 4, token_num), + num_heads); // 128k is too large + dim3 blocks_merge(blockx, blocky); + merge_multi_chunks_v2_kernel + <<>>( + reinterpret_cast(tmp_workspace->ptr()), + static_cast(tmp_m->ptr()), + static_cast(tmp_d->ptr()), + seq_lens_q.data(), + seq_lens_kv.data(), + seq_lens_encoder.data(), + padding_offsets.data(), + shift_bias ? reinterpret_cast( + const_cast(shift_bias.get().data())) : nullptr, + smooth_weight ? reinterpret_cast(const_cast( + smooth_weight.get().data())) : nullptr, + reinterpret_cast(out->data()), + in_scale, + max_seq_len, + num_chunks, + num_heads, + chunk_size, + HEAD_DIM, + token_num, + speculate_max_draft_token_num); + } + } + } else { + constexpr uint32_t num_frags_z = BLOCK_SIZE / 16 / NUM_WARP_KV * 2; // !!! + constexpr uint32_t smem_size = + num_frags_x * 16 * HEAD_DIM * sizeof(T) + + NUM_WARP_KV * num_frags_z * 16 * HEAD_DIM * sizeof(uint8_t) * 2; + auto split_kv_kernel = + multi_query_append_attention_c8_warp1_4_kernel; + if (smem_size >= 48 * 1024) { + cudaFuncSetAttribute(split_kv_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size); + } + const int dev_id = 0; + int sm_count; + // int act_blocks_per_sm; + cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev_id); + uint32_t chunk_size = static_cast(max_partition_size); + if (!is_decoder) { + chunk_size = static_cast(encoder_max_partition_size); + } + + const int num_chunks = div_up(max_dec_len, chunk_size); + dim3 grids(num_blocks_x_cpu, num_chunks, kv_num_heads); + // dim3 grids(num_blocks_x_cpu, num_chunks, 1); + dim3 blocks(32, num_warps); + if (num_chunks <= 1) { + auto nosplit_kv_kernel = + multi_query_append_attention_c8_warp1_4_kernel; + if (smem_size >= 48 * 1024) { + cudaFuncSetAttribute(nosplit_kv_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size); + } + + nosplit_kv_kernel<<>>( + reinterpret_cast(const_cast(qkv.data())), + const_cast(cache_k.data()), + const_cast(cache_v.data()), + reinterpret_cast(const_cast(cache_k_scale.data())), + reinterpret_cast(const_cast(cache_v_scale.data())), + shift_bias ? reinterpret_cast( + const_cast(shift_bias.get().data())) : nullptr, + smooth_weight ? reinterpret_cast( + const_cast(smooth_weight.get().data())) : nullptr, + seq_lens_q.data(), + seq_lens_kv.data(), + batch_ids.data(), + tile_ids_per_batch.data(), + cum_offsets.data(), + block_table.data(), + max_seq_len, + max_dec_len, + max_block_num_per_seq, + scale, + in_scale, + chunk_size, + nullptr, + nullptr, + nullptr, + reinterpret_cast(out->data()), + speculate_max_draft_token_num); + } else { + phi::Allocator::AllocationPtr tmp_workspace, tmp_m, tmp_d; + if (is_decoder) { + tmp_workspace = allocator->Allocate( + phi::SizeOf(qkv.dtype()) * + static_cast(bsz * num_chunks * num_heads * HEAD_DIM)); + tmp_m = allocator->Allocate( + phi::SizeOf(paddle::DataType::FLOAT32) * + static_cast(bsz * num_chunks * num_heads)); + tmp_d = allocator->Allocate( + phi::SizeOf(paddle::DataType::FLOAT32) * + static_cast(bsz * num_chunks * num_heads)); + } else { + if (ENABLE_PREFILL) { + tmp_workspace = + allocator->Allocate(phi::SizeOf(qkv.dtype()) * + static_cast(token_num * num_chunks * + num_heads * HEAD_DIM)); + tmp_m = allocator->Allocate( + phi::SizeOf(paddle::DataType::FLOAT32) * + static_cast(token_num * num_chunks * num_heads)); + tmp_d = allocator->Allocate( + phi::SizeOf(paddle::DataType::FLOAT32) * + static_cast(token_num * num_chunks * num_heads)); + } else { + tmp_workspace = allocator->Allocate( + phi::SizeOf(qkv.dtype()) * + static_cast(speculate_max_draft_token_num * bsz * + num_chunks * num_heads * HEAD_DIM)); + tmp_m = allocator->Allocate( + phi::SizeOf(paddle::DataType::FLOAT32) * + static_cast(speculate_max_draft_token_num * bsz * + num_chunks * num_heads)); + tmp_d = allocator->Allocate( + phi::SizeOf(paddle::DataType::FLOAT32) * + static_cast(speculate_max_draft_token_num * bsz * + num_chunks * num_heads)); + } + } + split_kv_kernel<<>>( + reinterpret_cast(const_cast(qkv.data())), + const_cast(cache_k.data()), + const_cast(cache_v.data()), + reinterpret_cast(const_cast(cache_k_scale.data())), + reinterpret_cast(const_cast(cache_v_scale.data())), + shift_bias ? reinterpret_cast( + const_cast(shift_bias.get().data())) : nullptr, + smooth_weight ? reinterpret_cast( + const_cast(smooth_weight.get().data())) : nullptr, + seq_lens_q.data(), + seq_lens_kv.data(), + batch_ids.data(), + tile_ids_per_batch.data(), + cum_offsets.data(), + block_table.data(), + max_seq_len, + max_dec_len, + max_block_num_per_seq, + scale, + in_scale, + chunk_size, + reinterpret_cast(tmp_workspace->ptr()), + static_cast(tmp_m->ptr()), + static_cast(tmp_d->ptr()), + reinterpret_cast(out->data()), + speculate_max_draft_token_num); + // merge + constexpr int vec_size = num_elems_per_128b(); + if (is_decoder) { + constexpr int blockx = HEAD_DIM / vec_size; + constexpr int blocky = (128 + blockx - 1) / blockx; + dim3 grids_merge(bsz, num_heads); // 128k is too large + dim3 blocks_merge(blockx, blocky); + merge_multi_chunks_decoder_kernel + <<>>( + reinterpret_cast(tmp_workspace->ptr()), + static_cast(tmp_m->ptr()), + static_cast(tmp_d->ptr()), + seq_lens_q.data(), + seq_lens_kv.data(), + seq_lens_encoder.data(), + cum_offsets.data(), + shift_bias ? reinterpret_cast( + const_cast(shift_bias.get().data())) : nullptr, + smooth_weight ? reinterpret_cast(const_cast( + smooth_weight.get().data())) : nullptr, + reinterpret_cast(out->data()), + in_scale, + max_seq_len, + num_chunks, + num_heads, + chunk_size, + HEAD_DIM); + } else { + constexpr int blockx = HEAD_DIM / vec_size; + constexpr int blocky = (128 + blockx - 1) / blockx; + dim3 grids_merge(min(sm_count * 4, token_num), + num_heads); // 128k is too large + dim3 blocks_merge(blockx, blocky); + merge_multi_chunks_v2_kernel + <<>>( + reinterpret_cast(tmp_workspace->ptr()), + static_cast(tmp_m->ptr()), + static_cast(tmp_d->ptr()), + seq_lens_q.data(), + seq_lens_kv.data(), + seq_lens_encoder.data(), + padding_offsets.data(), + shift_bias ? reinterpret_cast( + const_cast(shift_bias.get().data())) : nullptr, + smooth_weight ? reinterpret_cast(const_cast( + smooth_weight.get().data())) : nullptr, + reinterpret_cast(out->data()), + in_scale, + max_seq_len, + num_chunks, + num_heads, + chunk_size, + HEAD_DIM, + token_num, + speculate_max_draft_token_num); + } + } + } +} + +template +void MultiQueryAppendC4Attention( + const paddle::Tensor& qkv, + const paddle::Tensor& cache_k, + const paddle::Tensor& cache_v, + const paddle::optional& attn_mask, + const paddle::Tensor& cache_k_scale, + const paddle::Tensor& cache_v_scale, + const paddle::optional& cache_k_zp, + const paddle::optional& cache_v_zp, + const paddle::optional& shift_bias, + const paddle::optional& smooth_weight, + const paddle::Tensor& seq_lens_q, + const paddle::Tensor& seq_lens_kv, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& padding_offsets, + const paddle::Tensor& cum_offsets, + const paddle::Tensor& block_table, + const paddle::Tensor& batch_ids, + const paddle::Tensor& tile_ids_per_batch, + const int num_blocks_x_cpu, + const int max_seq_len, + const int max_dec_len, + const int num_heads, + const int kv_num_heads, + const float in_scale, + const int max_partition_size, + const int encoder_max_partition_size, + const int speculate_max_draft_token_num, + const bool is_decoder, + cudaStream_t& stream, + paddle::Tensor* out) { + using NV_TYPE = typename cascade_attn_type_traits::type; + using OUT_NV_TYPE = typename cascade_attn_type_traits::type; + + const auto& q_dims = qkv.dims(); + const auto& k_dims = cache_k.dims(); + const auto& cum_offsets_dims = cum_offsets.dims(); + const uint32_t token_num = q_dims[0]; + const uint32_t bsz = cum_offsets_dims[0]; + const uint32_t max_block_num_per_seq = block_table.dims()[1]; + constexpr uint32_t num_warps = 4; + constexpr uint32_t NUM_WARP_KV = num_warps / NUM_WARP_Q; + constexpr uint32_t num_frags_x = BLOCK_SHAPE_Q / (16 * NUM_WARP_Q); // 1 or 2 + constexpr uint32_t num_frags_y = HEAD_DIM / 16; + constexpr uint32_t num_qrow_per_block = NUM_WARP_Q * num_frags_x * 16; + + auto* allocator = paddle::GetAllocator(qkv.place()); + + const float scale = 1.f / sqrt(HEAD_DIM); + + if constexpr (NUM_WARP_Q == 4) { + constexpr uint32_t num_frags_z = BLOCK_SIZE / 16; // !!! + // constexpr uint32_t num_frags_z = 8; // 128 per iter, 4 is better? + constexpr uint32_t smem_size = + num_warps * num_frags_x * 16 * HEAD_DIM * sizeof(T) + + num_frags_z * 16 * HEAD_DIM / 2 * sizeof(uint8_t) * 2 + + HEAD_DIM * 4 * sizeof(T); + auto split_kv_kernel = + multi_query_append_attention_c4_kernel; + // if (smem_size >= 48 * 1024) { + cudaFuncSetAttribute(split_kv_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size); + // } + const int dev_id = 0; + int sm_count; + int act_blocks_per_sm; + cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev_id); + cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &act_blocks_per_sm, split_kv_kernel, num_warps * 32, smem_size); + assert(act_blocks_per_sm > 1); + const int num_blocks_per_wave = sm_count * act_blocks_per_sm; + const int num_blocks_need = num_blocks_x_cpu * kv_num_heads; + const int max_num_chunks = div_up(num_blocks_per_wave, num_blocks_need); + const float ratio = static_cast(num_blocks_need) / + static_cast(num_blocks_per_wave); + + uint32_t chunk_size = static_cast(max_partition_size); + if (!is_decoder) { + chunk_size = static_cast(encoder_max_partition_size); + } + const int num_chunks = div_up(max_dec_len, chunk_size); + + dim3 grids(num_blocks_x_cpu, num_chunks, kv_num_heads); + dim3 blocks(32, num_warps); + if (num_chunks <= 1) { + auto nosplit_kv_kernel = + multi_query_append_attention_c4_kernel; + if (smem_size >= 48 * 1024) { + cudaFuncSetAttribute(nosplit_kv_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size); + } + nosplit_kv_kernel<<>>( + reinterpret_cast(const_cast(qkv.data())), + const_cast(cache_k.data()), + const_cast(cache_v.data()), + reinterpret_cast(const_cast(cache_k_scale.data())), + cache_k_zp ? reinterpret_cast( + const_cast(cache_k_zp.get().data())) : nullptr, + reinterpret_cast(const_cast(cache_v_scale.data())), + cache_v_zp ? reinterpret_cast( + const_cast(cache_v_zp.get().data())) : nullptr, + shift_bias ? reinterpret_cast( + const_cast(shift_bias.get().data())) : nullptr, + smooth_weight ? reinterpret_cast( + const_cast(smooth_weight.get().data())) : nullptr, + seq_lens_q.data(), + seq_lens_kv.data(), + batch_ids.data(), + tile_ids_per_batch.data(), + cum_offsets.data(), + block_table.data(), + max_seq_len, + max_dec_len, + max_block_num_per_seq, + scale, + in_scale, + chunk_size, + nullptr, + nullptr, + nullptr, + reinterpret_cast(out->data()), + speculate_max_draft_token_num); + } else { + phi::Allocator::AllocationPtr tmp_workspace, tmp_m, tmp_d; + if (ENABLE_PREFILL) { + tmp_workspace = allocator->Allocate( + phi::SizeOf(qkv.dtype()) * + static_cast(token_num * num_chunks * num_heads * HEAD_DIM)); + tmp_m = allocator->Allocate( + phi::SizeOf(paddle::DataType::FLOAT32) * + static_cast(token_num * num_chunks * num_heads)); + tmp_d = allocator->Allocate( + phi::SizeOf(paddle::DataType::FLOAT32) * + static_cast(token_num * num_chunks * num_heads)); + } else { + tmp_workspace = allocator->Allocate( + phi::SizeOf(qkv.dtype()) * + static_cast(speculate_max_draft_token_num * bsz * + num_chunks * num_heads * HEAD_DIM)); + tmp_m = allocator->Allocate( + phi::SizeOf(paddle::DataType::FLOAT32) * + static_cast(speculate_max_draft_token_num * bsz * + num_chunks * num_heads)); + tmp_d = allocator->Allocate( + phi::SizeOf(paddle::DataType::FLOAT32) * + static_cast(speculate_max_draft_token_num * bsz * + num_chunks * num_heads)); + } + split_kv_kernel<<>>( + reinterpret_cast(const_cast(qkv.data())), + const_cast(cache_k.data()), + const_cast(cache_v.data()), + reinterpret_cast(const_cast(cache_k_scale.data())), + cache_k_zp ? reinterpret_cast( + const_cast(cache_k_zp.get().data())) : nullptr, + reinterpret_cast(const_cast(cache_v_scale.data())), + cache_v_zp ? reinterpret_cast( + const_cast(cache_v_zp.get().data())) : nullptr, + shift_bias ? reinterpret_cast( + const_cast(shift_bias.get().data())) : nullptr, + smooth_weight ? reinterpret_cast( + const_cast(smooth_weight.get().data())) : nullptr, + seq_lens_q.data(), + seq_lens_kv.data(), + batch_ids.data(), + tile_ids_per_batch.data(), + cum_offsets.data(), + block_table.data(), + max_seq_len, + max_dec_len, + max_block_num_per_seq, + scale, + in_scale, + chunk_size, + reinterpret_cast(tmp_workspace->ptr()), + static_cast(tmp_m->ptr()), + static_cast(tmp_d->ptr()), + reinterpret_cast(out->data()), + speculate_max_draft_token_num); + // merge + constexpr int vec_size = num_elems_per_128b(); + if (is_decoder) { + constexpr int blockx = HEAD_DIM / vec_size; + constexpr int blocky = (128 + blockx - 1) / blockx; + dim3 grids_merge(bsz, num_heads); // 128k is too large + dim3 blocks_merge(blockx, blocky); + merge_multi_chunks_decoder_kernel + <<>>( + reinterpret_cast(tmp_workspace->ptr()), + static_cast(tmp_m->ptr()), + static_cast(tmp_d->ptr()), + seq_lens_q.data(), + seq_lens_kv.data(), + seq_lens_encoder.data(), + cum_offsets.data(), + shift_bias ? reinterpret_cast( + const_cast(shift_bias.get().data())) : nullptr, + smooth_weight ? reinterpret_cast(const_cast( + smooth_weight.get().data())) : nullptr, + reinterpret_cast(out->data()), + in_scale, + max_seq_len, + num_chunks, + num_heads, + chunk_size, + HEAD_DIM); + } else { + constexpr int blockx = HEAD_DIM / vec_size; + constexpr int blocky = (128 + blockx - 1) / blockx; + dim3 grids_merge(min(sm_count * 4, token_num), + num_heads); // 128k is too large + dim3 blocks_merge(blockx, blocky); + merge_multi_chunks_v2_kernel + <<>>( + reinterpret_cast(tmp_workspace->ptr()), + static_cast(tmp_m->ptr()), + static_cast(tmp_d->ptr()), + seq_lens_q.data(), + seq_lens_kv.data(), + seq_lens_encoder.data(), + padding_offsets.data(), + shift_bias ? reinterpret_cast( + const_cast(shift_bias.get().data())) : nullptr, + smooth_weight ? reinterpret_cast(const_cast( + smooth_weight.get().data())) : nullptr, + reinterpret_cast(out->data()), + in_scale, + max_seq_len, + num_chunks, + num_heads, + chunk_size, + HEAD_DIM, + token_num, + speculate_max_draft_token_num); + } + } + } else { + constexpr uint32_t num_frags_z = BLOCK_SIZE / 16 / NUM_WARP_KV * 4; // !!! + constexpr uint32_t smem_size = + num_frags_x * 16 * HEAD_DIM * sizeof(T) + + NUM_WARP_KV * num_frags_z * 16 * HEAD_DIM / 2 * sizeof(uint8_t) * 2 + + HEAD_DIM * 4 * sizeof(T); + auto split_kv_kernel = + multi_query_append_attention_c4_warp1_4_kernel; + if (smem_size >= 48 * 1024) { + cudaFuncSetAttribute(split_kv_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size); + } + const int dev_id = 0; + int sm_count; + int act_blocks_per_sm; + cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev_id); + cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &act_blocks_per_sm, split_kv_kernel, num_warps * 32, smem_size); + assert(act_blocks_per_sm > 1); + const int num_blocks_per_wave = sm_count * act_blocks_per_sm; + const int num_blocks_need = num_blocks_x_cpu * kv_num_heads; + const int max_num_chunks = div_up(num_blocks_per_wave, num_blocks_need); + const float ratio = static_cast(num_blocks_need) / + static_cast(num_blocks_per_wave); + + + uint32_t chunk_size = static_cast(max_partition_size); + if (!is_decoder) { + chunk_size = static_cast(encoder_max_partition_size); + } + const int num_chunks = div_up(max_dec_len, chunk_size); + dim3 grids(num_blocks_x_cpu, num_chunks, kv_num_heads); + // dim3 grids(num_blocks_x_cpu, num_chunks, 1); + dim3 blocks(32, num_warps); + if (num_chunks <= 1) { + auto nosplit_kv_kernel = + multi_query_append_attention_c4_warp1_4_kernel; + if (smem_size >= 48 * 1024) { + cudaFuncSetAttribute(nosplit_kv_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size); + } + nosplit_kv_kernel<<>>( + reinterpret_cast(const_cast(qkv.data())), + const_cast(cache_k.data()), + const_cast(cache_v.data()), + reinterpret_cast(const_cast(cache_k_scale.data())), + cache_k_zp ? reinterpret_cast( + const_cast(cache_k_zp.get().data())) : nullptr, + reinterpret_cast(const_cast(cache_v_scale.data())), + cache_v_zp ? reinterpret_cast( + const_cast(cache_v_zp.get().data())) : nullptr, + shift_bias ? reinterpret_cast( + const_cast(shift_bias.get().data())) : nullptr, + smooth_weight ? reinterpret_cast( + const_cast(smooth_weight.get().data())) : nullptr, + seq_lens_q.data(), + seq_lens_kv.data(), + batch_ids.data(), + tile_ids_per_batch.data(), + cum_offsets.data(), + block_table.data(), + max_seq_len, + max_dec_len, + max_block_num_per_seq, + scale, + in_scale, + chunk_size, + nullptr, + nullptr, + nullptr, + reinterpret_cast(out->data()), + speculate_max_draft_token_num); + } else { + phi::Allocator::AllocationPtr tmp_workspace, tmp_m, tmp_d; + if (is_decoder) { + tmp_workspace = allocator->Allocate( + phi::SizeOf(qkv.dtype()) * + static_cast(bsz * num_chunks * num_heads * HEAD_DIM)); + tmp_m = allocator->Allocate( + phi::SizeOf(paddle::DataType::FLOAT32) * + static_cast(bsz * num_chunks * num_heads)); + tmp_d = allocator->Allocate( + phi::SizeOf(paddle::DataType::FLOAT32) * + static_cast(bsz * num_chunks * num_heads)); + } else { + if (ENABLE_PREFILL) { + tmp_workspace = + allocator->Allocate(phi::SizeOf(qkv.dtype()) * + static_cast(token_num * num_chunks * + num_heads * HEAD_DIM)); + tmp_m = allocator->Allocate( + phi::SizeOf(paddle::DataType::FLOAT32) * + static_cast(token_num * num_chunks * num_heads)); + tmp_d = allocator->Allocate( + phi::SizeOf(paddle::DataType::FLOAT32) * + static_cast(token_num * num_chunks * num_heads)); + } else { + tmp_workspace = allocator->Allocate( + phi::SizeOf(qkv.dtype()) * + static_cast(speculate_max_draft_token_num * bsz * + num_chunks * num_heads * HEAD_DIM)); + tmp_m = allocator->Allocate( + phi::SizeOf(paddle::DataType::FLOAT32) * + static_cast(speculate_max_draft_token_num * bsz * + num_chunks * num_heads)); + tmp_d = allocator->Allocate( + phi::SizeOf(paddle::DataType::FLOAT32) * + static_cast(speculate_max_draft_token_num * bsz * + num_chunks * num_heads)); + } + } + split_kv_kernel<<>>( + reinterpret_cast(const_cast(qkv.data())), + const_cast(cache_k.data()), + const_cast(cache_v.data()), + reinterpret_cast(const_cast(cache_k_scale.data())), + cache_k_zp ? reinterpret_cast( + const_cast(cache_k_zp.get().data())) : nullptr, + reinterpret_cast(const_cast(cache_v_scale.data())), + cache_v_zp ? reinterpret_cast( + const_cast(cache_v_zp.get().data())) : nullptr, + shift_bias ? reinterpret_cast( + const_cast(shift_bias.get().data())) : nullptr, + smooth_weight ? reinterpret_cast( + const_cast(smooth_weight.get().data())) : nullptr, + seq_lens_q.data(), + seq_lens_kv.data(), + batch_ids.data(), + tile_ids_per_batch.data(), + cum_offsets.data(), + block_table.data(), + max_seq_len, + max_dec_len, + max_block_num_per_seq, + scale, + in_scale, + chunk_size, + reinterpret_cast(tmp_workspace->ptr()), + static_cast(tmp_m->ptr()), + static_cast(tmp_d->ptr()), + reinterpret_cast(out->data()), + speculate_max_draft_token_num); + // merge + constexpr int vec_size = num_elems_per_128b(); + if (is_decoder) { + constexpr int blockx = HEAD_DIM / vec_size; + constexpr int blocky = (128 + blockx - 1) / blockx; + dim3 grids_merge(bsz, num_heads); // 128k is too large + dim3 blocks_merge(blockx, blocky); + merge_multi_chunks_decoder_kernel + <<>>( + reinterpret_cast(tmp_workspace->ptr()), + static_cast(tmp_m->ptr()), + static_cast(tmp_d->ptr()), + seq_lens_q.data(), + seq_lens_kv.data(), + seq_lens_encoder.data(), + cum_offsets.data(), + shift_bias ? reinterpret_cast( + const_cast(shift_bias.get().data())) : nullptr, + smooth_weight ? reinterpret_cast(const_cast( + smooth_weight.get().data())) : nullptr, + reinterpret_cast(out->data()), + in_scale, + max_seq_len, + num_chunks, + num_heads, + chunk_size, + HEAD_DIM); + } else { + constexpr int blockx = HEAD_DIM / vec_size; + constexpr int blocky = (128 + blockx - 1) / blockx; + dim3 grids_merge(min(sm_count * 4, token_num), + num_heads); // 128k is too large + dim3 blocks_merge(blockx, blocky); + merge_multi_chunks_v2_kernel + <<>>( + reinterpret_cast(tmp_workspace->ptr()), + static_cast(tmp_m->ptr()), + static_cast(tmp_d->ptr()), + seq_lens_q.data(), + seq_lens_kv.data(), + seq_lens_encoder.data(), + padding_offsets.data(), + shift_bias ? reinterpret_cast( + const_cast(shift_bias.get().data())) : nullptr, + smooth_weight ? reinterpret_cast(const_cast( + smooth_weight.get().data())) : nullptr, + reinterpret_cast(out->data()), + in_scale, + max_seq_len, + num_chunks, + num_heads, + chunk_size, + HEAD_DIM, + token_num, + speculate_max_draft_token_num); + } + } + } +} diff --git a/csrc/gpu/append_attn/append_attention_kernel.h b/csrc/gpu/append_attn/append_attention_kernel.h new file mode 100644 index 000000000000..656d2e89938a --- /dev/null +++ b/csrc/gpu/append_attn/append_attention_kernel.h @@ -0,0 +1,180 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once +#include "append_attention_impl.cuh" +// #define DEBUG_DEC_ATTN + +template +void CascadeAppendAttentionKernel( + const paddle::Tensor& qkv, // [token_num, num_heads, head_dim] + const paddle::Tensor& cache_k, // [max_block_num, num_heads, block_size, head_dim] + const paddle::Tensor& cache_v, // [max_block_num, num_heads, head_dim, block_size] + const paddle::optional& attn_mask, + const paddle::optional& cache_k_scale, // [num_kv_heads, head_dim] + const paddle::optional& cache_v_scale, // [num_kv_heads, head_dim] + const paddle::optional& cache_k_zp, // [num_kv_heads, head_dim] + const paddle::optional& cache_v_zp, // [num_kv_heads, head_dim] + const paddle::optional& shift_bias, // [num_kv_heads, head_dim] + const paddle::optional& smooth_weight, // [num_kv_heads, head_dim] + const paddle::Tensor& seq_lens_q, + const paddle::Tensor& seq_lens_kv, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& padding_offsets, + const paddle::Tensor& cum_offsets, + const paddle::Tensor& block_table, + const paddle::Tensor& batch_ids, + const paddle::Tensor& tile_ids_per_batch, + const std::string& cache_quant_type_str, + const int num_blocks, + const int block_shape_q, + const int max_seq_len, + const int max_dec_len, + const int num_heads, + const int kv_num_heads, + const int head_dim, + const float in_scale, + const int max_partition_size, + const int encoder_max_partition_size, + const int speculate_max_draft_token_num, + const bool causal, + const bool is_decoder, + const bool enable_prefill, + cudaStream_t& stream, + paddle::Tensor* out) { + const auto& q_dims = qkv.dims(); + const auto& k_dims = cache_k.dims(); + const auto& cum_offsets_dims = cum_offsets.dims(); + const uint32_t token_num = q_dims[0]; + const uint32_t block_size = k_dims[2]; + const uint32_t bsz = cum_offsets_dims[0]; + const uint32_t group_size = num_heads / kv_num_heads; + + if (cache_quant_type_str == "none") { + DISPATCH_CAUSAL(causal, CAUSAL, + {DISPATCH_ENABLE_PREFILL(enable_prefill, ENABLE_PREFILL, + {DISPATCH_GQA_GROUP_SIZE(group_size, GROUP_SIZE, + {DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, + {DISPATCH_BLOCK_SIZE(block_size, BLOCK_SIZE, + {DISPATCH_BLOCKSHAPE_Q(block_shape_q, BLOCK_SHAPE_Q, NUM_WARP_Q, + {MultiQueryAppendAttention( + qkv, + cache_k, + cache_v, + attn_mask, + shift_bias, + smooth_weight, + seq_lens_q, + seq_lens_kv, + seq_lens_encoder, + padding_offsets, + cum_offsets, + block_table, + batch_ids, + tile_ids_per_batch, + num_blocks, + max_seq_len, + max_dec_len, + num_heads, + kv_num_heads, + in_scale, + max_partition_size, + encoder_max_partition_size, + speculate_max_draft_token_num, + is_decoder, + stream, + out); + })})})})})}) + } else if (cache_quant_type_str == "cache_int8") { + DISPATCH_CAUSAL(causal, CAUSAL, + {DISPATCH_ENABLE_PREFILL(enable_prefill, ENABLE_PREFILL, + {DISPATCH_GQA_GROUP_SIZE(group_size, GROUP_SIZE, + {DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, + {DISPATCH_BLOCK_SIZE(block_size, BLOCK_SIZE, + {DISPATCH_BLOCKSHAPE_Q(block_shape_q, BLOCK_SHAPE_Q, NUM_WARP_Q, { + MultiQueryAppendC8Attention( + qkv, + cache_k, + cache_v, + attn_mask, + cache_k_scale.get(), + cache_v_scale.get(), + shift_bias, + smooth_weight, + seq_lens_q, + seq_lens_kv, + seq_lens_encoder, + padding_offsets, + cum_offsets, + block_table, + batch_ids, + tile_ids_per_batch, + num_blocks, + max_seq_len, + max_dec_len, + num_heads, + kv_num_heads, + in_scale, + max_partition_size, + encoder_max_partition_size, + speculate_max_draft_token_num, + is_decoder, + stream, + out); + })})})})})}) + } else if (cache_quant_type_str == "cache_int4") { + DISPATCH_CAUSAL(causal, CAUSAL, + {DISPATCH_ENABLE_PREFILL(enable_prefill, ENABLE_PREFILL, + {DISPATCH_GQA_GROUP_SIZE(group_size, GROUP_SIZE, + {DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, + {DISPATCH_BLOCK_SIZE(block_size, BLOCK_SIZE, + {DISPATCH_BLOCKSHAPE_Q(block_shape_q, BLOCK_SHAPE_Q, NUM_WARP_Q, + {MultiQueryAppendC4Attention( + qkv, + cache_k, + cache_v, + attn_mask, + cache_k_scale.get(), + cache_v_scale.get(), + cache_k_zp, + cache_v_zp, + shift_bias, + smooth_weight, + seq_lens_q, + seq_lens_kv, + seq_lens_encoder, + padding_offsets, + cum_offsets, + block_table, + batch_ids, + tile_ids_per_batch, + num_blocks, + max_seq_len, + max_dec_len, + num_heads, + kv_num_heads, + in_scale, + max_partition_size, + encoder_max_partition_size, + speculate_max_draft_token_num, + is_decoder, + stream, + out); + })})})})})}) + } else { + PD_THROW("append attention just support C16/C8/C4_zp now!"); + } +} \ No newline at end of file diff --git a/csrc/gpu/append_attn/decoder_write_cache_with_rope_impl.cuh b/csrc/gpu/append_attn/decoder_write_cache_with_rope_impl.cuh new file mode 100644 index 000000000000..ba12454ee0c1 --- /dev/null +++ b/csrc/gpu/append_attn/decoder_write_cache_with_rope_impl.cuh @@ -0,0 +1,2452 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include "helper.h" +#include "mem_util.cuh" +#include "mma_tensor_op.cuh" +#include "utils.cuh" + +template +__global__ void append_decode_cache_T_rope_kernel( + const T* __restrict__ quant_qkv, // [bsz, num_heads + 2 * gqa_group_size, + // head_size] + T* __restrict__ key_cache, // [num_blocks, gqa_group_size, block_size, + // head_size // 2] + T* __restrict__ value_cache, // [num_blocks, gqa_group_size, block_size, + // head_size // 2] + T* __restrict__ qkv_out, + const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq] + const int* __restrict__ padding_offsets, // [num_tokens] + const int* __restrict__ cum_offsets, + const int* __restrict__ seq_lens, // [bsz] + const int* __restrict__ seq_lens_encoder, // [bsz] + const float* __restrict__ cos_emb, + const float* __restrict__ sin_emb, + const int max_seq_len, + const int max_blocks_per_seq, + const int num_heads, + const int head_size, + const int block_size, + const uint32_t elem_cnt, + const int gqa_group_size) { + using LoadT = AlignedVector; + using LoadBiasT = AlignedVector; + using LoadKVT = AlignedVector; + constexpr int HalfVecSize = VecSize / 2; + using LoadEmbT = AlignedVector; + LoadT src_vec; + LoadBiasT out_vec; + LoadKVT cache_vec; + LoadEmbT cos_emb_vec; + LoadEmbT sin_emb_vec; + + int64_t global_thread_idx = blockDim.x * blockIdx.x + threadIdx.x; + const int64_t hidden_size = (num_heads + 2 * gqa_group_size) * head_size; + // const int64_t offset = 2 * hidden_size; + const int half_head_size = head_size / 2; + for (int32_t linear_index = global_thread_idx * VecSize, + step = gridDim.x * blockDim.x * VecSize; + linear_index < elem_cnt; + linear_index += step) { + const int ori_bi = linear_index / hidden_size; + const int bias = linear_index % hidden_size; + const int hi = bias / head_size; // q + k + v + const int h_bias = bias % head_size; + const int start_token_idx = ori_bi * max_seq_len - cum_offsets[ori_bi]; + if (seq_lens_encoder[ori_bi] > 0) return; + const int write_seq_id = seq_lens[ori_bi]; + if (write_seq_id == 0) continue; + + const int* block_table_now = nullptr; + + block_table_now = block_tables + ori_bi * max_blocks_per_seq; + const int block_idx = block_table_now[write_seq_id / block_size]; + const int block_offset = write_seq_id % block_size; + const uint32_t ori_idx = + start_token_idx * hidden_size + hi * head_size + h_bias; + + const int bias_idx = hi * head_size + h_bias; + Load(&quant_qkv[ori_idx], &src_vec); + if (hi < num_heads + gqa_group_size) { + // q k rope + const uint32_t emb_idx = write_seq_id * half_head_size + h_bias / 2; + Load(&cos_emb[emb_idx], &cos_emb_vec); + Load(&sin_emb[emb_idx], &sin_emb_vec); + } +#pragma unroll + for (int i = 0; i < HalfVecSize; i++) { + // dequant + add_bias + rope + float input_left = static_cast(src_vec[2 * i]); + float input_right = static_cast(src_vec[2 * i + 1]); + + if (hi < num_heads + gqa_group_size) { + const float cos_tmp = cos_emb_vec[i]; + const float sin_tmp = sin_emb_vec[i]; + out_vec[2 * i] = + static_cast(input_left * cos_tmp - input_right * sin_tmp); + out_vec[2 * i + 1] = + static_cast(input_right * cos_tmp + input_left * sin_tmp); + } else { + out_vec[2 * i] = src_vec[2 * i]; + out_vec[2 * i + 1] = src_vec[2 * i + 1]; + } + } + if (hi < num_heads) { + // write q + Store(out_vec, &qkv_out[ori_idx]); + } else { + // quant + write k/v + const uint32_t kv_head_idx = (hi - num_heads) % gqa_group_size; + const uint32_t tgt_idx = + block_idx * gqa_group_size * block_size * head_size + + kv_head_idx * block_size * head_size + block_offset * head_size + + h_bias; + if (hi < num_heads + gqa_group_size) { + Store(out_vec, &key_cache[tgt_idx]); + } else { + Store(out_vec, &value_cache[tgt_idx]); + } + } + } +} + +template +__global__ void append_decode_cache_T_rope_kernel( + const int* __restrict__ quant_qkv, // [bsz, num_heads + 2 * gqa_group_size, + // head_size] + T* __restrict__ key_cache, // [num_blocks, gqa_group_size, block_size, + // head_size // 2] + T* __restrict__ value_cache, // [num_blocks, gqa_group_size, block_size, + // head_size // 2] + T* __restrict__ qkv_out, + const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq] + const int* __restrict__ padding_offsets, // [num_tokens] + const int* __restrict__ cum_offsets, + const int* __restrict__ seq_lens, // [bsz] + const int* __restrict__ seq_lens_encoder, // [bsz] + const float* __restrict__ cos_emb, + const float* __restrict__ sin_emb, + const float* __restrict__ qkv_out_scales, // [num_head + 2 * + // gqa_group_size, dim_head] + const T* __restrict__ qkv_biases, // [num_head + 2 * gqa_group_size, + // dim_head] + const int max_seq_len, + const int max_blocks_per_seq, + const int num_heads, + const int head_size, + const int block_size, + const uint32_t elem_cnt, + const int gqa_group_size) { + using LoadT = AlignedVector; + using LoadBiasT = AlignedVector; + using LoadOutScaleT = AlignedVector; + using LoadKVT = AlignedVector; + constexpr int HalfVecSize = VecSize / 2; + using LoadEmbT = AlignedVector; + LoadT src_vec; + LoadBiasT bias_vec; + LoadOutScaleT out_scale_vec; + LoadKVT cache_vec; + LoadEmbT cos_emb_vec; + LoadEmbT sin_emb_vec; + + int64_t global_thread_idx = blockDim.x * blockIdx.x + threadIdx.x; + const int64_t hidden_size = (num_heads + 2 * gqa_group_size) * head_size; + // const int64_t offset = 2 * hidden_size; + const int half_head_size = head_size / 2; + for (int32_t linear_index = global_thread_idx * VecSize, + step = gridDim.x * blockDim.x * VecSize; + linear_index < elem_cnt; + linear_index += step) { + const int ori_bi = linear_index / hidden_size; + const int bias = linear_index % hidden_size; + const int hi = bias / head_size; // q + k + v + const int h_bias = bias % head_size; + const int start_token_idx = ori_bi * max_seq_len - cum_offsets[ori_bi]; + if (seq_lens_encoder[ori_bi] > 0) return; + const int write_seq_id = seq_lens[ori_bi]; + if (write_seq_id == 0) continue; + + const int* block_table_now = nullptr; + + block_table_now = block_tables + ori_bi * max_blocks_per_seq; + const int block_idx = block_table_now[write_seq_id / block_size]; + const int block_offset = write_seq_id % block_size; + const uint32_t ori_idx = + start_token_idx * hidden_size + hi * head_size + h_bias; + + const int bias_idx = hi * head_size + h_bias; + Load(&quant_qkv[ori_idx], &src_vec); + if (qkv_biases) { + Load(&qkv_biases[bias_idx], &bias_vec); + } + Load(&qkv_out_scales[bias_idx], &out_scale_vec); + if (hi < num_heads + gqa_group_size) { + // q k rope + const uint32_t emb_idx = write_seq_id * half_head_size + h_bias / 2; + Load(&cos_emb[emb_idx], &cos_emb_vec); + Load(&sin_emb[emb_idx], &sin_emb_vec); + } +#pragma unroll + for (int i = 0; i < HalfVecSize; i++) { + // dequant + add_bias + rope + float input_left = static_cast(src_vec[2 * i]); + float input_right = static_cast(src_vec[2 * i + 1]); + input_left = qkv_biases ? input_left * out_scale_vec[2 * i] + + static_cast(bias_vec[2 * i]) : input_left * out_scale_vec[2 * i]; + input_right = qkv_biases ? input_right * out_scale_vec[2 * i + 1] + + static_cast(bias_vec[2 * i + 1]) : input_right * out_scale_vec[2 * i + 1]; + if (hi < num_heads + gqa_group_size) { + const float cos_tmp = cos_emb_vec[i]; + const float sin_tmp = sin_emb_vec[i]; + bias_vec[2 * i] = + static_cast(input_left * cos_tmp - input_right * sin_tmp); + bias_vec[2 * i + 1] = + static_cast(input_right * cos_tmp + input_left * sin_tmp); + } else { + bias_vec[2 * i] = static_cast(input_left); + bias_vec[2 * i + 1] = static_cast(input_right); + } + } + if (hi < num_heads) { + // write q + Store(bias_vec, &qkv_out[ori_idx]); + } else { + // quant + write k/v + const uint32_t kv_head_idx = (hi - num_heads) % gqa_group_size; + const uint32_t tgt_idx = + block_idx * gqa_group_size * block_size * head_size + + kv_head_idx * block_size * head_size + block_offset * head_size + + h_bias; + if (hi < num_heads + gqa_group_size) { + Store(bias_vec, &key_cache[tgt_idx]); + } else { + Store(bias_vec, &value_cache[tgt_idx]); + } + } + } +} + +template +__global__ void append_decode_cache_T_neox_rope_kernel( + const T* __restrict__ qkv, // [bsz, num_heads + 2 * gqa_group_size, + // head_size] + T* __restrict__ key_cache, // [num_blocks, gqa_group_size, block_size, head_size // 2] + T* __restrict__ value_cache, // [num_blocks, gqa_group_size, block_size, head_size // 2] + T* __restrict__ qkv_out, + const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq] + const int* __restrict__ padding_offsets, // [num_tokens] + const int* __restrict__ cum_offsets, + const int* __restrict__ seq_lens, // [bsz] + const int* __restrict__ seq_lens_encoder, // [bsz] + const float* __restrict__ cos_emb, + const float* __restrict__ sin_emb, + const int max_seq_len, + const int max_blocks_per_seq, + const int num_heads, + const int head_size, + const int block_size, + const uint32_t elem_cnt, + const int gqa_group_size) { + using LoadT = AlignedVector; + using LoadBiasT = AlignedVector; + using LoadKVT = AlignedVector; + constexpr int HalfVecSize = VecSize / 2; + using LoadEmbT = AlignedVector; + + LoadT left_vec, right_vec; + LoadBiasT left_bias_vec, right_bias_vec; + LoadKVT left_cache_vec, right_cache_vec; + LoadEmbT cos_emb_vec; + LoadEmbT sin_emb_vec; + + int64_t global_thread_idx = blockDim.x * blockIdx.x + threadIdx.x; + const int half_head_size = head_size / 2; + const int64_t hidden_size = (num_heads + 2 * gqa_group_size) * head_size; + const int64_t half_hidden_size = hidden_size / 2; + // const int64_t offset = 2 * hidden_size; + + for (int32_t linear_index = global_thread_idx * VecSize, + step = gridDim.x * blockDim.x * VecSize; + linear_index < elem_cnt; + linear_index += step) { + const int ori_bi = linear_index / half_hidden_size; + const int bias = linear_index % half_hidden_size; + const int hi = bias / half_head_size; // q + k + v + const int h_bias = bias % half_head_size; + const int start_token_idx = ori_bi * max_seq_len - cum_offsets[ori_bi]; + if (seq_lens_encoder[ori_bi] > 0) return; + const int write_seq_id = seq_lens[ori_bi]; + if (write_seq_id == 0) continue; + + const int* block_table_now = nullptr; + + block_table_now = block_tables + ori_bi * max_blocks_per_seq; + const int block_idx = block_table_now[write_seq_id / block_size]; + const int block_offset = write_seq_id % block_size; + const uint32_t ori_idx_left = + start_token_idx * hidden_size + hi * head_size + h_bias; + const uint32_t ori_idx_right = + ori_idx_left + half_head_size; + + Load(&qkv[ori_idx_left], &left_vec); + Load(&qkv[ori_idx_right], &right_vec); + + if (hi < num_heads + gqa_group_size) { + // q k rope + const uint32_t emb_idx = write_seq_id * head_size + h_bias; + Load(&cos_emb[emb_idx], &cos_emb_vec); + Load(&sin_emb[emb_idx], &sin_emb_vec); + } +#pragma unroll + for (int i = 0; i < VecSize; i++) { + // rope + float input_left = static_cast(left_vec[i]); + float input_right = static_cast(right_vec[i]); + if (hi < num_heads + gqa_group_size) { + const float cos_tmp = cos_emb_vec[i]; + const float sin_tmp = sin_emb_vec[i]; + left_bias_vec[i] = + static_cast(input_left * cos_tmp - input_right * sin_tmp); + right_bias_vec[i] = + static_cast(input_right * cos_tmp + input_left * sin_tmp); + } else { + left_bias_vec[i] = static_cast(input_left); + right_bias_vec[i] = static_cast(input_right); + } + } + if (hi < num_heads) { + // write q + Store(left_bias_vec, &qkv_out[ori_idx_left]); + Store(right_bias_vec, &qkv_out[ori_idx_right]); + } else { + // write k/v + const uint32_t kv_head_idx = (hi - num_heads) % gqa_group_size; + const uint32_t tgt_idx_left = + block_idx * gqa_group_size * block_size * head_size + + kv_head_idx * block_size * head_size + block_offset * head_size + + h_bias; + const uint32_t tgt_idx_right = tgt_idx_left + half_head_size; + if (hi < num_heads + gqa_group_size) { + Store(left_bias_vec, &key_cache[tgt_idx_left]); + Store(right_bias_vec, &key_cache[tgt_idx_right]); + } else { + Store(left_bias_vec, &value_cache[tgt_idx_left]); + Store(right_bias_vec, &value_cache[tgt_idx_right]); + } + } + } +} + +template +__global__ void append_decode_cache_T_neox_rope_kernel( + const int* __restrict__ quant_qkv, // [bsz, num_heads + 2 * gqa_group_size, + // head_size] + T* __restrict__ key_cache, // [num_blocks, gqa_group_size, block_size, head_size // 2] + T* __restrict__ value_cache, // [num_blocks, gqa_group_size, block_size, head_size // 2] + T* __restrict__ qkv_out, + const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq] + const int* __restrict__ padding_offsets, // [num_tokens] + const int* __restrict__ cum_offsets, + const int* __restrict__ seq_lens, // [bsz] + const int* __restrict__ seq_lens_encoder, // [bsz] + const float* __restrict__ cos_emb, + const float* __restrict__ sin_emb, + const float* __restrict__ qkv_out_scales, // [num_head + 2 * gqa_group_size, dim_head] + const T* __restrict__ qkv_biases, // [num_head + 2 * gqa_group_size, dim_head] + const int max_seq_len, + const int max_blocks_per_seq, + const int num_heads, + const int head_size, + const int block_size, + const uint32_t elem_cnt, + const int gqa_group_size) { + using LoadT = AlignedVector; + using LoadBiasT = AlignedVector; + using LoadOutScaleT = AlignedVector; + using LoadKVT = AlignedVector; + constexpr int HalfVecSize = VecSize / 2; + using LoadEmbT = AlignedVector; + // LoadT src_vec; + // LoadBiasT bias_vec; + // LoadOutScaleT out_scale_vec; + // LoadKVT cache_vec; + LoadT left_vec, right_vec; + LoadBiasT left_bias_vec, right_bias_vec; + LoadOutScaleT left_out_scale_vec, right_out_scale_vec; + LoadKVT left_cache_vec, right_cache_vec; + LoadEmbT cos_emb_vec; + LoadEmbT sin_emb_vec; + + int64_t global_thread_idx = blockDim.x * blockIdx.x + threadIdx.x; + const int half_head_size = head_size / 2; + const int64_t hidden_size = (num_heads + 2 * gqa_group_size) * head_size; + const int64_t half_hidden_size = hidden_size / 2; + + for (int32_t linear_index = global_thread_idx * VecSize, + step = gridDim.x * blockDim.x * VecSize; + linear_index < elem_cnt; + linear_index += step) { + const int ori_bi = linear_index / half_hidden_size; + const int bias = linear_index % half_hidden_size; + const int hi = bias / half_head_size; // q + k + v + const int h_bias = bias % half_head_size; + const int start_token_idx = ori_bi * max_seq_len - cum_offsets[ori_bi]; + if (seq_lens_encoder[ori_bi] > 0) return; + const int write_seq_id = seq_lens[ori_bi]; + if (write_seq_id == 0) continue; + + const int* block_table_now = nullptr; + + block_table_now = block_tables + ori_bi * max_blocks_per_seq; + const int block_idx = block_table_now[write_seq_id / block_size]; + const int block_offset = write_seq_id % block_size; + const uint32_t ori_idx_left = + start_token_idx * hidden_size + hi * head_size + h_bias; + const uint32_t ori_idx_right = + ori_idx_left + half_head_size; + + // const int bias_idx = hi * head_size + h_bias; + const int bias_idx_left = hi * head_size + h_bias; + const int bias_idx_right = bias_idx_left + half_head_size; + + // Load(&quant_qkv[ori_idx], &src_vec); + Load(&quant_qkv[ori_idx_left], &left_vec); + Load(&quant_qkv[ori_idx_right], &right_vec); + if (qkv_biases) { + Load(&qkv_biases[bias_idx_left], &left_bias_vec); + Load(&qkv_biases[bias_idx_right], &right_bias_vec); + } + + // Load(&qkv_out_scales[bias_idx], &out_scale_vec); + Load(&qkv_out_scales[bias_idx_left], &left_out_scale_vec); + Load(&qkv_out_scales[bias_idx_right], &right_out_scale_vec); + + if (hi < num_heads + gqa_group_size) { + // q k rope + const uint32_t emb_idx = write_seq_id * head_size + h_bias; + Load(&cos_emb[emb_idx], &cos_emb_vec); + Load(&sin_emb[emb_idx], &sin_emb_vec); + } +#pragma unroll + for (int i = 0; i < VecSize; i++) { + // dequant + add_bias + rope + float input_left = static_cast(left_vec[i]); + float input_right = static_cast(right_vec[i]); + input_left = qkv_biases ? input_left * left_out_scale_vec[i] + + static_cast(left_bias_vec[i]) : input_left * left_out_scale_vec[i]; + input_right = qkv_biases ? input_right * right_out_scale_vec[i] + + static_cast(right_bias_vec[i]) : input_right * right_out_scale_vec[i]; + if (hi < num_heads + gqa_group_size) { + const float cos_tmp = cos_emb_vec[i]; + const float sin_tmp = sin_emb_vec[i]; + left_bias_vec[i] = + static_cast(input_left * cos_tmp - input_right * sin_tmp); + right_bias_vec[i] = + static_cast(input_right * cos_tmp + input_left * sin_tmp); + } else { + left_bias_vec[i] = static_cast(input_left); + right_bias_vec[i] = static_cast(input_right); + } + } + if (hi < num_heads) { + // write q + Store(left_bias_vec, &qkv_out[ori_idx_left]); + Store(right_bias_vec, &qkv_out[ori_idx_right]); + } else { + // quant + write k/v + const uint32_t kv_head_idx = (hi - num_heads) % gqa_group_size; + const uint32_t tgt_idx_left = + block_idx * gqa_group_size * block_size * head_size + + kv_head_idx * block_size * head_size + block_offset * head_size + + h_bias; + const uint32_t tgt_idx_right = tgt_idx_left + half_head_size; + if (hi < num_heads + gqa_group_size) { + Store(left_bias_vec, &key_cache[tgt_idx_left]); + Store(right_bias_vec, &key_cache[tgt_idx_right]); + } else { + Store(left_bias_vec, &value_cache[tgt_idx_left]); + Store(right_bias_vec, &value_cache[tgt_idx_right]); + } + } + } +} + +template +__global__ void append_decode_cache_int8_rope_kernel( + const T* __restrict__ quant_qkv, // [bsz, num_heads + 2 * gqa_group_size, + // head_size] + uint8_t* __restrict__ key_cache, // [num_blocks, gqa_group_size, + // block_size, head_size // 2] + uint8_t* __restrict__ value_cache, // [num_blocks, gqa_group_size, + // block_size, head_size // 2] + T* __restrict__ qkv_out, + const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq] + const int* __restrict__ padding_offsets, // [num_tokens] + const int* __restrict__ cum_offsets, + const int* __restrict__ seq_lens, // [bsz] + const int* __restrict__ seq_lens_encoder, // [bsz] + const float* __restrict__ cos_emb, + const float* __restrict__ sin_emb, + const T* __restrict__ cache_k_scale, + const T* __restrict__ cache_v_scale, + const int max_seq_len, + const int max_blocks_per_seq, + const int num_heads, + const int block_size, + const float max_bound, + const float min_bound, + const int gqa_group_size) { + static_assert(HeadDim == 128, "just support HeadDim be 128 now!"); + static_assert(VecSize == 4, "just support VecSize be 4 now, 32 * 4!"); + constexpr int NUM_WARPS = 4; + const int tid = threadIdx.x; + const int wid = tid / 32; + const int lane_id = tid % 32; + const int bid = blockIdx.x, head_idx = blockIdx.y * NUM_WARPS + wid; + int q_head_idx, k_head_idx, v_idx; + // q : dequant + add_bias + rope + write + // k : dequant + add_bias + rope + quant + write + // v : dequant + add_bias + quant + write + // kv在0位置全补0 + const int64_t hidden_size = (num_heads + 2 * gqa_group_size) * HeadDim; + constexpr int half_head_size = HeadDim / 2; + const int start_token_idx = bid * max_seq_len - __ldg(&cum_offsets[bid]); + if (seq_lens_encoder[bid] > 0) return; + const int write_seq_id = seq_lens[bid]; + if (write_seq_id == 0) return; + const int* block_table_now = nullptr; + + block_table_now = block_tables + bid * max_blocks_per_seq; + const int block_idx = __ldg(&block_table_now[write_seq_id / block_size]); + const int block_offset = write_seq_id % block_size; + + if (head_idx < num_heads) { + // q + using LoadT = AlignedVector; + using LoadBiasT = AlignedVector; + using LoadOutScaleT = AlignedVector; + constexpr int HalfVecSize = VecSize / 2; + using LoadEmbT = AlignedVector; + + LoadT src_vec; + LoadBiasT out_vec; + LoadEmbT cos_emb_vec; + LoadEmbT sin_emb_vec; + const T* qkv_now = quant_qkv + start_token_idx * hidden_size; + T* qkv_out_now = qkv_out + start_token_idx * hidden_size; +#pragma unroll + for (uint32_t head_bias = lane_id * VecSize; head_bias < HeadDim; + head_bias += 32 * VecSize) { + const int bias_idx = head_idx * HeadDim + head_bias; + Load(&qkv_now[bias_idx], &src_vec); + + // q rope + const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2; + Load(&cos_emb[emb_idx], &cos_emb_vec); + Load(&sin_emb[emb_idx], &sin_emb_vec); +#pragma unroll + for (int i = 0; i < HalfVecSize; i++) { + // dequant + add_bias + rope + float input_left = static_cast(src_vec[2 * i]); + float input_right = static_cast(src_vec[2 * i + 1]); + + const float cos_tmp = cos_emb_vec[i]; + const float sin_tmp = sin_emb_vec[i]; + out_vec[2 * i] = + static_cast(input_left * cos_tmp - input_right * sin_tmp); + out_vec[2 * i + 1] = + static_cast(input_right * cos_tmp + input_left * sin_tmp); + } + Store(out_vec, &qkv_out_now[bias_idx]); + } + } else if (head_idx < num_heads + 2 * gqa_group_size) { + // k + constexpr int KV_VEC_SIZE = 16 / sizeof(uint8_t); // 16 + using LoadPadKVT = AlignedVector; + const uint32_t kv_head_idx = (head_idx - num_heads) % gqa_group_size; + if (block_offset == 0) { + // pad zero for this kv_head_idx for this block + LoadPadKVT pad_cache_vec; + *(reinterpret_cast(pad_cache_vec.val)) = make_uint4(0, 0, 0, 0); + if (head_idx < num_heads + gqa_group_size) { + constexpr int num_vecs_per_head_dim = HeadDim / KV_VEC_SIZE; + constexpr int num_token_each_time = 32 / num_vecs_per_head_dim; + const uint32_t tgt_idx = + (block_idx * gqa_group_size + kv_head_idx) * block_size * HeadDim + + lane_id % num_vecs_per_head_dim * KV_VEC_SIZE; + for (int block_i = lane_id / num_vecs_per_head_dim; + block_i < block_size; + block_i += num_token_each_time) { + Store(pad_cache_vec, + &key_cache[tgt_idx + block_i * HeadDim]); + } + } else { + const int num_vecs_per_head_dim = block_size / KV_VEC_SIZE; + const int num_token_each_time = 32 / num_vecs_per_head_dim; + const uint32_t tgt_idx = + (block_idx * gqa_group_size + kv_head_idx) * HeadDim * block_size + + lane_id % num_vecs_per_head_dim * KV_VEC_SIZE; + for (int block_i = lane_id / num_vecs_per_head_dim; block_i < HeadDim; + block_i += num_token_each_time) { + Store( + pad_cache_vec, &value_cache[tgt_idx + block_i * block_size]); + } + } + } + + constexpr int K_VEC_SIZE = 4; + constexpr int HALF_K_VEC_SIZE = 2; + using LoadKVResT = AlignedVector; + using LoadKVT = AlignedVector; + using LoadT = AlignedVector; + using LoadBiasT = AlignedVector; + using LoadOutScaleT = AlignedVector; + using LoadEmbT = AlignedVector; + LoadKVResT cache_vec; + LoadT src_vec1, src_vec2; + LoadBiasT out_vec1, out_vec2; + LoadEmbT cos_emb_vec1, cos_emb_vec2; + LoadEmbT sin_emb_vec1, sin_emb_vec2; + + const T* qkv_now = quant_qkv + start_token_idx * hidden_size; + const int head_bias = lane_id / 4 * 16 + lane_id % 4 * 2; + const int bias_idx = head_idx * HeadDim + head_bias; + Load(&qkv_now[bias_idx], &src_vec1); + Load(&qkv_now[bias_idx + 8], &src_vec2); + T scale; + if (head_idx < num_heads + gqa_group_size) { + const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2; + Load(&cos_emb[emb_idx], &cos_emb_vec1); + Load(&cos_emb[emb_idx + 4], &cos_emb_vec2); + Load(&sin_emb[emb_idx], &sin_emb_vec1); + Load(&sin_emb[emb_idx + 4], &sin_emb_vec2); + scale = __ldg(&cache_k_scale[kv_head_idx]); + } else { + scale = __ldg(&cache_v_scale[kv_head_idx]); + } + + float input_left = static_cast(src_vec1[0]); + float input_right = static_cast(src_vec1[1]); + if (head_idx < num_heads + gqa_group_size) { + float cos_tmp = cos_emb_vec1[0]; + float sin_tmp = sin_emb_vec1[0]; + out_vec1[0] = + static_cast(input_left * cos_tmp - input_right * sin_tmp); + out_vec1[1] = + static_cast(input_right * cos_tmp + input_left * sin_tmp); + } else { + out_vec1[0] = src_vec1[0]; + out_vec1[1] = src_vec1[1]; + } + + input_left = static_cast(src_vec2[0]); + input_right = static_cast(src_vec2[1]); + if (head_idx < num_heads + gqa_group_size) { + float cos_tmp = cos_emb_vec2[0]; + float sin_tmp = sin_emb_vec2[0]; + out_vec2[0] = + static_cast(input_left * cos_tmp - input_right * sin_tmp); + out_vec2[1] = + static_cast(input_right * cos_tmp + input_left * sin_tmp); + } else { + out_vec2[0] = src_vec2[0]; + out_vec2[1] = src_vec2[1]; + } +#pragma unroll + for (uint32_t i = 0; i < HALF_K_VEC_SIZE; i++) { + float quant_value1 = static_cast(scale * out_vec1[i]); + float quant_value2 = static_cast(scale * out_vec2[i]); + if constexpr (RoundType == 0) { + quant_value1 = static_cast(roundWithTiesToEven(quant_value1)); + quant_value2 = static_cast(roundWithTiesToEven(quant_value2)); + } else { + quant_value1 = static_cast(round(quant_value1)); + quant_value2 = static_cast(round(quant_value2)); + } + quant_value1 = quant_value1 > max_bound ? max_bound : quant_value1; + quant_value1 = quant_value1 < min_bound ? min_bound : quant_value1; + quant_value2 = quant_value2 > max_bound ? max_bound : quant_value2; + quant_value2 = quant_value2 < min_bound ? min_bound : quant_value2; + cache_vec[i] = static_cast(quant_value1 + 128.0f); + cache_vec[i + HALF_K_VEC_SIZE] = + static_cast(quant_value2 + 128.0f); + } + if (head_idx < num_heads + gqa_group_size) { + // write k + // 大分块 lane_id / 4 / 2 + // 上下 lane_id / 4 % 2 + // 左16还是右16 (block_offset % 16) / 8 + // 小偏移 lane_id % 4 * 2 + const int start_block_16 = + block_offset / 16 * 16 + block_offset % 8 + lane_id / 4 % 2 * 8; + const uint32_t tgt_cache_idx = + block_idx * gqa_group_size * block_size * HeadDim + + kv_head_idx * block_size * HeadDim + start_block_16 * HeadDim + + lane_id / 4 / 2 * 32 + (block_offset % 16) / 8 * 16 + lane_id % 4 * 4; + Store(cache_vec, &key_cache[tgt_cache_idx]); + } else { + // write v transpose + // 大分块 block_offset / 16 / 2 * 32 + // 大上下 lane_id / 4 * 16 * block_size + lane_id % 4 * 2 + // 小上下 block_offset / 16 % 2 * block_size + // 左16还是右16 + // 小偏移 + const uint32_t base_tgt_cache_idx = + block_idx * gqa_group_size * HeadDim * block_size + + kv_head_idx * HeadDim * block_size + + (lane_id / 4 * 16 + lane_id % 4 * 2) * block_size + + block_offset / 16 % 2 * 8 * block_size + block_offset / 16 / 2 * 32; + const uint32_t tgt_cache_idx1 = base_tgt_cache_idx + + block_offset % 8 / 2 * 4 // per 4 + + block_offset % 16 / 8 * 2 // per 2 + + block_offset % 2; // per 1 + const uint32_t tgt_cache_idx2 = tgt_cache_idx1 + block_size; + const uint32_t tgt_cache_idx3 = tgt_cache_idx1 + 16; + const uint32_t tgt_cache_idx4 = tgt_cache_idx3 + block_size; + value_cache[tgt_cache_idx1] = cache_vec[0]; + value_cache[tgt_cache_idx2] = cache_vec[1]; + value_cache[tgt_cache_idx3] = cache_vec[2]; + value_cache[tgt_cache_idx4] = cache_vec[3]; + } + } +} + +template +__global__ void append_decode_cache_int8_rope_kernel( + const int* __restrict__ quant_qkv, // [bsz, num_heads + 2 * gqa_group_size, + // head_size] + uint8_t* __restrict__ key_cache, // [num_blocks, gqa_group_size, + // block_size, head_size // 2] + uint8_t* __restrict__ value_cache, // [num_blocks, gqa_group_size, + // block_size, head_size // 2] + T* __restrict__ qkv_out, + const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq] + const int* __restrict__ padding_offsets, // [num_tokens] + const int* __restrict__ cum_offsets, + const int* __restrict__ seq_lens, // [bsz] + const int* __restrict__ seq_lens_encoder, // [bsz] + const float* __restrict__ cos_emb, + const float* __restrict__ sin_emb, + const float* __restrict__ qkv_out_scales, // [num_head + 2 * + // gqa_group_size, dim_head] + const T* __restrict__ qkv_biases, // [num_head + 2 * gqa_group_size, + // dim_head] + const T* __restrict__ cache_k_scales, + const T* __restrict__ cache_v_scales, + const int max_seq_len, + const int max_blocks_per_seq, + const int num_heads, + const int block_size, + const float max_bound, + const float min_bound, + const int gqa_group_size) { + static_assert(HeadDim == 128, "just support HeadDim be 128 now!"); + static_assert(VecSize == 4, "just support VecSize be 4 now, 32 * 4!"); + constexpr int NUM_WARPS = 4; + const int tid = threadIdx.x; + const int wid = tid / 32; + const int lane_id = tid % 32; + const int by = blockIdx.y; + const int bid = blockIdx.x, head_idx = blockIdx.y * NUM_WARPS + wid; + int q_head_idx, k_head_idx, v_idx; + // q : dequant + add_bias + rope + write + // k : dequant + add_bias + rope + quant + write + // v : dequant + add_bias + quant + write + // kv在0位置全补0 + const int64_t hidden_size = (num_heads + 2 * gqa_group_size) * HeadDim; + constexpr int half_head_size = HeadDim / 2; + const int start_token_idx = bid * max_seq_len - __ldg(&cum_offsets[bid]); + if (seq_lens_encoder[bid] > 0) return; + const int write_seq_id = seq_lens[bid]; + if (write_seq_id == 0) return; + const int* block_table_now = nullptr; + + block_table_now = block_tables + bid * max_blocks_per_seq; + const int block_idx = __ldg(&block_table_now[write_seq_id / block_size]); + const int block_offset = write_seq_id % block_size; + + if (head_idx < num_heads) { + // q + using LoadT = AlignedVector; + using LoadBiasT = AlignedVector; + using LoadOutScaleT = AlignedVector; + constexpr int HalfVecSize = VecSize / 2; + using LoadEmbT = AlignedVector; + + LoadT src_vec; + LoadBiasT bias_vec; + LoadOutScaleT out_scale_vec; + LoadEmbT cos_emb_vec; + LoadEmbT sin_emb_vec; + const int* qkv_now = quant_qkv + start_token_idx * hidden_size; + T* qkv_out_now = qkv_out + start_token_idx * hidden_size; +#pragma unroll + for (uint32_t head_bias = lane_id * VecSize; head_bias < HeadDim; + head_bias += 32 * VecSize) { + + const int bias_idx = head_idx * HeadDim + head_bias; + Load(&qkv_now[bias_idx], &src_vec); + + if (qkv_biases) { + Load(&qkv_biases[bias_idx], &bias_vec); + } + Load(&qkv_out_scales[bias_idx], &out_scale_vec); + + // q rope + const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2; + Load(&cos_emb[emb_idx], &cos_emb_vec); + + Load(&sin_emb[emb_idx], &sin_emb_vec); + +#pragma unroll + for (int i = 0; i < HalfVecSize; i++) { + // dequant + add_bias + rope + float input_left = static_cast(src_vec[2 * i]); + float input_right = static_cast(src_vec[2 * i + 1]); + input_left = qkv_biases ? input_left * out_scale_vec[2 * i] + + static_cast(bias_vec[2 * i]) + : input_left * out_scale_vec[2 * i]; + input_right = qkv_biases ? input_right * out_scale_vec[2 * i + 1] + + static_cast(bias_vec[2 * i + 1]) + : input_right * out_scale_vec[2 * i + 1]; + const float cos_tmp = cos_emb_vec[i]; + const float sin_tmp = sin_emb_vec[i]; + bias_vec[2 * i] = + static_cast(input_left * cos_tmp - input_right * sin_tmp); + bias_vec[2 * i + 1] = + static_cast(input_right * cos_tmp + input_left * sin_tmp); + } + Store(bias_vec, &qkv_out_now[bias_idx]); + } + } else if (head_idx < num_heads + 2 * gqa_group_size) { + // k + constexpr int KV_VEC_SIZE = 16 / sizeof(uint8_t); // 16 + using LoadPadKVT = AlignedVector; + const uint32_t kv_head_idx = (head_idx - num_heads) % gqa_group_size; + + if (block_offset == 0) { + // pad zero for this kv_head_idx for this block + LoadPadKVT pad_cache_vec; + *(reinterpret_cast(pad_cache_vec.val)) = make_uint4(0, 0, 0, 0); + if (head_idx < num_heads + gqa_group_size) { + constexpr int num_vecs_per_head_dim = HeadDim / KV_VEC_SIZE; + constexpr int num_token_each_time = 32 / num_vecs_per_head_dim; + const uint32_t tgt_idx = + (block_idx * gqa_group_size + kv_head_idx) * block_size * HeadDim + + lane_id % num_vecs_per_head_dim * KV_VEC_SIZE; + for (int block_i = lane_id / num_vecs_per_head_dim; + block_i < block_size; + block_i += num_token_each_time) { + Store(pad_cache_vec, + &key_cache[tgt_idx + block_i * HeadDim]); + } + } else { + const int num_vecs_per_head_dim = block_size / KV_VEC_SIZE; + const int num_token_each_time = 32 / num_vecs_per_head_dim; + const uint32_t tgt_idx = + (block_idx * gqa_group_size + kv_head_idx) * HeadDim * block_size + + lane_id % num_vecs_per_head_dim * KV_VEC_SIZE; + for (int block_i = lane_id / num_vecs_per_head_dim; block_i < HeadDim; + block_i += num_token_each_time) { + Store( + pad_cache_vec, &value_cache[tgt_idx + block_i * block_size]); + } + } + } + + constexpr int K_VEC_SIZE = 4; + constexpr int HALF_K_VEC_SIZE = 2; + using LoadKVResT = AlignedVector; + using LoadKVT = AlignedVector; + using LoadT = AlignedVector; + using LoadBiasT = AlignedVector; + using LoadOutScaleT = AlignedVector; + using LoadEmbT = AlignedVector; + LoadKVResT cache_vec; + LoadT src_vec1, src_vec2; + LoadBiasT bias_vec1, bias_vec2; + LoadOutScaleT out_scale_vec1, out_scale_vec2; + LoadEmbT cos_emb_vec1, cos_emb_vec2; + LoadEmbT sin_emb_vec1, sin_emb_vec2; + + const int* qkv_now = quant_qkv + start_token_idx * hidden_size; + const int head_bias = lane_id / 4 * 16 + lane_id % 4 * 2; + const int bias_idx = head_idx * HeadDim + head_bias; + Load(&qkv_now[bias_idx], &src_vec1); + Load(&qkv_now[bias_idx + 8], &src_vec2); + if (qkv_biases) { + Load(&qkv_biases[bias_idx], &bias_vec1); + Load(&qkv_biases[bias_idx + 8], &bias_vec2); + } + Load(&qkv_out_scales[bias_idx], &out_scale_vec1); + Load(&qkv_out_scales[bias_idx + 8], + &out_scale_vec2); + + T scale; + if (head_idx < num_heads + gqa_group_size) { + const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2; + Load(&cos_emb[emb_idx], &cos_emb_vec1); + Load(&cos_emb[emb_idx + 4], &cos_emb_vec2); + Load(&sin_emb[emb_idx], &sin_emb_vec1); + Load(&sin_emb[emb_idx + 4], &sin_emb_vec2); + scale = __ldg(&cache_k_scales[kv_head_idx]); + } else { + scale = __ldg(&cache_v_scales[kv_head_idx]); + } + + float input_left = static_cast(src_vec1[0]); + float input_right = static_cast(src_vec1[1]); + input_left = qkv_biases ? input_left * out_scale_vec1[0] + + static_cast(bias_vec1[0]) + : input_left * out_scale_vec1[0]; + input_right = qkv_biases ? input_right * out_scale_vec1[1] + + static_cast(bias_vec1[1]) + : input_right * out_scale_vec1[1]; + if (head_idx < num_heads + gqa_group_size) { + float cos_tmp = cos_emb_vec1[0]; + float sin_tmp = sin_emb_vec1[0]; + bias_vec1[0] = + static_cast(input_left * cos_tmp - input_right * sin_tmp); + bias_vec1[1] = + static_cast(input_right * cos_tmp + input_left * sin_tmp); + } else { + bias_vec1[0] = static_cast(input_left); + bias_vec1[1] = static_cast(input_right); + } + + input_left = static_cast(src_vec2[0]); + input_right = static_cast(src_vec2[1]); + input_left = qkv_biases ? input_left * out_scale_vec2[0] + + static_cast(bias_vec2[0]) + : input_left * out_scale_vec2[0]; + input_right = qkv_biases ? input_right * out_scale_vec2[1] + + static_cast(bias_vec2[1]) + : input_right * out_scale_vec2[1]; + if (head_idx < num_heads + gqa_group_size) { + float cos_tmp = cos_emb_vec2[0]; + float sin_tmp = sin_emb_vec2[0]; + bias_vec2[0] = + static_cast(input_left * cos_tmp - input_right * sin_tmp); + bias_vec2[1] = + static_cast(input_right * cos_tmp + input_left * sin_tmp); + } else { + bias_vec2[0] = static_cast(input_left); + bias_vec2[1] = static_cast(input_right); + } +#pragma unroll + for (uint32_t i = 0; i < HALF_K_VEC_SIZE; i++) { + float quant_value1 = static_cast(scale * bias_vec1[i]); + float quant_value2 = static_cast(scale * bias_vec2[i]); + if constexpr (RoundType == 0) { + quant_value1 = static_cast(roundWithTiesToEven(quant_value1)); + quant_value2 = static_cast(roundWithTiesToEven(quant_value2)); + } else { + quant_value1 = static_cast(round(quant_value1)); + quant_value2 = static_cast(round(quant_value2)); + } + quant_value1 = quant_value1 > max_bound ? max_bound : quant_value1; + quant_value1 = quant_value1 < min_bound ? min_bound : quant_value1; + quant_value2 = quant_value2 > max_bound ? max_bound : quant_value2; + quant_value2 = quant_value2 < min_bound ? min_bound : quant_value2; + cache_vec[i] = static_cast(quant_value1 + 128.0f); + cache_vec[i + HALF_K_VEC_SIZE] = + static_cast(quant_value2 + 128.0f); + } + if (head_idx < num_heads + gqa_group_size) { + // write k + // 大分块 lane_id / 4 / 2 + // 上下 lane_id / 4 % 2 + // 左16还是右16 (block_offset % 16) / 8 + // 小偏移 lane_id % 4 * 2 + const int start_block_16 = + block_offset / 16 * 16 + block_offset % 8 + lane_id / 4 % 2 * 8; + const uint32_t tgt_cache_idx = + block_idx * gqa_group_size * block_size * HeadDim + + kv_head_idx * block_size * HeadDim + start_block_16 * HeadDim + + lane_id / 4 / 2 * 32 + (block_offset % 16) / 8 * 16 + lane_id % 4 * 4; + Store(cache_vec, &key_cache[tgt_cache_idx]); + } else { + // write v transpose + // 大分块 block_offset / 16 / 2 * 32 + // 大上下 lane_id / 4 * 16 * block_size + lane_id % 4 * 2 + // 小上下 block_offset / 16 % 2 * block_size + // 左16还是右16 + // 小偏移 + const uint32_t base_tgt_cache_idx = + block_idx * gqa_group_size * HeadDim * block_size + + kv_head_idx * HeadDim * block_size + + (lane_id / 4 * 16 + lane_id % 4 * 2) * block_size + + block_offset / 16 % 2 * 8 * block_size + block_offset / 16 / 2 * 32; + const uint32_t tgt_cache_idx1 = base_tgt_cache_idx + + block_offset % 8 / 2 * 4 // per 4 + + block_offset % 16 / 8 * 2 // per 2 + + block_offset % 2; // per 1 + const uint32_t tgt_cache_idx2 = tgt_cache_idx1 + block_size; + const uint32_t tgt_cache_idx3 = tgt_cache_idx1 + 16; + const uint32_t tgt_cache_idx4 = tgt_cache_idx3 + block_size; + value_cache[tgt_cache_idx1] = cache_vec[0]; + value_cache[tgt_cache_idx2] = cache_vec[1]; + value_cache[tgt_cache_idx3] = cache_vec[2]; + value_cache[tgt_cache_idx4] = cache_vec[3]; + } + } +} + + +template +__global__ void append_decode_cache_int8_neox_rope_kernel( + const T* __restrict__ quant_qkv, // [bsz, num_heads + 2 * gqa_group_size, head_size] + uint8_t* __restrict__ key_cache, // [num_blocks, gqa_group_size, block_size, head_size // 2] + uint8_t* __restrict__ value_cache, // [num_blocks, gqa_group_size, block_size, head_size // 2] + T* __restrict__ qkv_out, + const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq] + const int* __restrict__ padding_offsets, // [num_tokens] + const int* __restrict__ cum_offsets, + const int* __restrict__ seq_lens, // [bsz] + const int* __restrict__ seq_lens_encoder, // [bsz] + const float* __restrict__ cos_emb, + const float* __restrict__ sin_emb, + const T* __restrict__ cache_k_scales, + const T* __restrict__ cache_v_scales, + const int max_seq_len, + const int max_blocks_per_seq, + const int num_heads, + const int block_size, + const float max_bound, + const float min_bound, + const int gqa_group_size) { + static_assert(HeadDim == 128, "just support HeadDim be 128 now!"); + static_assert(VecSize == 4, "just support VecSize be 4 now, 32 * 4!"); + constexpr int NUM_WARPS = 4; + const int tid = threadIdx.x; + const int wid = tid / 32; + const int lane_id = tid % 32; + const int by = blockIdx.y; + const int bid = blockIdx.x, head_idx = blockIdx.y * NUM_WARPS + wid; + int q_head_idx, k_head_idx, v_idx; + // q : dequant + add_bias + rope + write + // k : dequant + add_bias + rope + quant + write + // v : dequant + add_bias + quant + write + // kv在0位置全补0 + const int64_t hidden_size = (num_heads + 2 * gqa_group_size) * HeadDim; + constexpr int half_head_size = HeadDim / 2; + const int start_token_idx = bid * max_seq_len - __ldg(&cum_offsets[bid]); + if (seq_lens_encoder[bid] > 0) return; + const int write_seq_id = seq_lens[bid]; + if (write_seq_id == 0) return; + const int* block_table_now = nullptr; + + block_table_now = block_tables + bid * max_blocks_per_seq; + const int block_idx = __ldg(&block_table_now[write_seq_id / block_size]); + const int block_offset = write_seq_id % block_size; + + if (head_idx < num_heads) { + // q + using LoadT = AlignedVector; + using LoadBiasT = AlignedVector; + constexpr int HalfVecSize = VecSize / 2; + using LoadEmbT = AlignedVector; + + // LoadT src_vec; + LoadT left_vec; + LoadT right_vec; + // LoadBiasT bias_vec; + LoadBiasT left_bias_vec; + LoadBiasT right_bias_vec; + LoadEmbT cos_emb_vec; + LoadEmbT sin_emb_vec; + const T* qkv_now = quant_qkv + start_token_idx * hidden_size; + T* qkv_out_now = qkv_out + start_token_idx * hidden_size; +#pragma unroll + for (uint32_t head_bias = lane_id * VecSize; head_bias < half_head_size; + head_bias += 32 * VecSize) { + + const int bias_idx_left = head_idx * HeadDim + head_bias; + const int bias_idx_right = bias_idx_left + half_head_size; + + Load(&qkv_now[bias_idx_left], &left_vec); + Load(&qkv_now[bias_idx_right], &right_vec); + + // q rope + const uint32_t emb_idx = write_seq_id * HeadDim + head_bias; + Load(&cos_emb[emb_idx], &cos_emb_vec); + Load(&sin_emb[emb_idx], &sin_emb_vec); + +#pragma unroll + for (int i = 0; i < VecSize; i++) { + // dequant + add_bias + rope + float input_left = static_cast(left_vec[i]); + float input_right = static_cast(right_vec[i]); + const float cos_tmp = cos_emb_vec[i]; + const float sin_tmp = sin_emb_vec[i]; + left_bias_vec[i] = + static_cast(input_left * cos_tmp - input_right * sin_tmp); + right_bias_vec[i] = + static_cast(input_right * cos_tmp + input_left * sin_tmp); + } + Store(left_bias_vec, &qkv_out_now[bias_idx_left]); + Store(right_bias_vec, &qkv_out_now[bias_idx_right]); + } + } else if (head_idx < num_heads + 2 * gqa_group_size) { + // k v + constexpr int KV_VEC_SIZE = 16 / sizeof(uint8_t); // 16 + using LoadPadKVT = AlignedVector; + const uint32_t kv_head_idx = (head_idx - num_heads) % gqa_group_size; + if (block_offset == 0) { + // pad zero for this kv_head_idx for this block + LoadPadKVT pad_cache_vec; + *(reinterpret_cast(pad_cache_vec.val)) = make_uint4(0, 0, 0, 0); + if (head_idx < num_heads + gqa_group_size) { + constexpr int num_vecs_per_head_dim = HeadDim / KV_VEC_SIZE; + constexpr int num_token_each_time = 32 / num_vecs_per_head_dim; + const uint32_t tgt_idx = + (block_idx * gqa_group_size + kv_head_idx) * block_size * HeadDim + + lane_id % num_vecs_per_head_dim * KV_VEC_SIZE; + for (int block_i = lane_id / num_vecs_per_head_dim; + block_i < block_size; + block_i += num_token_each_time) { + Store(pad_cache_vec, + &key_cache[tgt_idx + block_i * HeadDim]); + } + } else { + const int num_vecs_per_head_dim = block_size / KV_VEC_SIZE; + const int num_token_each_time = 32 / num_vecs_per_head_dim; + const uint32_t tgt_idx = + (block_idx * gqa_group_size + kv_head_idx) * HeadDim * block_size + + lane_id % num_vecs_per_head_dim * KV_VEC_SIZE; + for (int block_i = lane_id / num_vecs_per_head_dim; block_i < HeadDim; + block_i += num_token_each_time) { + Store( + pad_cache_vec, &value_cache[tgt_idx + block_i * block_size]); + } + } + } + if (head_idx < num_heads + gqa_group_size) { + // k + const int head_bias = lane_id / 4 * 16 + lane_id % 4 * 2; + if (head_bias < half_head_size) { + constexpr int K_VEC_SIZE = 4; + constexpr int HALF_K_VEC_SIZE = 2; + using LoadKVResT = AlignedVector; + using LoadKVT = AlignedVector; + using LoadT = AlignedVector; + using LoadBiasT = AlignedVector; + using LoadEmbT = AlignedVector; + + LoadKVResT left_cache_vec, right_cache_vec; + LoadT left_src_vec1, left_src_vec2, right_src_vec1, right_src_vec2; + LoadBiasT left_bias_vec1, left_bias_vec2, right_bias_vec1, right_bias_vec2; + LoadEmbT cos_emb_vec1, cos_emb_vec2; + LoadEmbT sin_emb_vec1, sin_emb_vec2; + + const T* qkv_now = quant_qkv + start_token_idx * hidden_size; + const int left_bias_idx = head_idx * HeadDim + head_bias; + const int right_bias_idx = left_bias_idx + half_head_size; + + Load(&qkv_now[left_bias_idx], &left_src_vec1); + Load(&qkv_now[left_bias_idx + 8], &left_src_vec2); + Load(&qkv_now[right_bias_idx], &right_src_vec1); + Load(&qkv_now[right_bias_idx + 8], &right_src_vec2); + + T scale; + const uint32_t emb_idx = write_seq_id * HeadDim + head_bias; + Load(&cos_emb[emb_idx], &cos_emb_vec1); + Load(&cos_emb[emb_idx + 8], &cos_emb_vec2); + Load(&sin_emb[emb_idx], &sin_emb_vec1); + Load(&sin_emb[emb_idx + 8], &sin_emb_vec2); + scale = __ldg(&cache_k_scales[kv_head_idx]); + #pragma unroll + for(int i = 0; i < HALF_K_VEC_SIZE; i++) { + float input_left = static_cast(left_src_vec1[i]); + float input_right = static_cast(right_src_vec1[i]); + + float cos_tmp = cos_emb_vec1[i]; + float sin_tmp = sin_emb_vec1[i]; + left_bias_vec1[i] = + static_cast(input_left * cos_tmp - input_right * sin_tmp); + right_bias_vec1[i] = + static_cast(input_right * cos_tmp + input_left * sin_tmp); + + input_left = static_cast(left_src_vec2[i]); + input_right = static_cast(right_src_vec2[i]); + cos_tmp = cos_emb_vec2[i]; + sin_tmp = sin_emb_vec2[i]; + left_bias_vec2[i] = + static_cast(input_left * cos_tmp - input_right * sin_tmp); + right_bias_vec2[i] = + static_cast(input_right * cos_tmp + input_left * sin_tmp); + + float quant_value1 = static_cast(scale * left_bias_vec1[i]); + float quant_value2 = static_cast(scale * left_bias_vec2[i]); + if constexpr (RoundType == 0) { + quant_value1 = static_cast(roundWithTiesToEven(quant_value1)); + quant_value2 = static_cast(roundWithTiesToEven(quant_value2)); + } else { + quant_value1 = static_cast(round(quant_value1)); + quant_value2 = static_cast(round(quant_value2)); + } + quant_value1 = quant_value1 > max_bound ? max_bound : quant_value1; + quant_value1 = quant_value1 < min_bound ? min_bound : quant_value1; + quant_value2 = quant_value2 > max_bound ? max_bound : quant_value2; + quant_value2 = quant_value2 < min_bound ? min_bound : quant_value2; + left_cache_vec[i] = static_cast(quant_value1 + 128.0f); + left_cache_vec[i + HALF_K_VEC_SIZE] = + static_cast(quant_value2 + 128.0f); + + quant_value1 = static_cast(scale * right_bias_vec1[i]); + quant_value2 = static_cast(scale * right_bias_vec2[i]); + if constexpr (RoundType == 0) { + quant_value1 = static_cast(roundWithTiesToEven(quant_value1)); + quant_value2 = static_cast(roundWithTiesToEven(quant_value2)); + } else { + quant_value1 = static_cast(round(quant_value1)); + quant_value2 = static_cast(round(quant_value2)); + } + quant_value1 = quant_value1 > max_bound ? max_bound : quant_value1; + quant_value1 = quant_value1 < min_bound ? min_bound : quant_value1; + quant_value2 = quant_value2 > max_bound ? max_bound : quant_value2; + quant_value2 = quant_value2 < min_bound ? min_bound : quant_value2; + right_cache_vec[i] = static_cast(quant_value1 + 128.0f); + right_cache_vec[i + HALF_K_VEC_SIZE] = + static_cast(quant_value2 + 128.0f); + } + // write k + // 大分块 lane_id / 4 / 2 + // 上下 lane_id / 4 % 2 + // 左16还是右16 (block_offset % 16) / 8 + // 小偏移 lane_id % 4 * 2 + const int left_start_block_16 = + block_offset / 16 * 16 + block_offset % 8 + lane_id / 4 % 2 * 8; + const uint32_t left_tgt_cache_idx = + block_idx * gqa_group_size * block_size * HeadDim + + kv_head_idx * block_size * HeadDim + left_start_block_16 * HeadDim + + lane_id / 4 / 2 * 32 + (block_offset % 16) / 8 * 16 + lane_id % 4 * 4; + + const int right_lane_id = lane_id + 16; + const int right_start_block_16 = + block_offset / 16 * 16 + block_offset % 8 + right_lane_id / 4 % 2 * 8; + const uint32_t right_tgt_cache_idx = + block_idx * gqa_group_size * block_size * HeadDim + + kv_head_idx * block_size * HeadDim + right_start_block_16 * HeadDim + + right_lane_id / 4 / 2 * 32 + (block_offset % 16) / 8 * 16 + right_lane_id % 4 * 4; + Store(left_cache_vec, &key_cache[left_tgt_cache_idx]); + Store(right_cache_vec, &key_cache[right_tgt_cache_idx]); + } + } else { + // v + constexpr int K_VEC_SIZE = 4; + constexpr int HALF_K_VEC_SIZE = 2; + using LoadKVResT = AlignedVector; + using LoadKVT = AlignedVector; + using LoadT = AlignedVector; + using LoadBiasT = AlignedVector; + LoadKVResT cache_vec; + LoadT src_vec1, src_vec2; + LoadBiasT bias_vec1, bias_vec2; + + const T* qkv_now = quant_qkv + start_token_idx * hidden_size; + const int head_bias = lane_id / 4 * 16 + lane_id % 4 * 2; + const int bias_idx = head_idx * HeadDim + head_bias; + Load(&qkv_now[bias_idx], &src_vec1); + Load(&qkv_now[bias_idx + 8], &src_vec2); + + T scale = __ldg(&cache_v_scales[kv_head_idx]); + +#pragma unroll + for (uint32_t i = 0; i < HALF_K_VEC_SIZE; i++) { + float quant_value1 = static_cast(scale * src_vec1[i]); + float quant_value2 = static_cast(scale * src_vec2[i]); + if constexpr (RoundType == 0) { + quant_value1 = static_cast(roundWithTiesToEven(quant_value1)); + quant_value2 = static_cast(roundWithTiesToEven(quant_value2)); + } else { + quant_value1 = static_cast(round(quant_value1)); + quant_value2 = static_cast(round(quant_value2)); + } + quant_value1 = quant_value1 > max_bound ? max_bound : quant_value1; + quant_value1 = quant_value1 < min_bound ? min_bound : quant_value1; + quant_value2 = quant_value2 > max_bound ? max_bound : quant_value2; + quant_value2 = quant_value2 < min_bound ? min_bound : quant_value2; + cache_vec[i] = static_cast(quant_value1 + 128.0f); + cache_vec[i + HALF_K_VEC_SIZE] = + static_cast(quant_value2 + 128.0f); + } + + // write v transpose + // 大分块 block_offset / 16 / 2 * 32 + // 大上下 lane_id / 4 * 16 * block_size + lane_id % 4 * 2 + // 小上下 block_offset / 16 % 2 * block_size + // 左16还是右16 + // 小偏移 + const uint32_t base_tgt_cache_idx = + block_idx * gqa_group_size * HeadDim * block_size + + kv_head_idx * HeadDim * block_size + + (lane_id / 4 * 16 + lane_id % 4 * 2) * block_size + + block_offset / 16 % 2 * 8 * block_size + block_offset / 16 / 2 * 32; + const uint32_t tgt_cache_idx1 = base_tgt_cache_idx + + block_offset % 8 / 2 * 4 // per 4 + + block_offset % 16 / 8 * 2 // per 2 + + block_offset % 2; // per 1 + const uint32_t tgt_cache_idx2 = tgt_cache_idx1 + block_size; + const uint32_t tgt_cache_idx3 = tgt_cache_idx1 + 16; + const uint32_t tgt_cache_idx4 = tgt_cache_idx3 + block_size; + value_cache[tgt_cache_idx1] = cache_vec[0]; + value_cache[tgt_cache_idx2] = cache_vec[1]; + value_cache[tgt_cache_idx3] = cache_vec[2]; + value_cache[tgt_cache_idx4] = cache_vec[3]; + } + } +} + +template +__global__ void append_decode_cache_int8_neox_rope_kernel( + const int* __restrict__ quant_qkv, // [bsz, num_heads + 2 * gqa_group_size, head_size] + uint8_t* __restrict__ key_cache, // [num_blocks, gqa_group_size, block_size, head_size // 2] + uint8_t* __restrict__ value_cache, // [num_blocks, gqa_group_size, block_size, head_size // 2] + T* __restrict__ qkv_out, + const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq] + const int* __restrict__ padding_offsets, // [num_tokens] + const int* __restrict__ cum_offsets, + const int* __restrict__ seq_lens, // [bsz] + const int* __restrict__ seq_lens_encoder, // [bsz] + const float* __restrict__ cos_emb, + const float* __restrict__ sin_emb, + const float* __restrict__ qkv_out_scales, // [num_head + 2 * gqa_group_size, dim_head] + const T* __restrict__ qkv_biases, // [num_head + 2 * gqa_group_size, dim_head] + const T* __restrict__ cache_k_scales, + const T* __restrict__ cache_v_scales, + const int max_seq_len, + const int max_blocks_per_seq, + const int num_heads, + const int block_size, + const float max_bound, + const float min_bound, + const int gqa_group_size) { + static_assert(HeadDim == 128, "just support HeadDim be 128 now!"); + static_assert(VecSize == 4, "just support VecSize be 4 now, 32 * 4!"); + constexpr int NUM_WARPS = 4; + const int tid = threadIdx.x; + const int wid = tid / 32; + const int lane_id = tid % 32; + const int by = blockIdx.y; + const int bid = blockIdx.x, head_idx = blockIdx.y * NUM_WARPS + wid; + int q_head_idx, k_head_idx, v_idx; + // q : dequant + add_bias + rope + write + // k : dequant + add_bias + rope + quant + write + // v : dequant + add_bias + quant + write + // kv在0位置全补0 + const int64_t hidden_size = (num_heads + 2 * gqa_group_size) * HeadDim; + constexpr int half_head_size = HeadDim / 2; + const int start_token_idx = bid * max_seq_len - __ldg(&cum_offsets[bid]); + if (seq_lens_encoder[bid] > 0) return; + const int write_seq_id = seq_lens[bid]; + if (write_seq_id == 0) return; + const int* block_table_now = nullptr; + + block_table_now = block_tables + bid * max_blocks_per_seq; + const int block_idx = __ldg(&block_table_now[write_seq_id / block_size]); + const int block_offset = write_seq_id % block_size; + + if (head_idx < num_heads) { + // q + using LoadT = AlignedVector; + using LoadBiasT = AlignedVector; + using LoadOutScaleT = AlignedVector; + constexpr int HalfVecSize = VecSize / 2; + using LoadEmbT = AlignedVector; + + // LoadT src_vec; + LoadT left_vec; + LoadT right_vec; + // LoadBiasT bias_vec; + LoadBiasT left_bias_vec; + LoadBiasT right_bias_vec; + // LoadOutScaleT out_scale_vec; + LoadOutScaleT left_out_scale_vec; + LoadOutScaleT right_out_scale_vec; + LoadEmbT cos_emb_vec; + LoadEmbT sin_emb_vec; + const int* qkv_now = quant_qkv + start_token_idx * hidden_size; + T* qkv_out_now = qkv_out + start_token_idx * hidden_size; +#pragma unroll + for (uint32_t head_bias = lane_id * VecSize; head_bias < half_head_size; + head_bias += 32 * VecSize) { + + // const int bias_idx = head_idx * HeadDim + head_bias; + const int bias_idx_left = head_idx * HeadDim + head_bias; + const int bias_idx_right = bias_idx_left + half_head_size; + + // Load(&qkv_now[bias_idx], &src_vec); + Load(&qkv_now[bias_idx_left], &left_vec); + Load(&qkv_now[bias_idx_right], &right_vec); + + if (qkv_biases) { + // Load(&qkv_biases[bias_idx], &bias_vec); + Load(&qkv_biases[bias_idx_left], &left_bias_vec); + Load(&qkv_biases[bias_idx_right], &right_bias_vec); + } + // Load(&qkv_out_scales[bias_idx], &out_scale_vec); + Load(&qkv_out_scales[bias_idx_left], + &left_out_scale_vec); + Load(&qkv_out_scales[bias_idx_right], + &right_out_scale_vec); + + // q rope + const uint32_t emb_idx = write_seq_id * HeadDim + head_bias; + Load(&cos_emb[emb_idx], &cos_emb_vec); + Load(&sin_emb[emb_idx], &sin_emb_vec); + +#pragma unroll + for (int i = 0; i < VecSize; i++) { + // dequant + add_bias + rope + float input_left = static_cast(left_vec[i]); + float input_right = static_cast(right_vec[i]); + input_left = qkv_biases ? input_left * left_out_scale_vec[i] + + static_cast(left_bias_vec[i]) + : input_left * left_out_scale_vec[i]; + input_right = qkv_biases ? input_right * right_out_scale_vec[i] + + static_cast(right_bias_vec[i]) + : input_right * right_out_scale_vec[i]; + const float cos_tmp = cos_emb_vec[i]; + const float sin_tmp = sin_emb_vec[i]; + left_bias_vec[i] = + static_cast(input_left * cos_tmp - input_right * sin_tmp); + right_bias_vec[i] = + static_cast(input_right * cos_tmp + input_left * sin_tmp); + } + // Store(bias_vec, &qkv_out_now[bias_idx]); + Store(left_bias_vec, &qkv_out_now[bias_idx_left]); + Store(right_bias_vec, &qkv_out_now[bias_idx_right]); + } + } else if (head_idx < num_heads + 2 * gqa_group_size) { + // k v + constexpr int KV_VEC_SIZE = 16 / sizeof(uint8_t); // 16 + using LoadPadKVT = AlignedVector; + const uint32_t kv_head_idx = (head_idx - num_heads) % gqa_group_size; + if (block_offset == 0) { + // pad zero for this kv_head_idx for this block + LoadPadKVT pad_cache_vec; + *(reinterpret_cast(pad_cache_vec.val)) = make_uint4(0, 0, 0, 0); + if (head_idx < num_heads + gqa_group_size) { + constexpr int num_vecs_per_head_dim = HeadDim / KV_VEC_SIZE; + constexpr int num_token_each_time = 32 / num_vecs_per_head_dim; + const uint32_t tgt_idx = + (block_idx * gqa_group_size + kv_head_idx) * block_size * HeadDim + + lane_id % num_vecs_per_head_dim * KV_VEC_SIZE; + for (int block_i = lane_id / num_vecs_per_head_dim; + block_i < block_size; + block_i += num_token_each_time) { + Store(pad_cache_vec, + &key_cache[tgt_idx + block_i * HeadDim]); + } + } else { + const int num_vecs_per_head_dim = block_size / KV_VEC_SIZE; + const int num_token_each_time = 32 / num_vecs_per_head_dim; + const uint32_t tgt_idx = + (block_idx * gqa_group_size + kv_head_idx) * HeadDim * block_size + + lane_id % num_vecs_per_head_dim * KV_VEC_SIZE; + for (int block_i = lane_id / num_vecs_per_head_dim; block_i < HeadDim; + block_i += num_token_each_time) { + Store( + pad_cache_vec, &value_cache[tgt_idx + block_i * block_size]); + } + } + } + if (head_idx < num_heads + gqa_group_size) { + // k + const int head_bias = lane_id / 4 * 16 + lane_id % 4 * 2; + if (head_bias < half_head_size) { + constexpr int K_VEC_SIZE = 4; + constexpr int HALF_K_VEC_SIZE = 2; + using LoadKVResT = AlignedVector; + using LoadKVT = AlignedVector; + using LoadT = AlignedVector; + using LoadBiasT = AlignedVector; + using LoadOutScaleT = AlignedVector; + using LoadEmbT = AlignedVector; + + LoadKVResT left_cache_vec, right_cache_vec; + LoadT left_src_vec1, left_src_vec2, right_src_vec1, right_src_vec2; + LoadBiasT left_bias_vec1, left_bias_vec2, right_bias_vec1, right_bias_vec2; + LoadOutScaleT left_out_scale_vec1, left_out_scale_vec2, right_out_scale_vec1, right_out_scale_vec2; + LoadEmbT cos_emb_vec1, cos_emb_vec2; + LoadEmbT sin_emb_vec1, sin_emb_vec2; + + const int* qkv_now = quant_qkv + start_token_idx * hidden_size; + const int left_bias_idx = head_idx * HeadDim + head_bias; + const int right_bias_idx = left_bias_idx + half_head_size; + + Load(&qkv_now[left_bias_idx], &left_src_vec1); + Load(&qkv_now[left_bias_idx + 8], &left_src_vec2); + Load(&qkv_now[right_bias_idx], &right_src_vec1); + Load(&qkv_now[right_bias_idx + 8], &right_src_vec2); + if (qkv_biases) { + Load(&qkv_biases[left_bias_idx], &left_bias_vec1); + Load(&qkv_biases[left_bias_idx + 8], &left_bias_vec2); + Load(&qkv_biases[right_bias_idx], &right_bias_vec1); + Load(&qkv_biases[right_bias_idx + 8], &right_bias_vec2); + } + Load(&qkv_out_scales[left_bias_idx], &left_out_scale_vec1); + Load(&qkv_out_scales[left_bias_idx + 8], + &left_out_scale_vec2); + Load(&qkv_out_scales[right_bias_idx], &right_out_scale_vec1); + Load(&qkv_out_scales[right_bias_idx + 8], + &right_out_scale_vec2); + + T scale; + const uint32_t emb_idx = write_seq_id * HeadDim + head_bias; + Load(&cos_emb[emb_idx], &cos_emb_vec1); + Load(&cos_emb[emb_idx + 8], &cos_emb_vec2); + Load(&sin_emb[emb_idx], &sin_emb_vec1); + Load(&sin_emb[emb_idx + 8], &sin_emb_vec2); + scale = __ldg(&cache_k_scales[kv_head_idx]); + #pragma unroll + for(int i = 0; i < HALF_K_VEC_SIZE; i++) { + float input_left = static_cast(left_src_vec1[i]); + float input_right = static_cast(right_src_vec1[i]); + input_left = qkv_biases ? input_left * left_out_scale_vec1[i] + + static_cast(left_bias_vec1[i]) + : input_left * left_out_scale_vec1[i]; + input_right = qkv_biases ? input_right * right_out_scale_vec1[i] + + static_cast(right_bias_vec1[i]) + : input_right * right_out_scale_vec1[i]; + + float cos_tmp = cos_emb_vec1[i]; + float sin_tmp = sin_emb_vec1[i]; + left_bias_vec1[i] = + static_cast(input_left * cos_tmp - input_right * sin_tmp); + right_bias_vec1[i] = + static_cast(input_right * cos_tmp + input_left * sin_tmp); + + input_left = static_cast(left_src_vec2[i]); + input_right = static_cast(right_src_vec2[i]); + input_left = qkv_biases ? input_left * left_out_scale_vec2[i] + + static_cast(left_bias_vec2[i]) + : input_left * left_out_scale_vec2[i]; + input_right = qkv_biases ? input_right * right_out_scale_vec2[i] + + static_cast(right_bias_vec2[i]) + : input_right * right_out_scale_vec2[i]; + cos_tmp = cos_emb_vec2[i]; + sin_tmp = sin_emb_vec2[i]; + left_bias_vec2[i] = + static_cast(input_left * cos_tmp - input_right * sin_tmp); + right_bias_vec2[i] = + static_cast(input_right * cos_tmp + input_left * sin_tmp); + + float quant_value1 = static_cast(scale * left_bias_vec1[i]); + float quant_value2 = static_cast(scale * left_bias_vec2[i]); + if constexpr (RoundType == 0) { + quant_value1 = static_cast(roundWithTiesToEven(quant_value1)); + quant_value2 = static_cast(roundWithTiesToEven(quant_value2)); + } else { + quant_value1 = static_cast(round(quant_value1)); + quant_value2 = static_cast(round(quant_value2)); + } + quant_value1 = quant_value1 > max_bound ? max_bound : quant_value1; + quant_value1 = quant_value1 < min_bound ? min_bound : quant_value1; + quant_value2 = quant_value2 > max_bound ? max_bound : quant_value2; + quant_value2 = quant_value2 < min_bound ? min_bound : quant_value2; + left_cache_vec[i] = static_cast(quant_value1 + 128.0f); + left_cache_vec[i + HALF_K_VEC_SIZE] = + static_cast(quant_value2 + 128.0f); + + quant_value1 = static_cast(scale * right_bias_vec1[i]); + quant_value2 = static_cast(scale * right_bias_vec2[i]); + if constexpr (RoundType == 0) { + quant_value1 = static_cast(roundWithTiesToEven(quant_value1)); + quant_value2 = static_cast(roundWithTiesToEven(quant_value2)); + } else { + quant_value1 = static_cast(round(quant_value1)); + quant_value2 = static_cast(round(quant_value2)); + } + quant_value1 = quant_value1 > max_bound ? max_bound : quant_value1; + quant_value1 = quant_value1 < min_bound ? min_bound : quant_value1; + quant_value2 = quant_value2 > max_bound ? max_bound : quant_value2; + quant_value2 = quant_value2 < min_bound ? min_bound : quant_value2; + right_cache_vec[i] = static_cast(quant_value1 + 128.0f); + right_cache_vec[i + HALF_K_VEC_SIZE] = + static_cast(quant_value2 + 128.0f); + } + // write k + // 大分块 lane_id / 4 / 2 + // 上下 lane_id / 4 % 2 + // 左16还是右16 (block_offset % 16) / 8 + // 小偏移 lane_id % 4 * 2 + const int left_start_block_16 = + block_offset / 16 * 16 + block_offset % 8 + lane_id / 4 % 2 * 8; + const uint32_t left_tgt_cache_idx = + block_idx * gqa_group_size * block_size * HeadDim + + kv_head_idx * block_size * HeadDim + left_start_block_16 * HeadDim + + lane_id / 4 / 2 * 32 + (block_offset % 16) / 8 * 16 + lane_id % 4 * 4; + + const int right_lane_id = lane_id + 16; + const int right_start_block_16 = + block_offset / 16 * 16 + block_offset % 8 + right_lane_id / 4 % 2 * 8; + const uint32_t right_tgt_cache_idx = + block_idx * gqa_group_size * block_size * HeadDim + + kv_head_idx * block_size * HeadDim + right_start_block_16 * HeadDim + + right_lane_id / 4 / 2 * 32 + (block_offset % 16) / 8 * 16 + right_lane_id % 4 * 4; + Store(left_cache_vec, &key_cache[left_tgt_cache_idx]); + Store(right_cache_vec, &key_cache[right_tgt_cache_idx]); + } + } else { + // v + constexpr int K_VEC_SIZE = 4; + constexpr int HALF_K_VEC_SIZE = 2; + using LoadKVResT = AlignedVector; + using LoadKVT = AlignedVector; + using LoadT = AlignedVector; + using LoadBiasT = AlignedVector; + using LoadOutScaleT = AlignedVector; + LoadKVResT cache_vec; + LoadT src_vec1, src_vec2; + LoadBiasT bias_vec1, bias_vec2; + LoadOutScaleT out_scale_vec1, out_scale_vec2; + + const int* qkv_now = quant_qkv + start_token_idx * hidden_size; + const int head_bias = lane_id / 4 * 16 + lane_id % 4 * 2; + const int bias_idx = head_idx * HeadDim + head_bias; + Load(&qkv_now[bias_idx], &src_vec1); + Load(&qkv_now[bias_idx + 8], &src_vec2); + if (qkv_biases) { + Load(&qkv_biases[bias_idx], &bias_vec1); + Load(&qkv_biases[bias_idx + 8], &bias_vec2); + } + Load(&qkv_out_scales[bias_idx], &out_scale_vec1); + Load(&qkv_out_scales[bias_idx + 8], + &out_scale_vec2); + + T scale = __ldg(&cache_v_scales[kv_head_idx]); + + float input_left = static_cast(src_vec1[0]); + float input_right = static_cast(src_vec1[1]); + input_left = qkv_biases ? input_left * out_scale_vec1[0] + + static_cast(bias_vec1[0]) + : input_left * out_scale_vec1[0]; + input_right = qkv_biases ? input_right * out_scale_vec1[1] + + static_cast(bias_vec1[1]) + : input_right * out_scale_vec1[1]; + + bias_vec1[0] = static_cast(input_left); + bias_vec1[1] = static_cast(input_right); + + input_left = static_cast(src_vec2[0]); + input_right = static_cast(src_vec2[1]); + input_left = qkv_biases ? input_left * out_scale_vec2[0] + + static_cast(bias_vec2[0]) + : input_left * out_scale_vec2[0]; + input_right = qkv_biases ? input_right * out_scale_vec2[1] + + static_cast(bias_vec2[1]) + : input_right * out_scale_vec2[1]; + + bias_vec2[0] = static_cast(input_left); + bias_vec2[1] = static_cast(input_right); + +#pragma unroll + for (uint32_t i = 0; i < HALF_K_VEC_SIZE; i++) { + float quant_value1 = static_cast(scale * bias_vec1[i]); + float quant_value2 = static_cast(scale * bias_vec2[i]); + if constexpr (RoundType == 0) { + quant_value1 = static_cast(roundWithTiesToEven(quant_value1)); + quant_value2 = static_cast(roundWithTiesToEven(quant_value2)); + } else { + quant_value1 = static_cast(round(quant_value1)); + quant_value2 = static_cast(round(quant_value2)); + } + quant_value1 = quant_value1 > max_bound ? max_bound : quant_value1; + quant_value1 = quant_value1 < min_bound ? min_bound : quant_value1; + quant_value2 = quant_value2 > max_bound ? max_bound : quant_value2; + quant_value2 = quant_value2 < min_bound ? min_bound : quant_value2; + cache_vec[i] = static_cast(quant_value1 + 128.0f); + cache_vec[i + HALF_K_VEC_SIZE] = + static_cast(quant_value2 + 128.0f); + } + + // write v transpose + // 大分块 block_offset / 16 / 2 * 32 + // 大上下 lane_id / 4 * 16 * block_size + lane_id % 4 * 2 + // 小上下 block_offset / 16 % 2 * block_size + // 左16还是右16 + // 小偏移 + const uint32_t base_tgt_cache_idx = + block_idx * gqa_group_size * HeadDim * block_size + + kv_head_idx * HeadDim * block_size + + (lane_id / 4 * 16 + lane_id % 4 * 2) * block_size + + block_offset / 16 % 2 * 8 * block_size + block_offset / 16 / 2 * 32; + const uint32_t tgt_cache_idx1 = base_tgt_cache_idx + + block_offset % 8 / 2 * 4 // per 4 + + block_offset % 16 / 8 * 2 // per 2 + + block_offset % 2; // per 1 + const uint32_t tgt_cache_idx2 = tgt_cache_idx1 + block_size; + const uint32_t tgt_cache_idx3 = tgt_cache_idx1 + 16; + const uint32_t tgt_cache_idx4 = tgt_cache_idx3 + block_size; + value_cache[tgt_cache_idx1] = cache_vec[0]; + value_cache[tgt_cache_idx2] = cache_vec[1]; + value_cache[tgt_cache_idx3] = cache_vec[2]; + value_cache[tgt_cache_idx4] = cache_vec[3]; + } + } +} + + +template +__global__ void append_decode_cache_int4_rope_kernel( + const T* __restrict__ quant_qkv, // [bsz, num_heads + 2 * gqa_group_size, + // head_size] + uint8_t* __restrict__ key_cache, // [num_blocks, gqa_group_size, + // block_size, head_size // 2] + uint8_t* __restrict__ value_cache, // [num_blocks, gqa_group_size, + // block_size, head_size // 2] + T* __restrict__ qkv_out, + const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq] + const int* __restrict__ padding_offsets, // [num_tokens] + const int* __restrict__ cum_offsets, + const int* __restrict__ seq_lens, // [bsz] + const int* __restrict__ seq_lens_encoder, // [bsz] + const float* __restrict__ cos_emb, + const float* __restrict__ sin_emb, + const T* __restrict__ cache_k_scale, + const T* __restrict__ cache_v_scale, + const T* __restrict__ cache_k_zero_points, + const T* __restrict__ cache_v_zero_points, + const int max_seq_len, + const int max_blocks_per_seq, + const int num_heads, + const int block_size, + const float max_bound, + const float min_bound, + const int gqa_group_size) { + static_assert(HeadDim == 128, "just support HeadDim be 128 now!"); + static_assert(VecSize == 4, "just support VecSize be 4 now, 32 * 4!"); + constexpr int NUM_WARPS = 4; + const int tid = threadIdx.x; + const int wid = tid / 32; + const int lane_id = tid % 32; + const int bid = blockIdx.x, head_idx = blockIdx.y * NUM_WARPS + wid; + // q : dequant + add_bias + rope + write + // k : dequant + add_bias + rope + quant + write + // v : dequant + add_bias + quant + write + // kv在0位置全补0 + const int64_t hidden_size = (num_heads + 2 * gqa_group_size) * HeadDim; + constexpr int half_head_size = HeadDim / 2; + const int half_block_size = block_size / 2; + const int start_token_idx = bid * max_seq_len - __ldg(&cum_offsets[bid]); + if (seq_lens_encoder[bid] > 0) return; + const int write_seq_id = seq_lens[bid]; + if (write_seq_id == 0) return; + const int* block_table_now = nullptr; + + block_table_now = block_tables + bid * max_blocks_per_seq; + + const int block_idx = __ldg(&block_table_now[write_seq_id / block_size]); + const int block_offset = write_seq_id % block_size; + // if (layer_id == 0 && bid == 0 && head_idx == num_heads && wid == 0 && + // lane_id == 0) { + // printf("bid: %d, start_token_idx: %d, num_heads: %d, gqa_group_size: %d, + // head_idx: %d, block_idx: %d, block_offset: %d\n", + // bid, start_token_idx, (int)num_heads, (int)gqa_group_size, + // head_idx, block_idx, block_offset); + // } + // __syncwarp(); + + if (head_idx < num_heads) { + // q + using LoadT = AlignedVector; + using LoadBiasT = AlignedVector; + using LoadOutScaleT = AlignedVector; + constexpr int HalfVecSize = VecSize / 2; + using LoadEmbT = AlignedVector; + + LoadT src_vec; + LoadBiasT out_vec; + LoadEmbT cos_emb_vec; + LoadEmbT sin_emb_vec; + const T* qkv_now = quant_qkv + start_token_idx * hidden_size; + T* qkv_out_now = qkv_out + start_token_idx * hidden_size; +#pragma unroll + for (uint32_t head_bias = lane_id * VecSize; head_bias < HeadDim; + head_bias += 32 * VecSize) { + const int bias_idx = head_idx * HeadDim + head_bias; + Load(&qkv_now[bias_idx], &src_vec); + + // q rope + const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2; + Load(&cos_emb[emb_idx], &cos_emb_vec); + Load(&sin_emb[emb_idx], &sin_emb_vec); +#pragma unroll + for (int i = 0; i < HalfVecSize; i++) { + // dequant + add_bias + rope + float input_left = static_cast(src_vec[2 * i]); + float input_right = static_cast(src_vec[2 * i + 1]); + + const float cos_tmp = cos_emb_vec[i]; + const float sin_tmp = sin_emb_vec[i]; + out_vec[2 * i] = + static_cast(input_left * cos_tmp - input_right * sin_tmp); + out_vec[2 * i + 1] = + static_cast(input_right * cos_tmp + input_left * sin_tmp); + } + Store(out_vec, &qkv_out_now[bias_idx]); + } + } else if (head_idx < num_heads + 2 * gqa_group_size) { + // k + constexpr int KV_VEC_SIZE = 16 / sizeof(uint8_t); // 16 + using LoadPadKVT = AlignedVector; + const uint32_t kv_head_idx = (head_idx - num_heads) % gqa_group_size; + if (block_offset == 0) { + // pad zero for this kv_head_idx for this block + LoadPadKVT pad_cache_vec; + *(reinterpret_cast(pad_cache_vec.val)) = make_uint4(0, 0, 0, 0); + if (head_idx < num_heads + gqa_group_size) { + constexpr int num_vecs_per_head_dim = half_head_size / KV_VEC_SIZE; + constexpr int num_token_each_time = 32 / num_vecs_per_head_dim; + const uint32_t tgt_idx = (block_idx * gqa_group_size + kv_head_idx) * + block_size * half_head_size + + lane_id % num_vecs_per_head_dim * KV_VEC_SIZE; + for (int block_i = lane_id / num_vecs_per_head_dim; + block_i < block_size; + block_i += num_token_each_time) { + Store( + pad_cache_vec, &key_cache[tgt_idx + block_i * half_head_size]); + } + } else { + const int num_vecs_per_head_dim = half_block_size / KV_VEC_SIZE; + const int num_token_each_time = 32 / num_vecs_per_head_dim; + const uint32_t tgt_idx = (block_idx * gqa_group_size + kv_head_idx) * + HeadDim * half_block_size + + lane_id % num_vecs_per_head_dim * KV_VEC_SIZE; + for (int block_i = lane_id / num_vecs_per_head_dim; block_i < HeadDim; + block_i += num_token_each_time) { + Store( + pad_cache_vec, &value_cache[tgt_idx + block_i * half_block_size]); + } + } + } + + constexpr int K_VEC_SIZE = 4; + constexpr int HALF_K_VEC_SIZE = 2; + using LoadKVResT = AlignedVector; + using LoadT = AlignedVector; + using LoadBiasT = AlignedVector; + using LoadOutScaleT = AlignedVector; + using LoadScaleT = AlignedVector; + using LoadEmbT = AlignedVector; + LoadT src_vec1, src_vec2; + LoadBiasT out_vec1, out_vec2; + LoadScaleT scale_vec1, scale_vec2; + LoadScaleT zp_vec1, zp_vec2; + LoadEmbT cos_emb_vec1, cos_emb_vec2; + LoadEmbT sin_emb_vec1, sin_emb_vec2; + + const T* qkv_now = quant_qkv + start_token_idx * hidden_size; + const int head_bias = lane_id / 4 * 16 + lane_id % 4 * 2; + const uint32_t cache_idx = kv_head_idx * HeadDim + head_bias; + const int bias_idx = head_idx * HeadDim + head_bias; + Load(&qkv_now[bias_idx], &src_vec1); + Load(&qkv_now[bias_idx + 8], &src_vec2); + if (head_idx < num_heads + gqa_group_size) { + const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2; + Load(&cos_emb[emb_idx], &cos_emb_vec1); + Load(&cos_emb[emb_idx + 4], &cos_emb_vec2); + Load(&sin_emb[emb_idx], &sin_emb_vec1); + Load(&sin_emb[emb_idx + 4], &sin_emb_vec2); + Load(&cache_k_scale[cache_idx], &scale_vec1); + Load(&cache_k_scale[cache_idx + 8], &scale_vec2); + Load(&cache_k_zero_points[cache_idx], &zp_vec1); + Load(&cache_k_zero_points[cache_idx + 8], &zp_vec2); + } else { + Load(&cache_v_scale[cache_idx], &scale_vec1); + Load(&cache_v_scale[cache_idx + 8], &scale_vec2); + Load(&cache_v_zero_points[cache_idx], &zp_vec1); + Load(&cache_v_zero_points[cache_idx + 8], &zp_vec2); + } + + float input_left = static_cast(src_vec1[0]); + float input_right = static_cast(src_vec1[1]); + if (head_idx < num_heads + gqa_group_size) { + float cos_tmp = cos_emb_vec1[0]; + float sin_tmp = sin_emb_vec1[0]; + out_vec1[0] = + static_cast(input_left * cos_tmp - input_right * sin_tmp); + out_vec1[1] = + static_cast(input_right * cos_tmp + input_left * sin_tmp); + } else { + out_vec1[0] = src_vec1[0]; + out_vec1[1] = src_vec1[1]; + } + + input_left = static_cast(src_vec2[0]); + input_right = static_cast(src_vec2[1]); + if (head_idx < num_heads + gqa_group_size) { + float cos_tmp = cos_emb_vec2[0]; + float sin_tmp = sin_emb_vec2[0]; + out_vec2[0] = + static_cast(input_left * cos_tmp - input_right * sin_tmp); + out_vec2[1] = + static_cast(input_right * cos_tmp + input_left * sin_tmp); + } else { + out_vec2[0] = src_vec2[0]; + out_vec2[1] = src_vec2[1]; + } + if (head_idx < num_heads + gqa_group_size) { + // quant + write k + // 大分块 lane_id / 4 / 4 + // 上下 lane_id / 4 % 4 / 2 + // 左16还是右16 lane_id / 4 % 2 + // 小偏移 lane_id % 4 * 2 + LoadKVResT cache_vec; + const int start_block_16 = + block_offset / 16 * 16 + block_offset % 8 + lane_id / 4 % 4 / 2 * 8; + const uint32_t tgt_cache_idx = + block_idx * gqa_group_size * block_size * half_head_size + + kv_head_idx * block_size * half_head_size + + start_block_16 * half_head_size + lane_id / 4 / 4 * 32 + + lane_id / 4 % 2 * 16 + lane_id % 4 * 4; + Load(&key_cache[tgt_cache_idx], &cache_vec); + // if (layer_id == 0 && bid == 0 && head_idx == num_heads && wid == 0) { + // for (int i = 0; i < 4; i++) { + // printf("lane_id: %d, before cache_vec[%d]: %d, tgt_cache_idx: + // %d\n", (int)lane_id, i, (int)cache_vec[i], (int)tgt_cache_idx); + // } + // } + // __syncwarp(); +#pragma unroll + for (uint32_t i = 0; i < HALF_K_VEC_SIZE; i++) { + float quant_value = + static_cast(scale_vec1[i] * out_vec1[i] + zp_vec1[i]); + if constexpr (RoundType == 0) { + quant_value = roundWithTiesToEven(quant_value); + } else { + quant_value = round(quant_value); + } + quant_value = quant_value > max_bound ? max_bound : quant_value; + quant_value = quant_value < min_bound ? min_bound : quant_value; + uint8_t uint_quant_value = static_cast(quant_value + 8.0f); + uint8_t ano_uint_quant_value = 0; + if (block_offset % 16 / 8 == 0) { + cache_vec[i] |= ((ano_uint_quant_value) | (uint_quant_value & 0x0F)); + } else { + cache_vec[i] |= ((uint_quant_value << 4) | (ano_uint_quant_value)); + } + } +#pragma unroll + for (uint32_t i = 0; i < HALF_K_VEC_SIZE; i++) { + float quant_value = + static_cast(scale_vec2[i] * out_vec2[i] + zp_vec2[i]); + if constexpr (RoundType == 0) { + quant_value = roundWithTiesToEven(quant_value); + } else { + quant_value = round(quant_value); + } + quant_value = quant_value > max_bound ? max_bound : quant_value; + quant_value = quant_value < min_bound ? min_bound : quant_value; + uint8_t uint_quant_value = static_cast(quant_value + 8.0f); + uint8_t ano_uint_quant_value = 0; + if (block_offset % 16 / 8 == 0) { + cache_vec[i + HALF_K_VEC_SIZE] |= + ((ano_uint_quant_value) | (uint_quant_value & 0x0F)); + } else { + cache_vec[i + HALF_K_VEC_SIZE] |= + ((uint_quant_value << 4) | (ano_uint_quant_value)); + } + } + Store(cache_vec, &key_cache[tgt_cache_idx]); + } else { + // quant + write v + // write v transpose + // 大分块 block_offset / 16 / 4 * 32 + // 大上下 lane_id / 4 * 16 * block_size + lane_id % 4 * 2 + // 小上下 block_offset / 16 % 4 / 2 * block_size + // 左16还是右16 block_offset / 16 % 2 * 16 + // 小偏移 + const uint32_t base_tgt_cache_idx = + block_idx * gqa_group_size * HeadDim * half_block_size + + kv_head_idx * HeadDim * half_block_size + + (lane_id / 4 * 16 + lane_id % 4 * 2) * half_block_size + + block_offset / 16 % 4 / 2 * 8 * half_block_size + + block_offset / 16 / 4 * 32 + block_offset / 16 % 2 * 16; + const uint32_t tgt_cache_idx1 = base_tgt_cache_idx + + block_offset % 8 / 2 * 4 // per 4 + + block_offset % 16 / 8 * 2 // per 2 + + block_offset % 2; // per 1 + const uint32_t tgt_cache_idx2 = tgt_cache_idx1 + half_block_size; + + float quant_value1 = + static_cast(scale_vec1[0] * out_vec1[0] + zp_vec1[0]); + float quant_value2 = + static_cast(scale_vec2[0] * out_vec2[0] + zp_vec2[0]); + if constexpr (RoundType == 0) { + quant_value1 = roundWithTiesToEven(quant_value1); + quant_value2 = roundWithTiesToEven(quant_value2); + } else { + quant_value1 = round(quant_value1); + quant_value2 = round(quant_value2); + } + quant_value1 = quant_value1 > max_bound ? max_bound : quant_value1; + quant_value1 = quant_value1 < min_bound ? min_bound : quant_value1; + quant_value2 = quant_value2 > max_bound ? max_bound : quant_value2; + quant_value2 = quant_value2 < min_bound ? min_bound : quant_value2; + uint8_t uint_quant_value1 = static_cast(quant_value1 + 8.0f); + uint8_t uint_quant_value2 = static_cast(quant_value2 + 8.0f); + value_cache[tgt_cache_idx1] = + (uint_quant_value2 << 4) | (uint_quant_value1 & 0x0F); + + quant_value1 = + static_cast(scale_vec1[1] * out_vec1[1] + zp_vec1[1]); + quant_value2 = + static_cast(scale_vec2[1] * out_vec2[1] + zp_vec2[1]); + if constexpr (RoundType == 0) { + quant_value1 = roundWithTiesToEven(quant_value1); + quant_value2 = roundWithTiesToEven(quant_value2); + } else { + quant_value1 = round(quant_value1); + quant_value2 = round(quant_value2); + } + quant_value1 = quant_value1 > max_bound ? max_bound : quant_value1; + quant_value1 = quant_value1 < min_bound ? min_bound : quant_value1; + quant_value2 = quant_value2 > max_bound ? max_bound : quant_value2; + quant_value2 = quant_value2 < min_bound ? min_bound : quant_value2; + uint_quant_value1 = static_cast(quant_value1 + 8.0f); + uint_quant_value2 = static_cast(quant_value2 + 8.0f); + value_cache[tgt_cache_idx2] = + (uint_quant_value2 << 4) | (uint_quant_value1 & 0x0F); + } + } +} + +template +__global__ void append_decode_cache_int4_rope_kernel( + const int* __restrict__ quant_qkv, // [bsz, num_heads + 2 * gqa_group_size, head_size] + uint8_t* __restrict__ key_cache, // [num_blocks, gqa_group_size, block_size, head_size // 2] + uint8_t* __restrict__ value_cache, // [num_blocks, gqa_group_size, block_size, head_size // 2] + T* __restrict__ qkv_out, + const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq] + const int* __restrict__ padding_offsets, // [num_tokens] + const int* __restrict__ cum_offsets, + const int* __restrict__ seq_lens, // [bsz] + const int* __restrict__ seq_lens_encoder, // [bsz] + const float* __restrict__ cos_emb, + const float* __restrict__ sin_emb, + const float* __restrict__ qkv_out_scales, // [num_head + 2 * gqa_group_size, dim_head] + const T* __restrict__ qkv_biases, // [num_head + 2 * gqa_group_size, dim_head] + const T* __restrict__ cache_k_scale, + const T* __restrict__ cache_v_scale, + const T* __restrict__ cache_k_zero_points, + const T* __restrict__ cache_v_zero_points, + const int max_seq_len, + const int max_blocks_per_seq, + const int num_heads, + const int block_size, + const float max_bound, + const float min_bound, + const int gqa_group_size) { + static_assert(HeadDim == 128, "just support HeadDim be 128 now!"); + static_assert(VecSize == 4, "just support VecSize be 4 now, 32 * 4!"); + constexpr int NUM_WARPS = 4; + const int tid = threadIdx.x; + const int wid = tid / 32; + const int lane_id = tid % 32; + const int bid = blockIdx.x, head_idx = blockIdx.y * NUM_WARPS + wid; + // q : dequant + add_bias + rope + write + // k : dequant + add_bias + rope + quant + write + // v : dequant + add_bias + quant + write + // kv在0位置全补0 + const int64_t hidden_size = (num_heads + 2 * gqa_group_size) * HeadDim; + constexpr int half_head_size = HeadDim / 2; + const int half_block_size = block_size / 2; + const int start_token_idx = bid * max_seq_len - __ldg(&cum_offsets[bid]); + if (seq_lens_encoder[bid] > 0) return; + const int write_seq_id = seq_lens[bid]; + if (write_seq_id == 0) return; + const int* block_table_now = nullptr; + + block_table_now = block_tables + bid * max_blocks_per_seq; + + const int block_idx = __ldg(&block_table_now[write_seq_id / block_size]); + const int block_offset = write_seq_id % block_size; + // if (layer_id == 0 && bid == 0 && head_idx == num_heads && wid == 0 && + // lane_id == 0) { + // printf("bid: %d, start_token_idx: %d, num_heads: %d, gqa_group_size: %d, + // head_idx: %d, block_idx: %d, block_offset: %d\n", + // bid, start_token_idx, (int)num_heads, (int)gqa_group_size, + // head_idx, block_idx, block_offset); + // } + // __syncwarp(); + + if (head_idx < num_heads) { + // q + using LoadT = AlignedVector; + using LoadBiasT = AlignedVector; + using LoadOutScaleT = AlignedVector; + constexpr int HalfVecSize = VecSize / 2; + using LoadEmbT = AlignedVector; + + LoadT src_vec; + LoadBiasT bias_vec; + LoadOutScaleT out_scale_vec; + LoadEmbT cos_emb_vec; + LoadEmbT sin_emb_vec; + const int* qkv_now = quant_qkv + start_token_idx * hidden_size; + T* qkv_out_now = qkv_out + start_token_idx * hidden_size; +#pragma unroll + for (uint32_t head_bias = lane_id * VecSize; head_bias < HeadDim; + head_bias += 32 * VecSize) { + const int bias_idx = head_idx * HeadDim + head_bias; + Load(&qkv_now[bias_idx], &src_vec); + if (qkv_biases) { + Load(&qkv_biases[bias_idx], &bias_vec); + } + Load(&qkv_out_scales[bias_idx], &out_scale_vec); + // q rope + const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2; + Load(&cos_emb[emb_idx], &cos_emb_vec); + Load(&sin_emb[emb_idx], &sin_emb_vec); +#pragma unroll + for (int i = 0; i < HalfVecSize; i++) { + // dequant + add_bias + rope + float input_left = static_cast(src_vec[2 * i]); + float input_right = static_cast(src_vec[2 * i + 1]); + input_left = qkv_biases ? input_left * out_scale_vec[2 * i] + + static_cast(bias_vec[2 * i]) + : input_left * out_scale_vec[2 * i]; + input_right = qkv_biases ? input_right * out_scale_vec[2 * i + 1] + + static_cast(bias_vec[2 * i + 1]) + : input_right * out_scale_vec[2 * i + 1]; + const float cos_tmp = cos_emb_vec[i]; + const float sin_tmp = sin_emb_vec[i]; + bias_vec[2 * i] = + static_cast(input_left * cos_tmp - input_right * sin_tmp); + bias_vec[2 * i + 1] = + static_cast(input_right * cos_tmp + input_left * sin_tmp); + } + Store(bias_vec, &qkv_out_now[bias_idx]); + } + } else if (head_idx < num_heads + 2 * gqa_group_size) { + // k + constexpr int KV_VEC_SIZE = 16 / sizeof(uint8_t); // 16 + using LoadPadKVT = AlignedVector; + const uint32_t kv_head_idx = (head_idx - num_heads) % gqa_group_size; + if (block_offset == 0) { + // pad zero for this kv_head_idx for this block + LoadPadKVT pad_cache_vec; + *(reinterpret_cast(pad_cache_vec.val)) = make_uint4(0, 0, 0, 0); + if (head_idx < num_heads + gqa_group_size) { + constexpr int num_vecs_per_head_dim = half_head_size / KV_VEC_SIZE; + constexpr int num_token_each_time = 32 / num_vecs_per_head_dim; + const uint32_t tgt_idx = (block_idx * gqa_group_size + kv_head_idx) * + block_size * half_head_size + + lane_id % num_vecs_per_head_dim * KV_VEC_SIZE; + for (int block_i = lane_id / num_vecs_per_head_dim; + block_i < block_size; + block_i += num_token_each_time) { + Store( + pad_cache_vec, &key_cache[tgt_idx + block_i * half_head_size]); + } + } else { + const int num_vecs_per_head_dim = half_block_size / KV_VEC_SIZE; + const int num_token_each_time = 32 / num_vecs_per_head_dim; + const uint32_t tgt_idx = (block_idx * gqa_group_size + kv_head_idx) * + HeadDim * half_block_size + + lane_id % num_vecs_per_head_dim * KV_VEC_SIZE; + for (int block_i = lane_id / num_vecs_per_head_dim; block_i < HeadDim; + block_i += num_token_each_time) { + Store( + pad_cache_vec, &value_cache[tgt_idx + block_i * half_block_size]); + } + } + } + + constexpr int K_VEC_SIZE = 4; + constexpr int HALF_K_VEC_SIZE = 2; + using LoadKVResT = AlignedVector; + using LoadT = AlignedVector; + using LoadBiasT = AlignedVector; + using LoadOutScaleT = AlignedVector; + using LoadScaleT = AlignedVector; + using LoadEmbT = AlignedVector; + LoadT src_vec1, src_vec2; + LoadBiasT bias_vec1, bias_vec2; + LoadOutScaleT out_scale_vec1, out_scale_vec2; + LoadScaleT scale_vec1, scale_vec2; + LoadScaleT zp_vec1, zp_vec2; + LoadEmbT cos_emb_vec1, cos_emb_vec2; + LoadEmbT sin_emb_vec1, sin_emb_vec2; + + const int* qkv_now = quant_qkv + start_token_idx * hidden_size; + const int head_bias = lane_id / 4 * 16 + lane_id % 4 * 2; + const uint32_t cache_idx = kv_head_idx * HeadDim + head_bias; + const int bias_idx = head_idx * HeadDim + head_bias; + Load(&qkv_now[bias_idx], &src_vec1); + Load(&qkv_now[bias_idx + 8], &src_vec2); + if (qkv_biases) { + Load(&qkv_biases[bias_idx], &bias_vec1); + Load(&qkv_biases[bias_idx + 8], &bias_vec2); + } + Load(&qkv_out_scales[bias_idx], &out_scale_vec1); + Load(&qkv_out_scales[bias_idx + 8], + &out_scale_vec2); + if (head_idx < num_heads + gqa_group_size) { + const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2; + Load(&cos_emb[emb_idx], &cos_emb_vec1); + Load(&cos_emb[emb_idx + 4], &cos_emb_vec2); + Load(&sin_emb[emb_idx], &sin_emb_vec1); + Load(&sin_emb[emb_idx + 4], &sin_emb_vec2); + Load(&cache_k_scale[cache_idx], &scale_vec1); + Load(&cache_k_scale[cache_idx + 8], &scale_vec2); + Load(&cache_k_zero_points[cache_idx], &zp_vec1); + Load(&cache_k_zero_points[cache_idx + 8], &zp_vec2); + } else { + Load(&cache_v_scale[cache_idx], &scale_vec1); + Load(&cache_v_scale[cache_idx + 8], &scale_vec2); + Load(&cache_v_zero_points[cache_idx], &zp_vec1); + Load(&cache_v_zero_points[cache_idx + 8], &zp_vec2); + } + + float input_left = static_cast(src_vec1[0]); + float input_right = static_cast(src_vec1[1]); + input_left = qkv_biases ? + input_left * out_scale_vec1[0] + static_cast(bias_vec1[0]) + : input_left * out_scale_vec1[0]; + input_right = qkv_biases ? + input_right * out_scale_vec1[1] + static_cast(bias_vec1[1]) + : input_right * out_scale_vec1[1]; + if (head_idx < num_heads + gqa_group_size) { + float cos_tmp = cos_emb_vec1[0]; + float sin_tmp = sin_emb_vec1[0]; + bias_vec1[0] = + static_cast(input_left * cos_tmp - input_right * sin_tmp); + bias_vec1[1] = + static_cast(input_right * cos_tmp + input_left * sin_tmp); + } else { + bias_vec1[0] = static_cast(input_left); + bias_vec1[1] = static_cast(input_right); + } + + input_left = static_cast(src_vec2[0]); + input_right = static_cast(src_vec2[1]); + input_left = qkv_biases ? + input_left * out_scale_vec2[0] + static_cast(bias_vec2[0]) + : input_left * out_scale_vec2[0]; + input_right = qkv_biases ? + input_right * out_scale_vec2[1] + static_cast(bias_vec2[1]) + : input_right * out_scale_vec2[1]; + if (head_idx < num_heads + gqa_group_size) { + float cos_tmp = cos_emb_vec2[0]; + float sin_tmp = sin_emb_vec2[0]; + bias_vec2[0] = + static_cast(input_left * cos_tmp - input_right * sin_tmp); + bias_vec2[1] = + static_cast(input_right * cos_tmp + input_left * sin_tmp); + } else { + bias_vec2[0] = static_cast(input_left); + bias_vec2[1] = static_cast(input_right); + } + if (head_idx < num_heads + gqa_group_size) { + // quant + write k + // 大分块 lane_id / 4 / 4 + // 上下 lane_id / 4 % 4 / 2 + // 左16还是右16 lane_id / 4 % 2 + // 小偏移 lane_id % 4 * 2 + LoadKVResT cache_vec; + const int start_block_16 = + block_offset / 16 * 16 + block_offset % 8 + lane_id / 4 % 4 / 2 * 8; + const uint32_t tgt_cache_idx = + block_idx * gqa_group_size * block_size * half_head_size + + kv_head_idx * block_size * half_head_size + + start_block_16 * half_head_size + lane_id / 4 / 4 * 32 + + lane_id / 4 % 2 * 16 + lane_id % 4 * 4; + Load(&key_cache[tgt_cache_idx], &cache_vec); + // if (layer_id == 0 && bid == 0 && head_idx == num_heads && wid == 0) { + // for (int i = 0; i < 4; i++) { + // printf("lane_id: %d, before cache_vec[%d]: %d, tgt_cache_idx: + // %d\n", (int)lane_id, i, (int)cache_vec[i], (int)tgt_cache_idx); + // } + // } + // __syncwarp(); +#pragma unroll + for (uint32_t i = 0; i < HALF_K_VEC_SIZE; i++) { + float quant_value = + static_cast(scale_vec1[i] * bias_vec1[i] + zp_vec1[i]); + if constexpr (RoundType == 0) { + quant_value = roundWithTiesToEven(quant_value); + } else { + quant_value = round(quant_value); + } + quant_value = quant_value > max_bound ? max_bound : quant_value; + quant_value = quant_value < min_bound ? min_bound : quant_value; + uint8_t uint_quant_value = static_cast(quant_value + 8.0f); + uint8_t ano_uint_quant_value = 0; + if (block_offset % 16 / 8 == 0) { + cache_vec[i] |= ((ano_uint_quant_value) | (uint_quant_value & 0x0F)); + } else { + cache_vec[i] |= ((uint_quant_value << 4) | (ano_uint_quant_value)); + } + } +#pragma unroll + for (uint32_t i = 0; i < HALF_K_VEC_SIZE; i++) { + float quant_value = + static_cast(scale_vec2[i] * bias_vec2[i] + zp_vec2[i]); + if constexpr (RoundType == 0) { + quant_value = roundWithTiesToEven(quant_value); + } else { + quant_value = round(quant_value); + } + quant_value = quant_value > max_bound ? max_bound : quant_value; + quant_value = quant_value < min_bound ? min_bound : quant_value; + uint8_t uint_quant_value = static_cast(quant_value + 8.0f); + uint8_t ano_uint_quant_value = 0; + if (block_offset % 16 / 8 == 0) { + cache_vec[i + HALF_K_VEC_SIZE] |= + ((ano_uint_quant_value) | (uint_quant_value & 0x0F)); + } else { + cache_vec[i + HALF_K_VEC_SIZE] |= + ((uint_quant_value << 4) | (ano_uint_quant_value)); + } + } + // if (layer_id == 0 && bid == 0 && head_idx == num_heads && wid == 0) { + // for (int i = 0; i < 4; i++) { + // printf("lane_id: %d, after cache_vec[%d]: %d, tgt_cache_idx: %d\n", + // (int)lane_id, i, (int)cache_vec[i], (int)tgt_cache_idx); + // } + // } + // __syncwarp(); + Store(cache_vec, &key_cache[tgt_cache_idx]); + } else { + // quant + write v + // write v transpose + // 大分块 block_offset / 16 / 4 * 32 + // 大上下 lane_id / 4 * 16 * block_size + lane_id % 4 * 2 + // 小上下 block_offset / 16 % 4 / 2 * block_size + // 左16还是右16 block_offset / 16 % 2 * 16 + // 小偏移 + const uint32_t base_tgt_cache_idx = + block_idx * gqa_group_size * HeadDim * half_block_size + + kv_head_idx * HeadDim * half_block_size + + (lane_id / 4 * 16 + lane_id % 4 * 2) * half_block_size + + block_offset / 16 % 4 / 2 * 8 * half_block_size + + block_offset / 16 / 4 * 32 + block_offset / 16 % 2 * 16; + const uint32_t tgt_cache_idx1 = base_tgt_cache_idx + + block_offset % 8 / 2 * 4 // per 4 + + block_offset % 16 / 8 * 2 // per 2 + + block_offset % 2; // per 1 + const uint32_t tgt_cache_idx2 = tgt_cache_idx1 + half_block_size; + + float quant_value1 = + static_cast(scale_vec1[0] * bias_vec1[0] + zp_vec1[0]); + float quant_value2 = + static_cast(scale_vec2[0] * bias_vec2[0] + zp_vec2[0]); + if constexpr (RoundType == 0) { + quant_value1 = roundWithTiesToEven(quant_value1); + quant_value2 = roundWithTiesToEven(quant_value2); + } else { + quant_value1 = round(quant_value1); + quant_value2 = round(quant_value2); + } + quant_value1 = quant_value1 > max_bound ? max_bound : quant_value1; + quant_value1 = quant_value1 < min_bound ? min_bound : quant_value1; + quant_value2 = quant_value2 > max_bound ? max_bound : quant_value2; + quant_value2 = quant_value2 < min_bound ? min_bound : quant_value2; + uint8_t uint_quant_value1 = static_cast(quant_value1 + 8.0f); + uint8_t uint_quant_value2 = static_cast(quant_value2 + 8.0f); + value_cache[tgt_cache_idx1] = + (uint_quant_value2 << 4) | (uint_quant_value1 & 0x0F); + + quant_value1 = + static_cast(scale_vec1[1] * bias_vec1[1] + zp_vec1[1]); + quant_value2 = + static_cast(scale_vec2[1] * bias_vec2[1] + zp_vec2[1]); + if constexpr (RoundType == 0) { + quant_value1 = roundWithTiesToEven(quant_value1); + quant_value2 = roundWithTiesToEven(quant_value2); + } else { + quant_value1 = round(quant_value1); + quant_value2 = round(quant_value2); + } + quant_value1 = quant_value1 > max_bound ? max_bound : quant_value1; + quant_value1 = quant_value1 < min_bound ? min_bound : quant_value1; + quant_value2 = quant_value2 > max_bound ? max_bound : quant_value2; + quant_value2 = quant_value2 < min_bound ? min_bound : quant_value2; + uint_quant_value1 = static_cast(quant_value1 + 8.0f); + uint_quant_value2 = static_cast(quant_value2 + 8.0f); + value_cache[tgt_cache_idx2] = + (uint_quant_value2 << 4) | (uint_quant_value1 & 0x0F); + } + } +} + +template +__global__ void append_decode_cache_int4_neox_rope_kernel( + const T* __restrict__ quant_qkv, // [bsz, num_heads + 2 * gqa_group_size, + // head_size] + uint8_t* __restrict__ key_cache, // [num_blocks, gqa_group_size, + // block_size, head_size // 2] + uint8_t* __restrict__ value_cache, // [num_blocks, gqa_group_size, + // block_size, head_size // 2] + T* __restrict__ qkv_out, + const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq] + const int* __restrict__ padding_offsets, // [num_tokens] + const int* __restrict__ cum_offsets, + const int* __restrict__ seq_lens, // [bsz] + const int* __restrict__ seq_lens_encoder, // [bsz] + const float* __restrict__ cos_emb, + const float* __restrict__ sin_emb, + const T* __restrict__ cache_k_scale, + const T* __restrict__ cache_v_scale, + const T* __restrict__ cache_k_zero_points, + const T* __restrict__ cache_v_zero_points, + const int max_seq_len, + const int max_blocks_per_seq, + const int num_heads, + const int block_size, + const float max_bound, + const float min_bound, + const int gqa_group_size) { + + } + +template +__global__ void append_decode_cache_int4_neox_rope_kernel( + const int* __restrict__ quant_qkv, // [bsz, num_heads + 2 * gqa_group_size, + // head_size] + uint8_t* __restrict__ key_cache, // [num_blocks, gqa_group_size, + // block_size, head_size // 2] + uint8_t* __restrict__ value_cache, // [num_blocks, gqa_group_size, + // block_size, head_size // 2] + T* __restrict__ qkv_out, + const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq] + const int* __restrict__ padding_offsets, // [num_tokens] + const int* __restrict__ cum_offsets, + const int* __restrict__ seq_lens, // [bsz] + const int* __restrict__ seq_lens_encoder, // [bsz] + const float* __restrict__ cos_emb, + const float* __restrict__ sin_emb, + const float* __restrict__ qkv_out_scales, // [num_head + 2 * + // gqa_group_size, dim_head] + const T* __restrict__ qkv_biases, // [num_head + 2 * gqa_group_size, + // dim_head] + const T* __restrict__ cache_k_scale, + const T* __restrict__ cache_v_scale, + const T* __restrict__ cache_k_zero_points, + const T* __restrict__ cache_v_zero_points, + const int max_seq_len, + const int max_blocks_per_seq, + const int num_heads, + const int block_size, + const float max_bound, + const float min_bound, + const int gqa_group_size) { + + } \ No newline at end of file diff --git a/csrc/gpu/append_attn/decoder_write_cache_with_rope_kernel.cu b/csrc/gpu/append_attn/decoder_write_cache_with_rope_kernel.cu new file mode 100644 index 000000000000..6e3c74cb0db8 --- /dev/null +++ b/csrc/gpu/append_attn/decoder_write_cache_with_rope_kernel.cu @@ -0,0 +1,600 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "decoder_write_cache_with_rope_kernel.h" + +template +void append_decode_cache_rope(const QKV_TYPE* qkv, + T* key_cache, + T* value_cache, + T* qkv_out, + const int* block_tables, + const int* padding_offsets, + const int* cum_offsets, + const int* seq_lens, + const int* seq_lens_encoder, + const float* cos_emb, + const float* sin_emb, + const float* qkv_out_scales, + const T* qkv_biases, + const int max_seq_len, + const int max_blocks_per_seq, + const int num_heads, + const int kv_num_heads, + const int head_size, + const int block_size, + const int bsz, + const cudaStream_t& stream, + const bool use_neox_style) { + const uint32_t elem_nums = use_neox_style ? bsz * (num_heads + 2 * kv_num_heads) * head_size / 2 : bsz * (num_heads + 2 * kv_num_heads) * head_size; + + constexpr int PackSize = 16 / sizeof(T); + const int pack_num = elem_nums / PackSize; + const int blocksize = 128; + int grid_size = 1; + GetNumBlocks<128>(pack_num, &grid_size); + if (use_neox_style) { + if (qkv_out_scales) { + append_decode_cache_T_neox_rope_kernel + <<>>( + reinterpret_cast(qkv), + key_cache, + value_cache, + qkv_out, + block_tables, + padding_offsets, + cum_offsets, + seq_lens, + seq_lens_encoder, + cos_emb, + sin_emb, + qkv_out_scales, + qkv_biases, + max_seq_len, + max_blocks_per_seq, + num_heads, + head_size, + block_size, + elem_nums, + kv_num_heads); + } else { + append_decode_cache_T_neox_rope_kernel + <<>>( + reinterpret_cast(qkv), + key_cache, + value_cache, + qkv_out, + block_tables, + padding_offsets, + cum_offsets, + seq_lens, + seq_lens_encoder, + cos_emb, + sin_emb, + max_seq_len, + max_blocks_per_seq, + num_heads, + head_size, + block_size, + elem_nums, + kv_num_heads); + } + } else { + if (qkv_out_scales) { + append_decode_cache_T_rope_kernel + <<>>( + reinterpret_cast(qkv), + key_cache, + value_cache, + qkv_out, + block_tables, + padding_offsets, + cum_offsets, + seq_lens, + seq_lens_encoder, + cos_emb, + sin_emb, + qkv_out_scales, + qkv_biases, + max_seq_len, + max_blocks_per_seq, + num_heads, + head_size, + block_size, + elem_nums, + kv_num_heads); + } else { + append_decode_cache_T_rope_kernel + <<>>( + reinterpret_cast(qkv), + key_cache, + value_cache, + qkv_out, + block_tables, + padding_offsets, + cum_offsets, + seq_lens, + seq_lens_encoder, + cos_emb, + sin_emb, + max_seq_len, + max_blocks_per_seq, + num_heads, + head_size, + block_size, + elem_nums, + kv_num_heads); + } + } + +} + +template +void append_decode_cache_int8_rope(const QKV_TYPE* qkv, + uint8_t* key_cache, + uint8_t* value_cache, + T* qkv_out, + const int* block_tables, + const int* padding_offsets, + const int* cum_offsets, + const int* seq_lens, + const int* seq_lens_encoder, + const float* cos_emb, + const float* sin_emb, + const float* qkv_out_scales, + const T* qkv_biases, + const T* cache_k_scale, + const T* cache_v_scale, + const int max_seq_len, + const int max_blocks_per_seq, + const int num_heads, + const int kv_num_heads, + const int head_size, + const int block_size, + const int bsz, + const cudaStream_t& stream, + const bool use_neox_style) { + constexpr int num_warps = 4; + const int all_warps = ((num_heads + 2 * kv_num_heads) + num_warps - 1) / + num_warps * num_warps; + dim3 grids(bsz, all_warps / num_warps); + if (use_neox_style) { + if (qkv_out_scales) { + append_decode_cache_int8_neox_rope_kernel + <<>>( + reinterpret_cast(qkv), + key_cache, + value_cache, + qkv_out, + block_tables, + padding_offsets, + cum_offsets, + seq_lens, + seq_lens_encoder, + cos_emb, + sin_emb, + qkv_out_scales, + qkv_biases, + cache_k_scale, + cache_v_scale, + max_seq_len, + max_blocks_per_seq, + num_heads, + block_size, + 127.0f, + -127.0f, + kv_num_heads); + } else { + append_decode_cache_int8_neox_rope_kernel + <<>>( + reinterpret_cast(qkv), + key_cache, + value_cache, + qkv_out, + block_tables, + padding_offsets, + cum_offsets, + seq_lens, + seq_lens_encoder, + cos_emb, + sin_emb, + cache_k_scale, + cache_v_scale, + max_seq_len, + max_blocks_per_seq, + num_heads, + block_size, + 127.0f, + -127.0f, + kv_num_heads); + } + } else { + if (qkv_out_scales) { + append_decode_cache_int8_rope_kernel + <<>>( + reinterpret_cast(qkv), + key_cache, + value_cache, + qkv_out, + block_tables, + padding_offsets, + cum_offsets, + seq_lens, + seq_lens_encoder, + cos_emb, + sin_emb, + qkv_out_scales, + qkv_biases, + cache_k_scale, + cache_v_scale, + max_seq_len, + max_blocks_per_seq, + num_heads, + block_size, + 127.0f, + -127.0f, + kv_num_heads); + } else { + append_decode_cache_int8_rope_kernel + <<>>( + reinterpret_cast(qkv), + key_cache, + value_cache, + qkv_out, + block_tables, + padding_offsets, + cum_offsets, + seq_lens, + seq_lens_encoder, + cos_emb, + sin_emb, + cache_k_scale, + cache_v_scale, + max_seq_len, + max_blocks_per_seq, + num_heads, + block_size, + 127.0f, + -127.0f, + kv_num_heads); + } + } +} + +template +void append_decode_cache_int4_rope(const QKV_TYPE* qkv, + uint8_t* key_cache, + uint8_t* value_cache, + T* qkv_out, + const int* block_tables, + const int* padding_offsets, + const int* cum_offsets, + const int* seq_lens, + const int* seq_lens_encoder, + const float* cos_emb, + const float* sin_emb, + const float* qkv_out_scales, + const T* qkv_biases, + const T* cache_k_scale, + const T* cache_v_scale, + const T* cache_k_zp, + const T* cache_v_zp, + const int max_seq_len, + const int max_blocks_per_seq, + const int num_heads, + const int kv_num_heads, + const int head_size, + const int block_size, + const int bsz, + const cudaStream_t& stream, + const bool use_neox_style) { + constexpr int num_warps = 4; + const int all_warps = ((num_heads + 2 * kv_num_heads) + num_warps - 1) / + num_warps * num_warps; + dim3 grids(bsz, all_warps / num_warps); + if (use_neox_style) { + if (qkv_out_scales) { + append_decode_cache_int4_neox_rope_kernel + <<>>( + reinterpret_cast(qkv), + key_cache, + value_cache, + qkv_out, + block_tables, + padding_offsets, + cum_offsets, + seq_lens, + seq_lens_encoder, + cos_emb, + sin_emb, + qkv_out_scales, + qkv_biases, + cache_k_scale, + cache_v_scale, + cache_k_zp, + cache_v_zp, + max_seq_len, + max_blocks_per_seq, + num_heads, + block_size, + 7.0f, + -8.0f, + kv_num_heads); + } else { + append_decode_cache_int4_neox_rope_kernel + <<>>( + reinterpret_cast(qkv), + key_cache, + value_cache, + qkv_out, + block_tables, + padding_offsets, + cum_offsets, + seq_lens, + seq_lens_encoder, + cos_emb, + sin_emb, + cache_k_scale, + cache_v_scale, + cache_k_zp, + cache_v_zp, + max_seq_len, + max_blocks_per_seq, + num_heads, + block_size, + 7.0f, + -8.0f, + kv_num_heads); + } + } else { + if (qkv_out_scales) { + append_decode_cache_int4_rope_kernel + <<>>( + reinterpret_cast(qkv), + key_cache, + value_cache, + qkv_out, + block_tables, + padding_offsets, + cum_offsets, + seq_lens, + seq_lens_encoder, + cos_emb, + sin_emb, + qkv_out_scales, + qkv_biases, + cache_k_scale, + cache_v_scale, + cache_k_zp, + cache_v_zp, + max_seq_len, + max_blocks_per_seq, + num_heads, + block_size, + 7.0f, + -8.0f, + kv_num_heads); + } else { + append_decode_cache_int4_rope_kernel + <<>>( + reinterpret_cast(qkv), + key_cache, + value_cache, + qkv_out, + block_tables, + padding_offsets, + cum_offsets, + seq_lens, + seq_lens_encoder, + cos_emb, + sin_emb, + cache_k_scale, + cache_v_scale, + cache_k_zp, + cache_v_zp, + max_seq_len, + max_blocks_per_seq, + num_heads, + block_size, + 7.0f, + -8.0f, + kv_num_heads); + } + } +} +template +void DecoderWriteCacheWithRoPEKernel( + const paddle::Tensor& qkv, + const paddle::Tensor& seq_lens, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& padding_offsets, + const paddle::Tensor& cum_offsets, + const paddle::Tensor& block_tables, + const paddle::optional& rotary_embs, + const paddle::optional& qkv_out_scales, + const paddle::optional& qkv_biases, + const paddle::optional& cache_k_scale, + const paddle::optional& cache_v_scale, + const paddle::optional& cache_k_zp, + const paddle::optional& cache_v_zp, + const std::string& cache_quant_type_str, + const bool use_neox_rotary_style, + const int max_seq_len, + const int num_heads, + const int kv_num_heads, + const int head_size, + cudaStream_t& stream, + paddle::Tensor* qkv_out, + paddle::Tensor* key_cache_out, + paddle::Tensor* value_cache_out) { + typedef cascade_attn_type_traits traits_; + typedef cascade_attn_type_traits qkt_nv_type_; + typedef typename traits_::type DataType_; + typedef typename qkt_nv_type_::type QKV_Data_TYPE; + const QKV_TYPE* qkv_ptr = qkv.data(); + auto qkv_dims = qkv.dims(); + const int max_blocks_per_seq = block_tables.dims()[1]; + const int bsz = cum_offsets.dims()[0]; + + // VLOG(1) << "gqa_group_size: " << gqa_group_size; + const int32_t block_size = key_cache_out->dims()[2]; + + const float* cos_emb = rotary_embs ? rotary_embs.get().data() : nullptr; + const float* sin_emb; + if (rotary_embs) { + sin_emb = use_neox_rotary_style ? rotary_embs.get().data() + max_seq_len * head_size : rotary_embs.get().data() + max_seq_len * head_size / 2; + } + if (cache_quant_type_str == "none") { + append_decode_cache_rope( + reinterpret_cast(qkv_ptr), + reinterpret_cast(key_cache_out->data()), + reinterpret_cast(value_cache_out->data()), + reinterpret_cast(qkv_out->data()), + block_tables.data(), + padding_offsets.data(), + cum_offsets.data(), + seq_lens.data(), + seq_lens_encoder.data(), + cos_emb, + sin_emb, + qkv_out_scales ? qkv_out_scales.get().data() : nullptr, + qkv_biases ? reinterpret_cast( + const_cast(qkv_biases.get().data())) : nullptr, + max_seq_len, + max_blocks_per_seq, + num_heads, + kv_num_heads, + head_size, + block_size, + bsz, + stream, + use_neox_rotary_style); + } else if (cache_quant_type_str == "cache_int8") { + append_decode_cache_int8_rope( + reinterpret_cast(qkv_ptr), + key_cache_out->data(), + value_cache_out->data(), + reinterpret_cast(qkv_out->data()), + block_tables.data(), + padding_offsets.data(), + cum_offsets.data(), + seq_lens.data(), + seq_lens_encoder.data(), + cos_emb, + sin_emb, + qkv_out_scales ? qkv_out_scales.get().data() : nullptr, + qkv_biases ? reinterpret_cast( + const_cast(qkv_biases.get().data())) : nullptr, + cache_k_scale ? reinterpret_cast( + const_cast(cache_k_scale.get().data())) : nullptr, + cache_v_scale ? reinterpret_cast( + const_cast(cache_v_scale.get().data())) : nullptr, + max_seq_len, + max_blocks_per_seq, + num_heads, + kv_num_heads, + head_size, + block_size, + bsz, + stream, + use_neox_rotary_style); + } else if (cache_quant_type_str == "cache_int4_zp") { + append_decode_cache_int4_rope( + reinterpret_cast(qkv_ptr), + key_cache_out->data(), + value_cache_out->data(), + reinterpret_cast(const_cast(qkv_out->data())), + block_tables.data(), + padding_offsets.data(), + cum_offsets.data(), + seq_lens.data(), + seq_lens_encoder.data(), + cos_emb, + sin_emb, + qkv_out_scales ? qkv_out_scales.get().data() : nullptr, + qkv_biases ? reinterpret_cast( + const_cast(qkv_biases.get().data())) : nullptr, + cache_k_scale ? reinterpret_cast( + const_cast(cache_k_scale.get().data())) : nullptr, + cache_v_scale ? reinterpret_cast( + const_cast(cache_v_scale.get().data())) : nullptr, + cache_k_zp ? reinterpret_cast( + const_cast(cache_k_zp.get().data())) : nullptr, + cache_v_zp ? reinterpret_cast( + const_cast(cache_v_zp.get().data())) : nullptr, + max_seq_len, + max_blocks_per_seq, + num_heads, + kv_num_heads, + head_size, + block_size, + bsz, + stream, + use_neox_rotary_style); + } else { + PD_THROW("append attention just support C16/C8/C4_zp now!"); + } +} + +template void DecoderWriteCacheWithRoPEKernel( + const paddle::Tensor& qkv, // [token_num, 3, num_head, head_dim] ([token_num, num_head + 2 * + // gqa_group_size, head_dim] if GQA) + const paddle::Tensor& seq_lens, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& padding_offsets, + const paddle::Tensor& cum_offsets, + const paddle::Tensor& block_tables, + const paddle::optional& rotary_embs, + const paddle::optional& qkv_out_scales, + const paddle::optional& qkv_biases, + const paddle::optional& cache_k_scale, + const paddle::optional& cache_v_scale, + const paddle::optional& cache_k_zp, + const paddle::optional& cache_v_zp, + const std::string& cache_quant_type_str, + const bool use_neox_rotary_style, + const int max_seq_len, + const int num_heads, + const int kv_num_heads, + const int head_size, + cudaStream_t& stream, + paddle::Tensor* qkv_out, + paddle::Tensor* key_cache_out, + paddle::Tensor* value_cache_out); + +template void DecoderWriteCacheWithRoPEKernel( + const paddle::Tensor& qkv, // [token_num, 3, num_head, head_dim] ([token_num, num_head + 2 * + // gqa_group_size, head_dim] if GQA) + const paddle::Tensor& seq_lens, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& padding_offsets, + const paddle::Tensor& cum_offsets, + const paddle::Tensor& block_tables, + const paddle::optional& rotary_embs, + const paddle::optional& qkv_out_scales, + const paddle::optional& qkv_biases, + const paddle::optional& cache_k_scale, + const paddle::optional& cache_v_scale, + const paddle::optional& cache_k_zp, + const paddle::optional& cache_v_zp, + const std::string& cache_quant_type_str, + const bool use_neox_rotary_style, + const int max_seq_len, + const int num_heads, + const int kv_num_heads, + const int head_size, + cudaStream_t& stream, + paddle::Tensor* qkv_out, + paddle::Tensor* key_cache_out, + paddle::Tensor* value_cache_out); \ No newline at end of file diff --git a/csrc/gpu/append_attn/decoder_write_cache_with_rope_kernel.h b/csrc/gpu/append_attn/decoder_write_cache_with_rope_kernel.h new file mode 100644 index 000000000000..9c5164ea5402 --- /dev/null +++ b/csrc/gpu/append_attn/decoder_write_cache_with_rope_kernel.h @@ -0,0 +1,43 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include "decoder_write_cache_with_rope_impl.cuh" + +template +void DecoderWriteCacheWithRoPEKernel( + const paddle::Tensor& qkv, // [token_num, 3, num_head, head_dim] ([token_num, num_head + 2 * + // gqa_group_size, head_dim] if GQA) + const paddle::Tensor& seq_lens, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& padding_offsets, + const paddle::Tensor& cum_offsets, + const paddle::Tensor& block_tables, + const paddle::optional& rotary_embs, + const paddle::optional& qkv_out_scales, + const paddle::optional& qkv_biases, + const paddle::optional& cache_k_scale, + const paddle::optional& cache_v_scale, + const paddle::optional& cache_k_zp, + const paddle::optional& cache_v_zp, + const std::string& cache_quant_type_str, + const bool use_neox_rotary_style, + const int max_seq_len, + const int num_heads, + const int kv_num_heads, + const int head_size, + cudaStream_t& stream, + paddle::Tensor* qkv_out, + paddle::Tensor* key_cache_out, + paddle::Tensor* value_cache_out); \ No newline at end of file diff --git a/csrc/gpu/append_attn/encoder_write_cache_with_rope_bfloat16_bfloat16_kernel.cu b/csrc/gpu/append_attn/encoder_write_cache_with_rope_bfloat16_bfloat16_kernel.cu new file mode 100644 index 000000000000..d359e2a5e204 --- /dev/null +++ b/csrc/gpu/append_attn/encoder_write_cache_with_rope_bfloat16_bfloat16_kernel.cu @@ -0,0 +1,42 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "encoder_write_cache_with_rope_kernel.h" + +template void EncoderWriteCacheWithRopeKernel(const paddle::Tensor& qkv, // [token_num, 3, num_head, head_dim] ([token_num, num_head + 2 * gqa_group_size, head_dim] if GQA) + const paddle::Tensor& seq_lens_this_time, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& seq_lens_decoder, + const paddle::Tensor& padding_offsets, + const paddle::Tensor& cum_offsets, + const paddle::Tensor& block_tables, + const paddle::Tensor& batch_ids, + const paddle::Tensor& tile_ids, + const paddle::optional& rotary_embs, + const paddle::optional& qkv_out_scales, + const paddle::optional& qkv_biases, + const paddle::optional& cache_k_scale, + const paddle::optional& cache_v_scale, + const paddle::optional& cache_k_zp, + const paddle::optional& cache_v_zp, + const std::string& cache_quant_type_str, + const int num_blocks, + const int max_seq_len, + const int num_heads, + const int kv_num_heads, + const int head_dim, + const bool use_neox_style, + cudaStream_t& stream, + paddle::Tensor *qkv_out, + paddle::Tensor *key_cache_out, + paddle::Tensor *value_cache_out); \ No newline at end of file diff --git a/csrc/gpu/append_attn/encoder_write_cache_with_rope_bfloat16_int_kernel.cu b/csrc/gpu/append_attn/encoder_write_cache_with_rope_bfloat16_int_kernel.cu new file mode 100644 index 000000000000..642c78322c20 --- /dev/null +++ b/csrc/gpu/append_attn/encoder_write_cache_with_rope_bfloat16_int_kernel.cu @@ -0,0 +1,42 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "encoder_write_cache_with_rope_kernel.h" + +template void EncoderWriteCacheWithRopeKernel(const paddle::Tensor& qkv, // [token_num, 3, num_head, head_dim] ([token_num, num_head + 2 * gqa_group_size, head_dim] if GQA) + const paddle::Tensor& seq_lens_this_time, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& seq_lens_decoder, + const paddle::Tensor& padding_offsets, + const paddle::Tensor& cum_offsets, + const paddle::Tensor& block_tables, + const paddle::Tensor& batch_ids, + const paddle::Tensor& tile_ids, + const paddle::optional& rotary_embs, + const paddle::optional& qkv_out_scales, + const paddle::optional& qkv_biases, + const paddle::optional& cache_k_scale, + const paddle::optional& cache_v_scale, + const paddle::optional& cache_k_zp, + const paddle::optional& cache_v_zp, + const std::string& cache_quant_type_str, + const int num_blocks, + const int max_seq_len, + const int num_heads, + const int kv_num_heads, + const int head_dim, + const bool use_neox_style, + cudaStream_t& stream, + paddle::Tensor *qkv_out, + paddle::Tensor *key_cache_out, + paddle::Tensor *value_cache_out); \ No newline at end of file diff --git a/csrc/gpu/append_attn/encoder_write_cache_with_rope_impl.cuh b/csrc/gpu/append_attn/encoder_write_cache_with_rope_impl.cuh new file mode 100644 index 000000000000..51bf255577ce --- /dev/null +++ b/csrc/gpu/append_attn/encoder_write_cache_with_rope_impl.cuh @@ -0,0 +1,1740 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include "helper.h" +#include "mem_util.cuh" +#include "mma_tensor_op.cuh" +#include "utils.cuh" + +template +__global__ void VariableLengthRotaryKernel( + const int *qkv, + const float *cos_emb, // [1, 1, seq_len, dim_head / 2] + const float *sin_emb, + const int *padding_offsets, + const int *seq_lens, + const int *seq_lens_decoder, + const float *qkv_out_scales, // [3, num_head, dim_head] + const T *qkv_biases, // [3, num_head, dim_head] + T *qkv_out, + const int64_t elem_cnt, + const int num_head, + const int seq_len, + const int last_dim) { + using LoadT = AlignedVector; + using LoadBiasT = AlignedVector; + using LoadScaleT = AlignedVector; + constexpr int HalfVecSize = VecSize / 2; + using LoadEmbT = AlignedVector; + LoadT src_vec; + LoadBiasT bias_vec; + LoadScaleT out_scale_vec; + LoadEmbT cos_emb_vec; + LoadEmbT sin_emb_vec; + int64_t global_thread_idx = blockDim.x * blockIdx.x + threadIdx.x; + const int half_lastdim = last_dim / 2; + const int hidden_size = num_head * last_dim; + const int offset = 3 * hidden_size; + for (int64_t linear_index = global_thread_idx * VecSize, + step = gridDim.x * blockDim.x * VecSize; + linear_index < elem_cnt; + linear_index += step) { + const int token_idx = linear_index / offset; + const int ori_token_idx = token_idx + padding_offsets[token_idx]; + const int ori_bi = ori_token_idx / seq_len; + if (seq_lens && seq_lens[ori_bi] == 0) continue; + const int bias = linear_index % offset; + const int qkv_id = bias / hidden_size; + const int qkv_bias = bias % hidden_size; + const int hi = qkv_bias / last_dim; + const int h_bias = qkv_bias % last_dim; + + const int ori_seq_id = ori_token_idx % seq_len + seq_lens_decoder[ori_bi]; + + const int emb_idx = ori_seq_id * half_lastdim + h_bias / 2; + const int bias_idx = qkv_id * hidden_size + hi * last_dim + h_bias; + const int64_t base_idx = token_idx * 3 * hidden_size + bias_idx; + Load(&qkv[base_idx], &src_vec); + if (qkv_biases) { + Load(&qkv_biases[bias_idx], &bias_vec); + } + Load(&qkv_out_scales[bias_idx], &out_scale_vec); + if (qkv_id < 2) { + Load(&cos_emb[emb_idx], &cos_emb_vec); + Load(&sin_emb[emb_idx], &sin_emb_vec); + } +#pragma unroll + for (int i = 0; i < HalfVecSize; i++) { + float input_left = static_cast(src_vec[2 * i]); + float input_right = static_cast(src_vec[2 * i + 1]); + // dequant + bias_add + input_left = qkv_biases ? input_left * out_scale_vec[2 * i] + + static_cast(bias_vec[2 * i]) : input_left * out_scale_vec[2 * i]; + input_right = qkv_biases ? input_right * out_scale_vec[2 * i + 1] + + static_cast(bias_vec[2 * i + 1]) : input_right * out_scale_vec[2 * i + 1]; + if (qkv_id < 2) { // qk rope + const float cos_tmp = cos_emb_vec[i]; + const float sin_tmp = sin_emb_vec[i]; + bias_vec[2 * i] = + static_cast(input_left * cos_tmp - input_right * sin_tmp); + bias_vec[2 * i + 1] = + static_cast(input_right * cos_tmp + input_left * sin_tmp); + } else { + bias_vec[2 * i] = static_cast(input_left); + bias_vec[2 * i + 1] = static_cast(input_right); + } + } + Store(bias_vec, &qkv_out[base_idx]); + } +} + +template +__global__ void VariableLengthRotaryKernel( + const T *qkv, + const float *cos_emb, // [1, 1, seq_len, dim_head / 2] + const float *sin_emb, + const int *padding_offsets, + const int *seq_lens, + const int *seq_lens_decoder, + T *qkv_out, + const int64_t elem_cnt, + const int num_head, + const int seq_len, + const int last_dim) { + using LoadT = AlignedVector; + constexpr int HalfVecSize = VecSize / 2; + using LoadEmbT = AlignedVector; + LoadT src_vec; + LoadEmbT cos_emb_vec; + LoadEmbT sin_emb_vec; + int64_t global_thread_idx = blockDim.x * blockIdx.x + threadIdx.x; + const int half_lastdim = last_dim / 2; + const int hidden_size = num_head * last_dim; + const int offset = 2 * hidden_size; + for (int64_t linear_index = global_thread_idx * VecSize, + step = gridDim.x * blockDim.x * VecSize; + linear_index < elem_cnt; + linear_index += step) { + const int token_idx = linear_index / offset; + const int ori_token_idx = token_idx + padding_offsets[token_idx]; + const int ori_bi = ori_token_idx / seq_len; + if (seq_lens && seq_lens[ori_bi] == 0) continue; + const int bias = linear_index % offset; + const int qkv_id = bias / hidden_size; + const int qkv_bias = bias % hidden_size; + const int hi = qkv_bias / last_dim; + const int h_bias = qkv_bias % last_dim; + + const int ori_seq_id = ori_token_idx % seq_len + seq_lens_decoder[ori_bi]; + + const int emb_idx = ori_seq_id * half_lastdim + h_bias / 2; + const int64_t base_idx = token_idx * 3 * hidden_size + + qkv_id * hidden_size + hi * last_dim + h_bias; + Load(&qkv[base_idx], &src_vec); + Load(&cos_emb[emb_idx], &cos_emb_vec); + Load(&sin_emb[emb_idx], &sin_emb_vec); +#pragma unroll + for (int i = 0; i < HalfVecSize; i++) { + const float input_left = static_cast(src_vec[2 * i]); + const float input_right = static_cast(src_vec[2 * i + 1]); + const float cos_tmp = cos_emb_vec[i]; + const float sin_tmp = sin_emb_vec[i]; + src_vec[2 * i] = + static_cast(input_left * cos_tmp - input_right * sin_tmp); + src_vec[2 * i + 1] = + static_cast(input_right * cos_tmp + input_left * sin_tmp); + } + Store(src_vec, &qkv_out[base_idx]); + } +} + +template +__global__ void NeoxVariableLengthRotaryKernel( + const int *qkv, + const float *cos_emb, // [1, 1, seq_len, dim_head / 2] + const float *sin_emb, + const int *padding_offsets, + const int *seq_lens, + const int *seq_lens_decoder, + const float *qkv_out_scales, // [3, num_head, dim_head] + const T *qkv_biases, // [3, num_head, dim_head] + T *qkv_out, + const int64_t elem_cnt, + const int num_head, + const int seq_len, + const int last_dim) { + using LoadT = AlignedVector; + using LoadBiasT = AlignedVector; + using LoadScaleT = AlignedVector; + using LoadEmbT = AlignedVector; + LoadT left_vec; + LoadT right_vec; + LoadBiasT left_bias_vec; + LoadBiasT right_bias_vec; + LoadScaleT left_out_scale_vec; + LoadScaleT right_out_scale_vec; + LoadEmbT cos_emb_vec; + LoadEmbT sin_emb_vec; + int64_t global_thread_idx = blockDim.x * blockIdx.x + threadIdx.x; + const int half_lastdim = last_dim / 2; + const int hidden_size = num_head * half_lastdim; + const int full_hidden_size = num_head * last_dim; + const int offset = 3 * hidden_size; + for (int64_t linear_index = global_thread_idx * VecSize, + step = gridDim.x * blockDim.x * VecSize; + linear_index < elem_cnt; + linear_index += step) { + const int token_idx = linear_index / offset; + const int ori_token_idx = token_idx + padding_offsets[token_idx]; + const int ori_bi = ori_token_idx / seq_len; + if (seq_lens && seq_lens[ori_bi] == 0) continue; + const int bias = linear_index % offset; + const int qkv_id = bias / hidden_size; + const int qkv_bias = bias % hidden_size; + const int hi = qkv_bias / half_lastdim; + const int h_bias = qkv_bias % half_lastdim; + + const int ori_seq_id = ori_token_idx % seq_len + seq_lens_decoder[ori_bi]; + + const int emb_idx = ori_seq_id * last_dim + h_bias; + const int bias_idx_left = + qkv_id * full_hidden_size + hi * last_dim + h_bias; + const int bias_idx_right = bias_idx_left + half_lastdim; + const int base_idx_left = token_idx * 3 * full_hidden_size + bias_idx_left; + const int base_idx_right = base_idx_left + half_lastdim; + Load(&qkv[base_idx_left], &left_vec); + Load(&qkv[base_idx_right], &right_vec); + if (qkv_biases) { + Load(&qkv_biases[bias_idx_left], &left_bias_vec); + Load(&qkv_biases[bias_idx_right], &right_bias_vec); + } + Load(&qkv_out_scales[bias_idx_left], + &left_out_scale_vec); + Load(&qkv_out_scales[bias_idx_right], + &right_out_scale_vec); + if (qkv_id < 2) { + Load(&cos_emb[emb_idx], &cos_emb_vec); + Load(&sin_emb[emb_idx], &sin_emb_vec); + } +#pragma unroll + for (int i = 0; i < VecSize; i++) { + float input_left = static_cast(left_vec[i]); + float input_right = static_cast(right_vec[i]); + // dequant + bias_add + input_left = qkv_biases ? input_left * left_out_scale_vec[i] + + static_cast(left_bias_vec[i]) : input_left * left_out_scale_vec[i]; + input_right = qkv_biases ? input_right * right_out_scale_vec[i] + + static_cast(right_bias_vec[i]) : input_right * right_out_scale_vec[i]; + if (qkv_id < 2) { // qk rope + const float cos_tmp = cos_emb_vec[i]; + const float sin_tmp = sin_emb_vec[i]; + left_bias_vec[i] = + static_cast(input_left * cos_tmp - input_right * sin_tmp); + right_bias_vec[i] = + static_cast(input_right * cos_tmp + input_left * sin_tmp); + } else { + left_bias_vec[i] = static_cast(input_left); + right_bias_vec[i] = static_cast(input_right); + } + } + Store(left_bias_vec, &qkv_out[base_idx_left]); + Store(right_bias_vec, &qkv_out[base_idx_right]); + } +} + +template +__global__ void NeoxVariableLengthRotaryKernel( + const T *qkv, + const float *cos_emb, // [1, 1, seq_len, dim_head / 2] + const float *sin_emb, + const int *padding_offsets, + const int *seq_lens, + const int *seq_lens_decoder, + T *qkv_out, + const int64_t elem_cnt, + const int num_head, + const int seq_len, + const int last_dim) { + using LoadT = AlignedVector; + using LoadEmbT = AlignedVector; + LoadT left_vec; + LoadT right_vec; + LoadEmbT cos_emb_vec; + LoadEmbT sin_emb_vec; + int64_t global_thread_idx = blockDim.x * blockIdx.x + threadIdx.x; + const int half_lastdim = last_dim / 2; + const int hidden_size = num_head * half_lastdim; + const int full_hidden_size = num_head * last_dim; + const int offset = 2 * hidden_size; + for (int64_t linear_index = global_thread_idx * VecSize, + step = gridDim.x * blockDim.x * VecSize; + linear_index < elem_cnt; + linear_index += step) { + const int token_idx = linear_index / offset; + const int ori_token_idx = token_idx + padding_offsets[token_idx]; + const int ori_bi = ori_token_idx / seq_len; + if (seq_lens && seq_lens[ori_bi] == 0) continue; + const int bias = linear_index % offset; + const int qkv_id = bias / hidden_size; + const int qkv_bias = bias % hidden_size; + const int hi = qkv_bias / half_lastdim; + const int h_bias = qkv_bias % half_lastdim; + + const int ori_seq_id = ori_token_idx % seq_len + seq_lens_decoder[ori_bi]; + + const int emb_idx = ori_seq_id * last_dim + h_bias; + const int base_idx_left = token_idx * 3 * full_hidden_size + + qkv_id * full_hidden_size + hi * last_dim + + h_bias; + const int base_idx_right = base_idx_left + half_lastdim; + + Load(&qkv[base_idx_left], &left_vec); + Load(&qkv[base_idx_right], &right_vec); + Load(&cos_emb[emb_idx], &cos_emb_vec); + Load(&sin_emb[emb_idx], &sin_emb_vec); +#pragma unroll + for (int i = 0; i < VecSize; i++) { + const float input_left = static_cast(left_vec[i]); + const float input_right = static_cast(right_vec[i]); + const float cos_tmp = cos_emb_vec[i]; + const float sin_tmp = sin_emb_vec[i]; + left_vec[i] = + static_cast(input_left * cos_tmp - input_right * sin_tmp); + right_vec[i] = + static_cast(input_right * cos_tmp + input_left * sin_tmp); + } + Store(left_vec, &qkv_out[base_idx_left]); + Store(right_vec, &qkv_out[base_idx_right]); + } +} + +template +__global__ void GQAVariableLengthRotaryKernel( + const int *qkv, + const float *cos_emb, // [1, 1, seq_len, dim_head / 2] + const float *sin_emb, + const int *padding_offsets, + const int *seq_lens, + const int *seq_lens_decoder, + const float *qkv_out_scales, // [3, q_num_head, dim_head] + const T *qkv_biases, // [3, q_num_head, dim_head] + T *qkv_out, + const int64_t elem_cnt, + const int q_num_head, + const int kv_num_head, + const int seq_len, + const int last_dim) { + using LoadT = AlignedVector; + using LoadBiasT = AlignedVector; + using LoadScaleT = AlignedVector; + constexpr int HalfVecSize = VecSize / 2; + using LoadEmbT = AlignedVector; + LoadT src_vec; + LoadBiasT bias_vec; + LoadScaleT out_scale_vec; + LoadEmbT cos_emb_vec; + LoadEmbT sin_emb_vec; + int64_t global_thread_idx = blockDim.x * blockIdx.x + threadIdx.x; + const int half_lastdim = last_dim / 2; + const int offset = (q_num_head + 2 * kv_num_head) * last_dim; + for (int64_t linear_index = global_thread_idx * VecSize, + step = gridDim.x * blockDim.x * VecSize; + linear_index < elem_cnt; + linear_index += step) { + const int token_idx = linear_index / offset; + const int ori_token_idx = token_idx + padding_offsets[token_idx]; + const int ori_bi = ori_token_idx / seq_len; + if (seq_lens[ori_bi] == 0) continue; + const int bias = linear_index % offset; + const int hi = bias / last_dim; + const int h_bias = bias % last_dim; + + const int ori_seq_id = ori_token_idx % seq_len + seq_lens_decoder[ori_bi]; + + const int64_t emb_idx = ori_seq_id * half_lastdim + h_bias / 2; + const int64_t bias_idx = hi * last_dim + h_bias; + const int64_t base_idx = token_idx * offset + bias_idx; + Load(&qkv[base_idx], &src_vec); + if (qkv_biases) { + Load(&qkv_biases[bias_idx], &bias_vec); + } + Load(&qkv_out_scales[bias_idx], &out_scale_vec); + if (hi < q_num_head + kv_num_head) { + Load(&cos_emb[emb_idx], &cos_emb_vec); + Load(&sin_emb[emb_idx], &sin_emb_vec); + } +#pragma unroll + for (int i = 0; i < HalfVecSize; i++) { + float input_left = static_cast(src_vec[2 * i]); + float input_right = static_cast(src_vec[2 * i + 1]); + // dequant + bias_add + input_left = qkv_biases ? input_left * out_scale_vec[2 * i] + + static_cast(bias_vec[2 * i]) : input_left * out_scale_vec[2 * i]; + input_right = qkv_biases ? input_right * out_scale_vec[2 * i + 1] + + static_cast(bias_vec[2 * i + 1]) : input_right * out_scale_vec[2 * i + 1]; + if (hi < q_num_head + kv_num_head) { // qk rope + const float cos_tmp = cos_emb_vec[i]; + const float sin_tmp = sin_emb_vec[i]; + bias_vec[2 * i] = + static_cast(input_left * cos_tmp - input_right * sin_tmp); + bias_vec[2 * i + 1] = + static_cast(input_right * cos_tmp + input_left * sin_tmp); + } else { + bias_vec[2 * i] = static_cast(input_left); + bias_vec[2 * i + 1] = static_cast(input_right); + } + } + Store(bias_vec, &qkv_out[base_idx]); + } +} + +template +__global__ void GQAVariableLengthRotaryKernel( + const T *qkv, + const float *cos_emb, // [1, 1, seq_len, dim_head / 2] + const float *sin_emb, + const int *padding_offsets, + const int *seq_lens, + const int *seq_lens_decoder, + T *qkv_out, + const int64_t elem_cnt, + const int q_num_head, + const int kv_num_head, + const int seq_len, + const int last_dim) { + using LoadT = AlignedVector; + constexpr int HalfVecSize = VecSize / 2; + using LoadEmbT = AlignedVector; + LoadT src_vec; + LoadEmbT cos_emb_vec; + LoadEmbT sin_emb_vec; + int64_t global_thread_idx = blockDim.x * blockIdx.x + threadIdx.x; + const int half_lastdim = last_dim / 2; + const int offset = (q_num_head + kv_num_head) * last_dim; + for (int64_t linear_index = global_thread_idx * VecSize, + step = gridDim.x * blockDim.x * VecSize; + linear_index < elem_cnt; + linear_index += step) { + const int token_idx = linear_index / offset; + const int ori_token_idx = token_idx + padding_offsets[token_idx]; + const int ori_bi = ori_token_idx / seq_len; + if (seq_lens[ori_bi] == 0) continue; + const int bias = linear_index % offset; + const int hi = bias / last_dim; + const int h_bias = bias % last_dim; + + const int ori_seq_id = ori_token_idx % seq_len + seq_lens_decoder[ori_bi]; + + const int64_t emb_idx = ori_seq_id * half_lastdim + h_bias / 2; + // [token_num, q_num_head + 2 * kv_num_head, last_dim] + const int64_t base_idx = + token_idx * (q_num_head + 2 * kv_num_head) * last_dim + hi * last_dim + + h_bias; + Load(&qkv[base_idx], &src_vec); + Load(&cos_emb[emb_idx], &cos_emb_vec); + Load(&sin_emb[emb_idx], &sin_emb_vec); +#pragma unroll + for (int i = 0; i < HalfVecSize; i++) { + const float input_left = static_cast(src_vec[2 * i]); + const float input_right = static_cast(src_vec[2 * i + 1]); + const float cos_tmp = cos_emb_vec[i]; + const float sin_tmp = sin_emb_vec[i]; + src_vec[2 * i] = + static_cast(input_left * cos_tmp - input_right * sin_tmp); + src_vec[2 * i + 1] = + static_cast(input_right * cos_tmp + input_left * sin_tmp); + } + Store(src_vec, &qkv_out[base_idx]); + } +} + +template +__global__ void GQANeoxVariableLengthRotaryKernel( + const int *qkv, + const float *cos_emb, // [1, 1, seq_len, dim_head / 2] + const float *sin_emb, + const int *padding_offsets, + const int *seq_lens, + const int *seq_lens_decoder, + const float *qkv_out_scales, // [3, q_num_head, dim_head] + const T *qkv_biases, // [3, q_num_head, dim_head] + T *qkv_out, + const int64_t elem_cnt, + const int q_num_head, + const int kv_num_head, + const int seq_len, + const int last_dim) { + using LoadT = AlignedVector; + using LoadBiasT = AlignedVector; + using LoadScaleT = AlignedVector; + using LoadEmbT = AlignedVector; + LoadT left_vec; + LoadT right_vec; + LoadBiasT left_bias_vec; + LoadBiasT right_bias_vec; + LoadScaleT left_out_scale_vec; + LoadScaleT right_out_scale_vec; + LoadEmbT cos_emb_vec; + LoadEmbT sin_emb_vec; + int64_t global_thread_idx = blockDim.x * blockIdx.x + threadIdx.x; + const int half_lastdim = last_dim / 2; + const int offset = (q_num_head + 2 * kv_num_head) * half_lastdim; + for (int64_t linear_index = global_thread_idx * VecSize, + step = gridDim.x * blockDim.x * VecSize; + linear_index < elem_cnt; + linear_index += step) { + const int token_idx = linear_index / offset; + const int ori_token_idx = token_idx + padding_offsets[token_idx]; + const int ori_bi = ori_token_idx / seq_len; + if (seq_lens && seq_lens[ori_bi] == 0) continue; + const int bias = linear_index % offset; + const int hi = bias / half_lastdim; + const int h_bias = bias % half_lastdim; + + const int ori_seq_id = ori_token_idx % seq_len + seq_lens_decoder[ori_bi]; + + const int emb_idx = ori_seq_id * last_dim + h_bias; + const int bias_idx_left = hi * last_dim + h_bias; + const int bias_idx_right = bias_idx_left + half_lastdim; + const int base_idx_left = token_idx * (q_num_head + 2 * kv_num_head) * last_dim + bias_idx_left; + const int base_idx_right = base_idx_left + half_lastdim; + Load(&qkv[base_idx_left], &left_vec); + Load(&qkv[base_idx_right], &right_vec); + if (qkv_biases) { + Load(&qkv_biases[bias_idx_left], &left_bias_vec); + Load(&qkv_biases[bias_idx_right], &right_bias_vec); + } + Load(&qkv_out_scales[bias_idx_left], + &left_out_scale_vec); + Load(&qkv_out_scales[bias_idx_right], + &right_out_scale_vec); + if (hi < (q_num_head + kv_num_head)) { + Load(&cos_emb[emb_idx], &cos_emb_vec); + Load(&sin_emb[emb_idx], &sin_emb_vec); + } +#pragma unroll + for (int i = 0; i < VecSize; i++) { + float input_left = static_cast(left_vec[i]); + float input_right = static_cast(right_vec[i]); + // dequant + bias_add + input_left = qkv_biases ? input_left * left_out_scale_vec[i] + + static_cast(left_bias_vec[i]) : input_left * left_out_scale_vec[i]; + input_right = qkv_biases ? input_right * right_out_scale_vec[i] + + static_cast(right_bias_vec[i]) : input_right * right_out_scale_vec[i]; + if (hi < (q_num_head + kv_num_head)) { // qk rope + const float cos_tmp = cos_emb_vec[i]; + const float sin_tmp = sin_emb_vec[i]; + left_bias_vec[i] = + static_cast(input_left * cos_tmp - input_right * sin_tmp); + right_bias_vec[i] = + static_cast(input_right * cos_tmp + input_left * sin_tmp); + } else { + left_bias_vec[i] = static_cast(input_left); + right_bias_vec[i] = static_cast(input_right); + } + } + Store(left_bias_vec, &qkv_out[base_idx_left]); + Store(right_bias_vec, &qkv_out[base_idx_right]); + } +} + +template +__global__ void GQANeoxVariableLengthRotaryKernel( + const T *qkv, + const float *cos_emb, // [1, 1, seq_len, dim_head / 2] + const float *sin_emb, + const int *padding_offsets, + const int *seq_lens, + const int *seq_lens_decoder, + const float *qkv_out_scales, + const T *qkv_biases, + T *qkv_out, + const int64_t elem_cnt, + const int q_num_head, + const int kv_num_head, + const int seq_len, + const int last_dim) { + using LoadT = AlignedVector; + using LoadEmbT = AlignedVector; + LoadT left_vec; + LoadT right_vec; + LoadEmbT cos_emb_vec; + LoadEmbT sin_emb_vec; + int64_t global_thread_idx = blockDim.x * blockIdx.x + threadIdx.x; + const int half_lastdim = last_dim / 2; + const int offset = (q_num_head + kv_num_head) * half_lastdim; + for (int64_t linear_index = global_thread_idx * VecSize, + step = gridDim.x * blockDim.x * VecSize; + linear_index < elem_cnt; + linear_index += step) { + const int token_idx = linear_index / offset; + const int ori_token_idx = token_idx + padding_offsets[token_idx]; + const int ori_bi = ori_token_idx / seq_len; + if (seq_lens && seq_lens[ori_bi] == 0) continue; + const int bias = linear_index % offset; + const int hi = bias / half_lastdim; + const int h_bias = bias % half_lastdim; + + const int ori_seq_id = ori_token_idx % seq_len + seq_lens_decoder[ori_bi]; + + const int emb_idx = ori_seq_id * last_dim + h_bias; + const int base_idx_left = + token_idx * (q_num_head + 2 * kv_num_head) * last_dim + hi * last_dim + + h_bias; + const int base_idx_right = base_idx_left + half_lastdim; + + Load(&qkv[base_idx_left], &left_vec); + Load(&qkv[base_idx_right], &right_vec); + Load(&cos_emb[emb_idx], &cos_emb_vec); + Load(&sin_emb[emb_idx], &sin_emb_vec); +#pragma unroll + for (int i = 0; i < VecSize; i++) { + const float input_left = static_cast(left_vec[i]); + const float input_right = static_cast(right_vec[i]); + const float cos_tmp = cos_emb_vec[i]; + const float sin_tmp = sin_emb_vec[i]; + left_vec[i] = + static_cast(input_left * cos_tmp - input_right * sin_tmp); + right_vec[i] = + static_cast(input_right * cos_tmp + input_left * sin_tmp); + } + Store(left_vec, &qkv_out[base_idx_left]); + Store(right_vec, &qkv_out[base_idx_right]); + } +} + +template +__global__ void cache_kernel( + const T *__restrict__ qkv, // [num_tokens, num_heads + 2 * gqa_group_size, + // head_size] + T *__restrict__ key_cache, // [num_blocks, gqa_group_size, block_size, + // head_size] + T *__restrict__ value_cache, // [num_blocks, gqa_group_size, block_size, + // head_size] + const int *__restrict__ block_tables, // [bsz, max_blocks_per_seq] + const int *__restrict__ padding_offsets, // [num_tokens] + const int *__restrict__ seq_lens, // [bsz] + const int *__restrict__ seq_lens_decoder, // [bsz] + const int max_seq_len, + const int max_blocks_per_seq, + const int num_heads, + const int head_size, + const int block_size, + const uint32_t elem_cnt, + const int gqa_group_size) { + using LoadT = AlignedVector; + LoadT src_vec; + + uint32_t global_thread_idx = blockDim.x * blockIdx.x + threadIdx.x; + const uint32_t hidden_size = gqa_group_size * head_size; + const uint32_t offset = 2 * hidden_size; + for (uint32_t linear_index = global_thread_idx * VecSize, + step = gridDim.x * blockDim.x * VecSize; + linear_index < elem_cnt; + linear_index += step) { + const uint32_t token_idx = linear_index / offset; + const uint32_t bias = linear_index % offset; + const uint32_t qkv_id = bias / hidden_size; // skip q + const uint32_t qkv_bias = bias % hidden_size; + const uint32_t hi = qkv_bias / head_size; + const uint32_t h_bias = qkv_bias % head_size; + const uint32_t ori_token_idx = token_idx + padding_offsets[token_idx]; + const uint32_t ori_bi = ori_token_idx / max_seq_len; + if (seq_lens[ori_bi] == 0) continue; + const uint32_t ori_seq_id = + ori_token_idx % max_seq_len + seq_lens_decoder[ori_bi]; + + const int32_t *block_table_now = nullptr; + + block_table_now = block_tables + ori_bi * max_blocks_per_seq; + + const uint32_t block_idx = block_table_now[ori_seq_id / block_size]; + const uint32_t block_offset = ori_seq_id % block_size; + + const uint32_t tgt_idx = + block_idx * gqa_group_size * block_size * head_size + + hi * block_size * head_size + block_offset * head_size + h_bias; + const uint32_t ori_idx = + token_idx * (num_heads + 2 * gqa_group_size) * head_size + + num_heads * head_size + qkv_id * hidden_size + hi * head_size + h_bias; + Load(&qkv[ori_idx], &src_vec); + if (qkv_id == 0) { + Store(src_vec, &key_cache[tgt_idx]); + } else { + Store(src_vec, &value_cache[tgt_idx]); + } + } +} + + +template +__global__ void append_write_cache_kv_c8_qkv( + uint8_t *__restrict__ cache_k, + uint8_t *__restrict__ cache_v, + const T *__restrict__ qkv_input, + const T *__restrict__ cache_k_scales, + const T *__restrict__ cache_v_scales, + const int *__restrict__ batch_ids, + const int *__restrict__ tile_ids, + const int *__restrict__ seq_lens_this_time, + const int *__restrict__ seq_lens_decoder, + const int *__restrict__ padding_offsets, + const int *__restrict__ cum_offsets, + const int *__restrict__ block_tables, + const int max_seq_len, + const int max_block_num_per_seq, + const int q_num_heads, + const int kv_num_heads) { + constexpr uint32_t num_vecs_per_head = HEAD_DIM / num_elems_per_128b(); + constexpr uint32_t pad_len = BLOCK_SIZE; + const uint32_t btid = blockIdx.x, kv_head_idx = blockIdx.z; // !!! + const T cache_k_scale = cache_k_scales[kv_head_idx]; + const T cache_v_scale = cache_v_scales[kv_head_idx]; + const uint32_t tid = threadIdx.x, wid = threadIdx.y; + const uint32_t batch_id = batch_ids[btid]; + const uint32_t tile_id = tile_ids[btid]; + const uint32_t seq_len_this_time = seq_lens_this_time[batch_id]; + if (seq_len_this_time <= 0) { + return; + } + const int *block_table_now = nullptr; + + block_table_now = block_tables + batch_id * max_block_num_per_seq; + + const uint32_t num_rows_per_block = + NUM_WARPS * num_frags_z * 16; // BLOCK_SIZE + const uint32_t start_len = seq_lens_decoder[batch_id]; + const uint32_t bf_pad_len = start_len % pad_len; + const uint32_t start_len_pad = start_len - bf_pad_len; + const uint32_t end_len = start_len + seq_len_this_time; + // const uint32_t end_len_pad_16 = div_up(end_len, num_rows_per_block) * + // num_rows_per_block; + const uint32_t tile_start = start_len_pad + tile_id * num_rows_per_block; + uint32_t chunk_start = tile_start + wid * num_frags_z * 16 + tid / 8; + + // 前 start_len % pad_len 部分不搬, + const uint32_t start_token_idx = + batch_id * max_seq_len - cum_offsets[batch_id]; + const uint32_t kv_batch_stride = (q_num_heads + 2 * kv_num_heads) * HEAD_DIM; + const uint32_t kv_h_stride = HEAD_DIM; + __shared__ T k_smem_ori[num_rows_per_block * HEAD_DIM]; + __shared__ T v_smem_ori[num_rows_per_block * HEAD_DIM]; + + smem_t k_smem(k_smem_ori); + smem_t v_smem(v_smem_ori); + + uint32_t kv_smem_offset_w = smem_t::get_permuted_offset( + wid * num_frags_z * 16 + tid / 8, tid % 8); // 4 * 8 per warp + /* + 1 | 2 + —————— + 3 | 4 + */ + uint32_t k_smem_offset_r = smem_t::get_permuted_offset( + wid * num_frags_z * 16 + 8 * (tid / 16) + tid % 8, (tid % 16) / 8); + + /* + 1 | 3 + —————— + 2 | 4 transpose + */ + constexpr uint32_t num_frags_v = num_frags_y / NUM_WARPS; + uint32_t v_smem_offset_r = smem_t::get_permuted_offset( + tid % 16, wid * num_frags_v * 2 + tid / 16); + + // load kv gmem to smem + const uint32_t real_start_token_idx = start_token_idx - bf_pad_len + + tile_id * num_rows_per_block + + wid * num_frags_z * 16 + tid / 8; + uint32_t k_read_idx = real_start_token_idx * kv_batch_stride + + (q_num_heads + kv_head_idx) * kv_h_stride + + tid % 8 * num_elems_per_128b(); + uint32_t v_read_idx = + real_start_token_idx * kv_batch_stride + + (q_num_heads + kv_num_heads + kv_head_idx) * kv_h_stride + + tid % 8 * num_elems_per_128b(); +#pragma unroll + for (uint32_t fz = 0; fz < num_frags_z; ++fz) { +#pragma unroll + for (uint32_t j = 0; j < 4; ++j) { +#pragma unroll + for (uint32_t fy = 0; fy < num_frags_y / 4; + ++fy) { // (num_frags_y * 16) / (8 * num_elems_per_128b()) + if (chunk_start >= start_len && chunk_start < end_len) { + k_smem.load_128b_async( + kv_smem_offset_w, qkv_input + k_read_idx, chunk_start < end_len); + v_smem.load_128b_async( + kv_smem_offset_w, qkv_input + v_read_idx, chunk_start < end_len); + } + kv_smem_offset_w = + k_smem.advance_offset_by_column<8>(kv_smem_offset_w, fy); + k_read_idx += 8 * num_elems_per_128b(); + v_read_idx += 8 * num_elems_per_128b(); + } + kv_smem_offset_w = + k_smem.advance_offset_by_row<4, num_vecs_per_head>(kv_smem_offset_w) - + 2 * num_frags_y; + chunk_start += 4; + k_read_idx += + 4 * kv_batch_stride - 2 * num_frags_y * num_elems_per_128b(); + v_read_idx += + 4 * kv_batch_stride - 2 * num_frags_y * num_elems_per_128b(); + } + } + commit_group(); + wait_group<0>(); + __syncthreads(); + + // mask, quant, store + using LoadKVT = AlignedVector; + LoadKVT cache_vec1; + LoadKVT cache_vec2; + + uint32_t chunk_start_k = tile_start + wid * num_frags_z * 16 + tid / 4; + uint32_t kv_frag[4]; + int block_id = __ldg(&block_table_now[tile_start / BLOCK_SIZE]); + const uint32_t write_n_stride = kv_num_heads * BLOCK_SIZE * HEAD_DIM; + const uint32_t write_h_stride = BLOCK_SIZE * HEAD_DIM; + const uint32_t write_b_stride = HEAD_DIM; + const uint32_t write_d_stride = BLOCK_SIZE; + uint32_t k_write_idx = block_id * write_n_stride + + kv_head_idx * write_h_stride + + (wid * num_frags_z * 16 + tid / 4) * write_b_stride + + tid % 4 * 4; // 4 * int8 = 8 * int4 = 32bit +#pragma unroll + for (uint32_t fz = 0; fz < num_frags_z; ++fz) { + uint32_t k_write_idx_now_z = k_write_idx + fz * 16 * write_b_stride; +#pragma unroll + for (uint32_t fy = 0; fy < num_frags_y; ++fy) { + uint32_t k_write_idx_now = k_write_idx_now_z + + fy % 2 * 8 * write_b_stride + + fy / 2 * 32; // + fy % 2 * 16; + // load + k_smem.ldmatrix_m8n8x4(k_smem_offset_r, kv_frag); + // quant + T *k_frag_T = reinterpret_cast(kv_frag); + if (bf_pad_len != 0) { + Load(cache_k + k_write_idx_now, &cache_vec1); + Load(cache_k + k_write_idx_now + 16, &cache_vec2); + } +#pragma unroll + for (uint32_t v_id = 0; v_id < 8; ++v_id) { + uint8_t uint_quant_value; + if (chunk_start_k + (v_id / 4) * 8 >= start_len && + chunk_start_k + (v_id / 4) * 8 < end_len) { + float quant_value = + static_cast(cache_k_scale * k_frag_T[v_id]); + quant_value = roundWithTiesToEven(quant_value); + quant_value = quant_value > 127.0f ? 127.0f : quant_value; + quant_value = quant_value < -127.0f ? -127.0f : quant_value; + uint_quant_value = static_cast(quant_value + 127.0f); + } else { + uint_quant_value = 0; + } + if (bf_pad_len != 0) { + if (v_id < 4) { + cache_vec1[v_id] |= uint_quant_value; + } else { + cache_vec2[v_id % 4] |= uint_quant_value; + } + } else { + if (v_id < 4) { + cache_vec1[v_id] = uint_quant_value; + } else { + cache_vec2[v_id - 4] = uint_quant_value; + } + } + } + // store + Store(cache_vec1, cache_k + k_write_idx_now); + Store(cache_vec2, cache_k + k_write_idx_now + 16); + k_smem_offset_r = k_smem.advance_offset_by_column<2>(k_smem_offset_r, fy); + } + k_smem_offset_r = + k_smem.advance_offset_by_row<16, num_vecs_per_head>(k_smem_offset_r) - + 2 * num_frags_y; + chunk_start_k += 16; + } + + uint32_t chunk_start_v = tile_start + tid % 4 * 2; + uint32_t v_write_idx = block_id * write_n_stride + + kv_head_idx * write_h_stride + + (wid * num_frags_v * 16 + tid / 4) * write_d_stride + + tid % 4 * 4; // 4 * int8 = 8 * int4 = 32bit + const uint32_t num_frags_z_v = num_frags_z * NUM_WARPS; +#pragma unroll + for (uint32_t fy = 0; fy < num_frags_v; ++fy) { // !!! + uint32_t v_write_idx_now_v = v_write_idx + fy * 16 * write_d_stride; +#pragma unroll + for (uint32_t fz = 0; fz < num_frags_z_v; ++fz) { + uint32_t v_write_idx_now = v_write_idx_now_v + + fz % 2 * 8 * write_d_stride + + fz / 2 * 32; // + fz % 2 * 16; + // load + v_smem.ldmatrix_m8n8x4_trans(v_smem_offset_r, kv_frag); + // quant + T *v_frag_T = reinterpret_cast(kv_frag); + if (bf_pad_len != 0) { + Load(cache_v + v_write_idx_now, &cache_vec1); + Load(cache_v + v_write_idx_now + 16, &cache_vec2); + } +#pragma unroll + for (uint32_t v_id = 0; v_id < 8; ++v_id) { + uint8_t uint_quant_value; + if (chunk_start_v + v_id % 2 + (v_id % 4) / 2 * 8 >= start_len && + chunk_start_v + v_id % 2 + (v_id % 4) / 2 * 8 < end_len) { + float quant_value = + static_cast(cache_v_scale * v_frag_T[v_id]); + quant_value = roundWithTiesToEven(quant_value); + quant_value = quant_value > 127.0f ? 127.0f : quant_value; + quant_value = quant_value < -127.0f ? -127.0f : quant_value; + uint_quant_value = static_cast(quant_value + 127.0f); + // store now + } else { + uint_quant_value = 0; + } + if (bf_pad_len != 0) { + if (v_id < 4) { + cache_vec1[v_id] |= uint_quant_value; + } else { + cache_vec2[v_id % 4] |= uint_quant_value; + } + } else { + if (v_id < 4) { + cache_vec1[v_id] = uint_quant_value; + } else { + cache_vec2[v_id - 4] = uint_quant_value; + } + } + } + // store + Store(cache_vec1, cache_v + v_write_idx_now); + Store(cache_vec2, cache_v + v_write_idx_now + 16); + chunk_start_v += 16; + v_smem_offset_r = + k_smem.advance_offset_by_row<16, num_vecs_per_head>(v_smem_offset_r); + } + v_smem_offset_r = k_smem.advance_offset_by_column<2>( + v_smem_offset_r, wid * num_frags_v + fy) - + 16 * num_frags_z_v * num_vecs_per_head; + chunk_start_v -= 16 * num_frags_z_v; + } +} + +// Write Cache KV in Append +template +__global__ void append_write_cache_kv_c4_qkv( + uint8_t *__restrict__ cache_k, + uint8_t *__restrict__ cache_v, + const T *__restrict__ qkv_input, + const T *__restrict__ cache_k_scales, + const T *__restrict__ cache_v_scales, + const T *__restrict__ cache_k_zero_points, + const T *__restrict__ cache_v_zero_points, + const int *__restrict__ batch_ids, + const int *__restrict__ tile_ids, + const int *__restrict__ seq_lens_this_time, + const int *__restrict__ seq_lens_decoder, + const int *__restrict__ padding_offsets, + const int *__restrict__ cum_offsets, + const int *__restrict__ block_tables, + const int max_seq_len, + const int max_block_num_per_seq, + const int q_num_heads, + const int kv_num_heads) { + constexpr uint32_t num_vecs_per_head = HEAD_DIM / num_elems_per_128b(); + constexpr uint32_t pad_len = BLOCK_SIZE; + const uint32_t btid = blockIdx.x, kv_head_idx = blockIdx.z; // !!! + const uint32_t tid = threadIdx.x, wid = threadIdx.y; + const uint32_t batch_id = batch_ids[btid]; + const uint32_t tile_id = tile_ids[btid]; + const uint32_t seq_len_this_time = seq_lens_this_time[batch_id]; + if (seq_len_this_time <= 0) { + return; + } + const int *block_table_now = nullptr; + + block_table_now = block_tables + batch_id * max_block_num_per_seq; + + const uint32_t num_rows_per_block = + NUM_WARPS * num_frags_z * 16; // BLOCK_SIZE + const uint32_t start_len = seq_lens_decoder[batch_id]; + const uint32_t bf_pad_len = start_len % pad_len; + const uint32_t start_len_pad = start_len - bf_pad_len; + const uint32_t end_len = start_len + seq_len_this_time; + // const uint32_t end_len_pad_16 = div_up(end_len, num_rows_per_block) * + // num_rows_per_block; + const uint32_t tile_start = start_len_pad + tile_id * num_rows_per_block; + uint32_t chunk_start = tile_start + wid * num_frags_z * 16 + tid / 8; + + // 前 start_len % pad_len 部分不搬, + const uint32_t start_token_idx = + batch_id * max_seq_len - cum_offsets[batch_id]; + const uint32_t kv_batch_stride = (q_num_heads + 2 * kv_num_heads) * HEAD_DIM; + const uint32_t kv_h_stride = HEAD_DIM; + __shared__ T k_smem_ori[num_rows_per_block * HEAD_DIM]; + __shared__ T v_smem_ori[num_rows_per_block * HEAD_DIM]; + __shared__ T k_scale_smem[HEAD_DIM]; + __shared__ T v_scale_smem[HEAD_DIM]; + __shared__ T k_zero_point_smem[HEAD_DIM]; + __shared__ T v_zero_point_smem[HEAD_DIM]; + const T *cache_k_scale_now = cache_k_scales + kv_head_idx * HEAD_DIM; + const T *cache_k_zp_now = cache_k_zero_points + kv_head_idx * HEAD_DIM; + const T *cache_v_scale_now = cache_v_scales + kv_head_idx * HEAD_DIM; + const T *cache_v_zp_now = cache_v_zero_points + kv_head_idx * HEAD_DIM; +#pragma unroll + for (uint32_t i = wid * 32 + tid; i < HEAD_DIM; i += 128) { + k_scale_smem[i] = cache_k_scale_now[i]; + k_zero_point_smem[i] = cache_k_zp_now[i]; + v_scale_smem[i] = cache_v_scale_now[i]; + v_zero_point_smem[i] = cache_v_zp_now[i]; + } + + smem_t k_smem(k_smem_ori); + smem_t v_smem(v_smem_ori); + + uint32_t kv_smem_offset_w = smem_t::get_permuted_offset( + wid * num_frags_z * 16 + tid / 8, tid % 8); // 4 * 8 per warp + /* + 1 | 2 + —————— + 3 | 4 + */ + uint32_t k_smem_offset_r = smem_t::get_permuted_offset( + wid * num_frags_z * 16 + 8 * (tid / 16) + tid % 8, (tid % 16) / 8); + + /* + 1 | 3 + —————— + 2 | 4 transpose + */ + constexpr uint32_t num_frags_v = num_frags_y / NUM_WARPS; + uint32_t v_smem_offset_r = smem_t::get_permuted_offset( + tid % 16, + wid * num_frags_v * 2 + tid / 16); // wid * num_frags_v * 16 / 8 + + // load kv gmem to smem + const uint32_t real_start_token_idx = start_token_idx - bf_pad_len + + tile_id * num_rows_per_block + + wid * num_frags_z * 16 + tid / 8; + uint32_t k_read_idx = real_start_token_idx * kv_batch_stride + + (q_num_heads + kv_head_idx) * kv_h_stride + + tid % 8 * num_elems_per_128b(); + uint32_t v_read_idx = + real_start_token_idx * kv_batch_stride + + (q_num_heads + kv_num_heads + kv_head_idx) * kv_h_stride + + tid % 8 * num_elems_per_128b(); +#pragma unroll + for (uint32_t fz = 0; fz < num_frags_z; ++fz) { +#pragma unroll + for (uint32_t j = 0; j < 4; ++j) { +#pragma unroll + for (uint32_t fy = 0; fy < num_frags_y / 4; + ++fy) { // (num_frags_y * 16) / (8 * num_elems_per_128b()) + if (chunk_start >= start_len && chunk_start < end_len) { + k_smem + .load_128b_async( // can be kNoFill? + kv_smem_offset_w, + qkv_input + k_read_idx, + chunk_start < end_len); + v_smem + .load_128b_async( // can be kNoFill? + kv_smem_offset_w, + qkv_input + v_read_idx, + chunk_start < end_len); + } + kv_smem_offset_w = + k_smem.advance_offset_by_column<8>(kv_smem_offset_w, fy); + k_read_idx += 8 * num_elems_per_128b(); + v_read_idx += 8 * num_elems_per_128b(); + } + kv_smem_offset_w = + k_smem.advance_offset_by_row<4, num_vecs_per_head>(kv_smem_offset_w) - + 2 * num_frags_y; + k_read_idx += + 4 * kv_batch_stride - 2 * num_frags_y * num_elems_per_128b(); + v_read_idx += + 4 * kv_batch_stride - 2 * num_frags_y * num_elems_per_128b(); + chunk_start += 4; + } + } + commit_group(); + wait_group<0>(); + __syncthreads(); +#ifdef DEBUG_WRITE_C4 + if (tid == 28 && wid == 1 && kv_head_idx == 0) { + printf("k\n"); + for (uint32_t i = 0; i < BLOCK_SIZE; ++i) { + for (uint32_t j = 0; j < HEAD_DIM; ++j) { + printf("%f ", (float)k_smem_ori[i * HEAD_DIM + j]); + } + printf("\n"); + } + printf("k end\n"); + printf("v\n"); + for (uint32_t i = 0; i < BLOCK_SIZE; ++i) { + for (uint32_t j = 0; j < HEAD_DIM; ++j) { + printf("%f ", (float)v_smem_ori[i * HEAD_DIM + j]); + } + printf("\n"); + } + printf("v end\n"); + } + __syncthreads(); +#endif + + // mask, quant, store + T cache_k_scale_frag[num_frags_y][4]; + T cache_k_zp_frag[num_frags_y][4]; +#pragma unroll + for (uint32_t fy = 0; fy < num_frags_y; ++fy) { + *(reinterpret_cast(&cache_k_scale_frag[fy][0])) = + *(reinterpret_cast(&k_scale_smem[fy * 16]) + tid % 4); + *(reinterpret_cast(&cache_k_scale_frag[fy][2])) = + *(reinterpret_cast(&k_scale_smem[fy * 16]) + tid % 4 + 4); + *(reinterpret_cast(&cache_k_zp_frag[fy][0])) = + *(reinterpret_cast(&k_zero_point_smem[fy * 16]) + tid % 4); + *(reinterpret_cast(&cache_k_zp_frag[fy][2])) = + *(reinterpret_cast(&k_zero_point_smem[fy * 16]) + tid % 4 + + 4); + } + T cache_v_scale_frag[num_frags_y][2]; + T cache_v_zp_frag[num_frags_y][2]; +#pragma unroll + for (uint32_t fy = 0; fy < num_frags_y; ++fy) { + cache_v_scale_frag[fy][0] = v_scale_smem[fy * 16 + tid / 4]; + cache_v_scale_frag[fy][1] = v_scale_smem[fy * 16 + tid / 4 + 8]; + cache_v_zp_frag[fy][0] = v_zero_point_smem[fy * 16 + tid / 4]; + cache_v_zp_frag[fy][1] = v_zero_point_smem[fy * 16 + tid / 4 + 8]; + } + + using LoadKVT = AlignedVector; + LoadKVT cache_vec; + + uint32_t chunk_start_k = tile_start + wid * num_frags_z * 16 + tid / 4; + uint32_t kv_frag[4]; + int block_id = __ldg(&block_table_now[tile_start / BLOCK_SIZE]); + const uint32_t write_n_stride = kv_num_heads * BLOCK_SIZE * HEAD_DIM / 2; + const uint32_t write_h_stride = BLOCK_SIZE * HEAD_DIM / 2; + const uint32_t write_b_stride = HEAD_DIM / 2; + const uint32_t write_d_stride = BLOCK_SIZE / 2; + uint32_t k_write_idx = block_id * write_n_stride + + kv_head_idx * write_h_stride + + (wid * num_frags_z * 16 + tid / 4) * write_b_stride + + tid % 4 * 4; // 4 * int8 = 8 * int4 = 32bit +#pragma unroll + for (uint32_t fz = 0; fz < num_frags_z; ++fz) { + uint32_t k_write_idx_now_z = k_write_idx + fz * 16 * write_b_stride; +#pragma unroll + for (uint32_t fy = 0; fy < num_frags_y; ++fy) { + uint32_t k_write_idx_now = k_write_idx_now_z + + (fy % 4) / 2 * 8 * write_b_stride + + fy / 4 * 32 + fy % 2 * 16; + // load + k_smem.ldmatrix_m8n8x4(k_smem_offset_r, kv_frag); + // quant + T *k_frag_T = reinterpret_cast(kv_frag); + // bf_pad_len为0表示新块写入,每个块第一次写入的时候不或,前后补0 + if (bf_pad_len != 0) { + Load(cache_k + k_write_idx_now, &cache_vec); + } + +#ifdef DEBUG_WRITE_C4 + if (tid == 28 && wid == 1 && kv_head_idx == 0) { +#pragma unroll + for (uint32_t t_id = 0; t_id < 8; ++t_id) { + printf( + "fy: %d, k_smem_offset_r: %d, k_write_idx_now: %d, load " + "k_frag_T[%d] = %f, cache_vec[%d] = %f, " + "cache_k_scale_frag[%d][%d] = %f, cache_k_zp_frag[%d][%d] = %f\n", + (int)fy, + (int)k_smem_offset_r, + (int)k_write_idx_now, + (int)t_id, + (float)k_frag_T[t_id], + (int)t_id % 4, + (float)cache_vec[tid % 4], + (int)fy, + (int)tid % 4, + (float)cache_k_scale_frag[fy][tid % 4], + (int)fy, + (int)tid % 4, + (float)cache_k_zp_frag[fy][tid % 4]); + } + } + __syncthreads(); +#endif + +#pragma unroll + for (uint32_t v_id = 0; v_id < 4; ++v_id) { + float quant_value1, quant_value2; + uint8_t uint_quant_value1, uint_quant_value2; + if (chunk_start_k >= start_len && chunk_start_k < end_len) { + quant_value1 = + static_cast(cache_k_scale_frag[fy][v_id] * k_frag_T[v_id] + + cache_k_zp_frag[fy][v_id]); + quant_value1 = roundWithTiesToEven(quant_value1); + quant_value1 = quant_value1 > 7.0f ? 7.0f : quant_value1; + quant_value1 = quant_value1 < -8.0f ? -8.0f : quant_value1; + uint_quant_value1 = static_cast(quant_value1 + 8.0f); + } else { + uint_quant_value1 = 0; + } + if (chunk_start_k + 8 >= start_len && chunk_start_k + 8 < end_len) { + quant_value2 = static_cast(cache_k_scale_frag[fy][v_id] * + k_frag_T[v_id + 4] + + cache_k_zp_frag[fy][v_id]); + quant_value2 = roundWithTiesToEven(quant_value2); + quant_value2 = quant_value2 > 7.0f ? 7.0f : quant_value2; + quant_value2 = quant_value2 < -8.0f ? -8.0f : quant_value2; + uint_quant_value2 = static_cast(quant_value2 + 8.0f); + } else { + uint_quant_value2 = 0; + } + if (bf_pad_len != 0) { + cache_vec[v_id] |= + (uint_quant_value2 << 4) | (uint_quant_value1 & 0x0F); + } else { + cache_vec[v_id] = + (uint_quant_value2 << 4) | (uint_quant_value1 & 0x0F); + } + } +#ifdef DEBUG_WRITE_C4 + if (tid == 0 && wid == 0) { +#pragma unroll + for (uint32_t t_id = 0; t_id < 4; ++t_id) { + printf("fy: %d, k_write_idx_now: %d, cache_vec[%d] = %f\n", + (int)fy, + (int)k_write_idx_now, + (int)t_id, + (float)cache_vec[t_id]); + } + } + __syncthreads(); +#endif + // store + Store(cache_vec, cache_k + k_write_idx_now); + k_smem_offset_r = k_smem.advance_offset_by_column<2>(k_smem_offset_r, fy); + } + k_smem_offset_r = + k_smem.advance_offset_by_row<16, num_vecs_per_head>(k_smem_offset_r) - + 2 * num_frags_y; + chunk_start_k += 16; + } + + uint32_t chunk_start_v = tile_start + tid % 4 * 2; + uint32_t v_write_idx = block_id * write_n_stride + + kv_head_idx * write_h_stride + + (wid * num_frags_v * 16 + tid / 4) * write_d_stride + + tid % 4 * 4; // 4 * int8 = 8 * int4 = 32bit + const uint32_t num_frags_z_v = num_frags_z * NUM_WARPS; +#ifdef DEBUG_WRITE_C4 + if (tid == 28 && wid == 1 && kv_head_idx == 0) { + printf( + "chunk_start_v: %d, v_write_idx: %d, num_frags_v: %d, num_frags_z_v: " + "%d, v_smem_offset_r: %d\n", + (int)chunk_start_v, + (int)v_write_idx, + int(num_frags_v), + (int)num_frags_z_v, + (int)v_smem_offset_r); + } + __syncthreads(); +#endif +#pragma unroll + for (uint32_t fy = 0; fy < num_frags_v; ++fy) { // !!! + uint32_t v_write_idx_now_v = v_write_idx + fy * 16 * write_d_stride; +#pragma unroll + for (uint32_t fz = 0; fz < num_frags_z_v; ++fz) { + uint32_t v_write_idx_now = v_write_idx_now_v + + (fz % 4) / 2 * 8 * write_d_stride + + fz / 4 * 32 + fz % 2 * 16; + // load +#ifdef DEBUG_WRITE_C4 + if (tid == 28 && wid == 1 && kv_head_idx == 0) { + printf("fy: %d, fz: %d, v_smem_offset_r: %d\n", + (int)fy, + (int)fz, + (int)v_smem_offset_r); + } + __syncthreads(); +#endif + v_smem.ldmatrix_m8n8x4_trans(v_smem_offset_r, kv_frag); + // quant + T *v_frag_T = reinterpret_cast(kv_frag); +#ifdef DEBUG_WRITE_C4 + if (tid == 28 && wid == 1 && kv_head_idx == 0) { + printf("fy: %d, fz: %d, v_write_idx_now: %d\n", + (int)fy, + (int)fz, + (int)v_write_idx_now); + for (int tii = 0; tii < 8; ++tii) { + printf("kv_frag[%d]: %f ", (int)tii, (float)kv_frag[tii]); + } + printf("\n"); + } + __syncthreads(); +#endif + if (bf_pad_len != 0) { + Load(cache_v + v_write_idx_now, &cache_vec); + } +#pragma unroll + for (uint32_t v_id = 0; v_id < 4; ++v_id) { + float quant_value1, quant_value2; + uint8_t uint_quant_value1, uint_quant_value2; + if (chunk_start_v + v_id % 2 + v_id / 2 * 8 >= start_len && + chunk_start_v + v_id % 2 + v_id / 2 * 8 < end_len) { + quant_value1 = static_cast( + cache_v_scale_frag[wid * num_frags_v + fy][0] * v_frag_T[v_id] + + cache_v_zp_frag[wid * num_frags_v + fy][0]); + quant_value1 = roundWithTiesToEven(quant_value1); + quant_value1 = quant_value1 > 7.0f ? 7.0f : quant_value1; + quant_value1 = quant_value1 < -8.0f ? -8.0f : quant_value1; + uint_quant_value1 = static_cast(quant_value1 + 8.0f); + quant_value2 = + static_cast(cache_v_scale_frag[wid * num_frags_v + fy][1] * + v_frag_T[v_id + 4] + + cache_v_zp_frag[wid * num_frags_v + fy][1]); + quant_value2 = roundWithTiesToEven(quant_value2); + quant_value2 = quant_value2 > 7.0f ? 7.0f : quant_value2; + quant_value2 = quant_value2 < -8.0f ? -8.0f : quant_value2; + uint_quant_value2 = static_cast(quant_value2 + 8.0f); + } else { + uint_quant_value1 = 0; + uint_quant_value2 = 0; + } +#ifdef DEBUG_WRITE_C4 + if (tid == 28 && wid == 1 && kv_head_idx == 0) { + printf( + "v_frag_T[%d]: %f, v_frag_T[%d]: %f, cache_v_scale_frag[%d][0]: " + "%f, cache_v_scale_frag[%d][1]: %f, uint_quant_value1: %d, " + "uint_quant_value2: %d\n", + (int)v_id, + (float)v_frag_T[v_id], + (int)(v_id + 4), + (float)v_frag_T[v_id + 4], + (int)(wid * num_frags_v + fy), + (float)cache_v_scale_frag[wid * num_frags_v + fy][0], + (int)(wid * num_frags_v + fy), + (float)cache_v_scale_frag[wid * num_frags_v + fy][1], + (int)uint_quant_value1, + (int)uint_quant_value1); + } + __syncthreads(); +#endif + if (bf_pad_len != 0) { + cache_vec[v_id] |= + (uint_quant_value2 << 4) | (uint_quant_value1 & 0x0F); + } else { + cache_vec[v_id] = + (uint_quant_value2 << 4) | (uint_quant_value1 & 0x0F); + } + } + // store + Store(cache_vec, cache_v + v_write_idx_now); + chunk_start_v += 16; + v_smem_offset_r = + v_smem.advance_offset_by_row<16, num_vecs_per_head>(v_smem_offset_r); + } + v_smem_offset_r = v_smem.advance_offset_by_column<2>( + v_smem_offset_r, wid * num_frags_v + fy) - + 16 * num_frags_z_v * num_vecs_per_head; + chunk_start_v -= 16 * num_frags_z_v; + } +} + +template +void rotary_qk_variable(T *qkv_out, // [token_num, 3, num_head, dim_head] + const QKV_TYPE *qkv_input, // qkv + const float *qkv_out_scales, // [3, num_head, dim_head] + const T *qkv_bias, + const float *rotary_emb, // [2, 1, 1, seq_len, dim_head / 2] + const int *padding_offsets, + const int *seq_lens, + const int *seq_lens_decoder, + const int token_num, + const int head_num, + const int seq_len, + const int input_output_len, + const int dim_head, + const cudaStream_t& stream, + bool use_neox_style = false) { + int64_t elem_nums = qkv_out_scales ? token_num * 3 * head_num * dim_head : token_num * 2 * head_num * dim_head; // for all q k v + if (use_neox_style) { + elem_nums /= 2; + } + + constexpr int PackSize = 16 / sizeof(T); + const int pack_num = elem_nums / PackSize; + const int blocksize = 128; + int grid_size = 1; + GetNumBlocks<128>(pack_num, &grid_size); + if (!use_neox_style) { + const float *cos_emb = rotary_emb; + const float *sin_emb = rotary_emb + input_output_len * dim_head / 2; + if (qkv_out_scales) { + VariableLengthRotaryKernel + <<>>( + reinterpret_cast(qkv_input), + cos_emb, + sin_emb, + padding_offsets, + seq_lens, + seq_lens_decoder, + qkv_out_scales, + qkv_bias, + qkv_out, + elem_nums, + head_num, + seq_len, + dim_head); + } else { + VariableLengthRotaryKernel + <<>>( + reinterpret_cast(qkv_input), + cos_emb, + sin_emb, + padding_offsets, + seq_lens, + seq_lens_decoder, + qkv_out, + elem_nums, + head_num, + seq_len, + dim_head); + } + } else { + const float *cos_emb = rotary_emb; + const float *sin_emb = rotary_emb + input_output_len * dim_head; + if (qkv_out_scales) { + NeoxVariableLengthRotaryKernel + <<>>( + reinterpret_cast(qkv_input), + cos_emb, + sin_emb, + padding_offsets, + seq_lens, + seq_lens_decoder, + qkv_out_scales, + qkv_bias, + qkv_out, + elem_nums, + head_num, + seq_len, + dim_head); + } else { + NeoxVariableLengthRotaryKernel + <<>>( + reinterpret_cast(qkv_input), + cos_emb, + sin_emb, + padding_offsets, + seq_lens, + seq_lens_decoder, + qkv_out, + elem_nums, + head_num, + seq_len, + dim_head); + } + } +} + +template +void gqa_rotary_qk_variable(T *qkv_out, // [token_num, 3, num_head, dim_head] + const QKV_TYPE *qkv_input, // qkv + const float *qkv_out_scales, // [3, num_head, dim_head] + const T *qkv_bias, + const float *rotary_emb, // [2, 1, 1, seq_len, dim_head / 2] + const int *padding_offsets, + const int *seq_lens, + const int *seq_lens_decoder, + const int token_num, + const int num_heads, + const int kv_num_heads, + const int seq_len, + const int input_output_len, + const int dim_head, + const cudaStream_t& stream, + bool use_neox_style = false) { + int64_t elem_nums = qkv_out_scales ? token_num * (num_heads + 2 * kv_num_heads) * dim_head : token_num * (num_heads + kv_num_heads) * dim_head; // for all q k v + if (use_neox_style) { + elem_nums /= 2; + } + + constexpr int PackSize = 16 / sizeof(T); + const int pack_num = elem_nums / PackSize; + const int blocksize = 128; + int grid_size = 1; + GetNumBlocks<128>(pack_num, &grid_size); + + if (!use_neox_style) { + const float *cos_emb = rotary_emb; + const float *sin_emb = rotary_emb + input_output_len * dim_head / 2; + if (qkv_out_scales) { + GQAVariableLengthRotaryKernel + <<>>( + reinterpret_cast(qkv_input), + cos_emb, + sin_emb, + padding_offsets, + seq_lens, + seq_lens_decoder, + qkv_out_scales, + qkv_bias, + qkv_out, + elem_nums, + num_heads, + kv_num_heads, + seq_len, + dim_head); + } else { + GQAVariableLengthRotaryKernel + <<>>( + reinterpret_cast(qkv_input), + cos_emb, + sin_emb, + padding_offsets, + seq_lens, + seq_lens_decoder, + qkv_out, + elem_nums, + num_heads, + kv_num_heads, + seq_len, + dim_head); + } + } else { + const float *cos_emb = rotary_emb; + const float *sin_emb = rotary_emb + input_output_len * dim_head; + if (qkv_out_scales) { + GQANeoxVariableLengthRotaryKernel + <<>>( + reinterpret_cast(qkv_input), + cos_emb, + sin_emb, + padding_offsets, + seq_lens, + seq_lens_decoder, + qkv_out_scales, + qkv_bias, + qkv_out, + elem_nums, + num_heads, + kv_num_heads, + seq_len, + dim_head); + } else { + GQANeoxVariableLengthRotaryKernel + <<>>( + reinterpret_cast(qkv_input), + cos_emb, + sin_emb, + padding_offsets, + seq_lens, + seq_lens_decoder, + qkv_out_scales, + qkv_bias, + qkv_out, + elem_nums, + num_heads, + kv_num_heads, + seq_len, + dim_head); + } + } +} + +template +void CascadeAppendWriteCacheKVQKV(const paddle::Tensor& qkv, // [token_num, 3, num_head, head_dim] ([token_num, num_head + 2 * gqa_group_size, head_dim] if GQA) + const paddle::Tensor& block_table, + const paddle::Tensor& padding_offsets, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& seq_lens_decoder, + const int max_seq_len, + const int num_heads, + const int head_dim, + const int kv_num_heads, + cudaStream_t& stream, + paddle::Tensor *key_cache_out, + paddle::Tensor *value_cache_out) { + auto qkv_dims = qkv.dims(); + const int max_blocks_per_seq = block_table.dims()[1]; + const int num_tokens = qkv_dims[0]; + + const int32_t block_size = key_cache_out->dims()[2]; + const uint32_t elem_nums = num_tokens * 2 * kv_num_heads * head_dim; // just k and v + constexpr int PackSize = 16 / sizeof(T); + const int pack_num = elem_nums / PackSize; + const int blocksize = 128; + int grid_size = 1; + GetNumBlocks<128>(pack_num, &grid_size); + cache_kernel<<>>( + reinterpret_cast(const_cast(qkv.data())), + reinterpret_cast(key_cache_out->data()), + reinterpret_cast(value_cache_out->data()), + block_table.data(), + padding_offsets.data(), + seq_lens_encoder.data(), + seq_lens_decoder.data(), + max_seq_len, + max_blocks_per_seq, + num_heads, + head_dim, + block_size, + elem_nums, + kv_num_heads); +} + +template +void CascadeAppendWriteCacheKVC8QKV(const paddle::Tensor &cache_k, // [max_block_num, num_heads, block_size, head_dim] + const paddle::Tensor &cache_v, // [max_block_num, num_heads, head_dim, block_size] + const paddle::Tensor &qkv, // [token_num, num_heads, head_dim] + const paddle::Tensor &cache_k_scale, // [num_kv_heads, head_dim] + const paddle::Tensor &cache_v_scale, // [num_kv_heads, head_dim] + const paddle::Tensor &seq_lens_this_time, + const paddle::Tensor &seq_lens_decoder, + const paddle::Tensor &padding_offsets, + const paddle::Tensor &cum_offsets, + const paddle::Tensor &block_table, + const paddle::Tensor &batch_ids, + const paddle::Tensor &tile_ids_per_batch, + int num_blocks_x_cpu, + int max_seq_len, + int q_num_heads, + int kv_num_heads, + cudaStream_t& stream, + paddle::Tensor *cache_k_out, + paddle::Tensor *cache_v_out) { + const auto &qkv_dims = qkv.dims(); + const auto &cache_k_dims = cache_k.dims(); + const auto &cum_offsets_dims = cum_offsets.dims(); + const uint32_t token_num = qkv_dims[0]; + const uint32_t num_heads = kv_num_heads; + const uint32_t bsz = cum_offsets_dims[0]; + const int max_block_num_per_seq = block_table.dims()[1]; + + const uint32_t pad_len = BLOCK_SIZE; + + constexpr uint32_t num_warps = 4; + constexpr uint32_t num_frags_z = BLOCK_SIZE / 16 / num_warps; + constexpr uint32_t num_frags_y = HEAD_DIM / 16; + constexpr uint32_t num_row_per_block = num_warps * num_frags_z * 16; + + dim3 grids(num_blocks_x_cpu, 1, num_heads); + dim3 blocks(32, num_warps); + + const uint32_t smem_size = (BLOCK_SIZE * HEAD_DIM) * sizeof(T) * 2; + auto kernel_fn = append_write_cache_kv_c8_qkv; + // if (smem_size >= 48 * 1024) { + cudaFuncSetAttribute( + kernel_fn, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); + // } + kernel_fn<<>>( + cache_k_out->data(), + cache_v_out->data(), + qkv.data(), + cache_k_scale.data(), + cache_v_scale.data(), + batch_ids.data(), + tile_ids_per_batch.data(), + seq_lens_this_time.data(), + seq_lens_decoder.data(), + padding_offsets.data(), + cum_offsets.data(), + block_table.data(), + max_seq_len, + max_block_num_per_seq, + q_num_heads, + kv_num_heads + ); +} + +template +void CascadeAppendWriteCacheKVC4QKV(const paddle::Tensor &cache_k, // [max_block_num, num_heads, block_size, head_dim] + const paddle::Tensor &cache_v, // [max_block_num, num_heads, head_dim, block_size] + const paddle::Tensor &qkv, // [token_num, num_heads, head_dim] + const paddle::Tensor &cache_k_scale, // [num_kv_heads, head_dim] + const paddle::Tensor &cache_v_scale, // [num_kv_heads, head_dim] + const paddle::Tensor &cache_k_zp, // [num_kv_heads, head_dim] + const paddle::Tensor &cache_v_zp, // [num_kv_heads, head_dim] + const paddle::Tensor &seq_lens_this_time, + const paddle::Tensor &seq_lens_decoder, + const paddle::Tensor &padding_offsets, + const paddle::Tensor &cum_offsets, + const paddle::Tensor &block_table, + const paddle::Tensor &batch_ids, + const paddle::Tensor &tile_ids_per_batch, + int num_blocks_x_cpu, + int max_seq_len, + int q_num_heads, + int kv_num_heads, + cudaStream_t& stream, + paddle::Tensor *cache_k_out, + paddle::Tensor *cache_v_out) { + const auto &qkv_dims = qkv.dims(); + const auto &cache_k_dims = cache_k.dims(); + const auto &cum_offsets_dims = cum_offsets.dims(); + const uint32_t token_num = qkv_dims[0]; + const uint32_t num_heads = kv_num_heads; + const uint32_t bsz = cum_offsets_dims[0]; + const int max_block_num_per_seq = block_table.dims()[1]; + + const uint32_t pad_len = BLOCK_SIZE; + + constexpr uint32_t num_warps = 4; + constexpr uint32_t num_frags_z = BLOCK_SIZE / 16 / num_warps; + constexpr uint32_t num_frags_y = HEAD_DIM / 16; + constexpr uint32_t num_row_per_block = num_warps * num_frags_z * 16; + + dim3 grids(num_blocks_x_cpu, 1, num_heads); + dim3 blocks(32, num_warps); + + const uint32_t smem_size = (BLOCK_SIZE * HEAD_DIM) * sizeof(T) * 2 + HEAD_DIM * 4 * sizeof(T); + // VLOG(1) << "smem_size: " << smem_size / 1024 << "KB"; + auto kernel_fn = append_write_cache_kv_c4_qkv; + cudaFuncSetAttribute( + kernel_fn, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); + kernel_fn<<>>( + cache_k_out->data(), + cache_v_out->data(), + qkv.data(), + cache_k_scale.data(), + cache_v_scale.data(), + cache_k_zp.data(), + cache_v_zp.data(), + batch_ids.data(), + tile_ids_per_batch.data(), + seq_lens_this_time.data(), + seq_lens_decoder.data(), + padding_offsets.data(), + cum_offsets.data(), + block_table.data(), + max_seq_len, + max_block_num_per_seq, + q_num_heads, + kv_num_heads + ); +} \ No newline at end of file diff --git a/csrc/gpu/append_attn/encoder_write_cache_with_rope_kernel.h b/csrc/gpu/append_attn/encoder_write_cache_with_rope_kernel.h new file mode 100644 index 000000000000..4c7afeea007e --- /dev/null +++ b/csrc/gpu/append_attn/encoder_write_cache_with_rope_kernel.h @@ -0,0 +1,108 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once +#include "encoder_write_cache_with_rope_impl.cuh" + +template +void EncoderWriteCacheWithRopeKernel(const paddle::Tensor& qkv, // [token_num, 3, num_head, head_dim] ([token_num, num_head + 2 * gqa_group_size, head_dim] if GQA) + const paddle::Tensor& seq_lens_this_time, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& seq_lens_decoder, + const paddle::Tensor& padding_offsets, + const paddle::Tensor& cum_offsets, + const paddle::Tensor& block_tables, + const paddle::Tensor& batch_ids, + const paddle::Tensor& tile_ids, + const paddle::optional& rotary_embs, + const paddle::optional& qkv_out_scales, + const paddle::optional& qkv_biases, + const paddle::optional& cache_k_scale, + const paddle::optional& cache_v_scale, + const paddle::optional& cache_k_zp, + const paddle::optional& cache_v_zp, + const std::string& cache_quant_type_str, + const int num_blocks, + const int max_seq_len, + const int num_heads, + const int kv_num_heads, + const int head_dim, + const bool use_neox_style, + cudaStream_t& stream, + paddle::Tensor *qkv_out, + paddle::Tensor *key_cache_out, + paddle::Tensor *value_cache_out) { + auto qkv_dims = qkv.dims(); + const uint32_t token_num = qkv_dims[0]; + + if (num_heads == kv_num_heads) { + rotary_qk_variable( + qkv_out->data(), + qkv.data(), + qkv_out_scales? qkv_out_scales.get().data() : nullptr, + qkv_biases? qkv_biases.get().data(): nullptr, + rotary_embs.get().data(), + padding_offsets.data(), + seq_lens_encoder.data(), + seq_lens_decoder.data(), + token_num, + num_heads, + max_seq_len, + rotary_embs.get().dims()[2], + head_dim, + stream, + use_neox_style + ); + } else { + gqa_rotary_qk_variable( + qkv_out->data(), + qkv.data(), + qkv_out_scales? qkv_out_scales.get().data() : nullptr, + qkv_biases? qkv_biases.get().data(): nullptr, + rotary_embs.get().data(), + padding_offsets.data(), + seq_lens_encoder.data(), + seq_lens_decoder.data(), + token_num, + num_heads, + kv_num_heads, + max_seq_len, + rotary_embs.get().dims()[2], + head_dim, + stream, + use_neox_style + ); + } + const auto &cache_k_dims = key_cache_out->dims(); + const uint32_t block_size = cache_k_dims[2]; + if (cache_quant_type_str == "none") { + CascadeAppendWriteCacheKVQKV(*qkv_out, block_tables, padding_offsets, seq_lens_encoder, seq_lens_decoder, + max_seq_len, num_heads, head_dim, kv_num_heads, stream, key_cache_out, value_cache_out); + } else if (cache_quant_type_str == "cache_int8") { + DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, + {DISPATCH_BLOCK_SIZE(block_size, BLOCK_SIZE, + {CascadeAppendWriteCacheKVC8QKV( + *key_cache_out, *value_cache_out, *qkv_out, cache_k_scale.get(), cache_v_scale.get(), seq_lens_this_time, + seq_lens_decoder, padding_offsets, cum_offsets, block_tables, batch_ids, tile_ids, num_blocks, max_seq_len, num_heads, kv_num_heads, stream, key_cache_out, value_cache_out);})}) + } else if (cache_quant_type_str == "cache_int4") { + DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, + {DISPATCH_BLOCK_SIZE(block_size, BLOCK_SIZE, + {CascadeAppendWriteCacheKVC4QKV( + *key_cache_out, *value_cache_out, *qkv_out, cache_k_scale.get(), cache_v_scale.get(), cache_k_zp.get(), cache_v_zp.get(), seq_lens_this_time, + seq_lens_decoder, padding_offsets, cum_offsets, block_tables, batch_ids, tile_ids, num_blocks, max_seq_len, num_heads, kv_num_heads, stream, key_cache_out, value_cache_out);})}) + } else { + PD_THROW( + "NOT supported cache_quant_type. " + "Only none, cache_int8 and cache_int4 are supported. "); + } +} \ No newline at end of file diff --git a/csrc/gpu/append_attn/get_block_shape_and_split_kv_block.cu b/csrc/gpu/append_attn/get_block_shape_and_split_kv_block.cu new file mode 100644 index 000000000000..95d4a9e349bb --- /dev/null +++ b/csrc/gpu/append_attn/get_block_shape_and_split_kv_block.cu @@ -0,0 +1,254 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "helper.h" +#include "paddle/extension.h" + +template +inline __device__ __host__ T div_up(T m, T n) { + return (m + n - 1) / n; +} + +__global__ void split_q_block(const int* __restrict__ seq_lens_q, + const int* __restrict__ seq_lens_encoder, + int* __restrict__ batch_ids, + int* __restrict__ tile_ids_per_batch, + int* __restrict__ num_blocks_x, + const int bsz, + const int num_rows_per_block, + const int gqa_group_size) { + if (threadIdx.x == 0) { + int gridx = 0; + int index = 0; + for (uint32_t bid = 0; bid < bsz; bid++) { + int seq_len = seq_lens_q[bid]; + if (seq_lens_encoder && seq_lens_encoder[bid] > 0) { + seq_len = 0; + } + const int loop_times = + div_up(seq_len * gqa_group_size, num_rows_per_block); + for (uint32_t tile_id = 0; tile_id < loop_times; tile_id++) { + batch_ids[index] = bid; + tile_ids_per_batch[index++] = tile_id; + } + gridx += loop_times; + } + *num_blocks_x = gridx; + } +} + +__global__ void split_kv_block(const int* __restrict__ seq_lens_decoder, + const int* __restrict__ seq_lens_encoder, + int* __restrict__ batch_ids, + int* __restrict__ tile_ids_per_batch, + int* __restrict__ num_blocks_x, + const int bsz, + const int pad_len, + const int num_row_per_block) { + if (threadIdx.x == 0) { + int gridx = 0; + int index = 0; + for (uint32_t bid = 0; bid < bsz; bid++) { + const int start_len = seq_lens_decoder[bid]; + int seq_len = seq_lens_encoder[bid] + start_len % pad_len; + if (seq_lens_encoder[bid] == 0) { + seq_len = 0; + } + const int loop_times = div_up(seq_len, num_row_per_block); + for (uint32_t tile_id = 0; tile_id < loop_times; tile_id++) { + batch_ids[index] = bid; + tile_ids_per_batch[index++] = tile_id; + } + gridx += loop_times; + } + *num_blocks_x = gridx; + } +} + +std::vector GetBlockShapeAndSplitKVBlock( + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& seq_lens_decoder, + const paddle::Tensor& max_enc_len_this_time, + const paddle::Tensor& seq_lens_this_time, + const paddle::Tensor& cum_offsets, + const int encoder_block_shape_q, + const int decoder_block_shape_q, + const int gqa_group_size, + const int block_size) { + auto stream = seq_lens_encoder.stream(); + int bsz = cum_offsets.shape()[0]; + + // decoder + const uint32_t decoder_max_tile_size_per_bs_q = + div_up((1 * gqa_group_size), decoder_block_shape_q); + auto decoder_batch_ids = + GetEmptyTensor({bsz * decoder_max_tile_size_per_bs_q}, + paddle::DataType::INT32, + seq_lens_encoder.place()); + auto decoder_tile_ids_per_batch = + GetEmptyTensor({bsz * decoder_max_tile_size_per_bs_q}, + paddle::DataType::INT32, + seq_lens_encoder.place()); + auto decoder_num_blocks_x = + GetEmptyTensor({1}, paddle::DataType::INT32, seq_lens_encoder.place()); + split_q_block<<<1, 32, 0, stream>>>(seq_lens_this_time.data(), + seq_lens_encoder.data(), + decoder_batch_ids.data(), + decoder_tile_ids_per_batch.data(), + decoder_num_blocks_x.data(), + bsz, + decoder_block_shape_q, + gqa_group_size); + auto decoder_num_blocks_x_cpu = + decoder_num_blocks_x.copy_to(paddle::CPUPlace(), false); + + int max_enc_len_this_time_data = max_enc_len_this_time.data()[0]; + if (max_enc_len_this_time_data <= 0) { + auto encoder_batch_ids = + paddle::full({1}, -1, paddle::DataType::INT32, paddle::GPUPlace()); + auto encoder_tile_ids_per_batch = + paddle::full({1}, -1, paddle::DataType::INT32, paddle::GPUPlace()); + auto encoder_num_blocks_x_cpu = + paddle::full({1}, -1, paddle::DataType::INT32, paddle::CPUPlace()); + auto kv_batch_ids = + paddle::full({1}, -1, paddle::DataType::INT32, paddle::GPUPlace()); + auto kv_tile_ids_per_batch = + paddle::full({1}, -1, paddle::DataType::INT32, paddle::GPUPlace()); + auto kv_num_blocks_x_cpu = + paddle::full({1}, -1, paddle::DataType::INT32, paddle::CPUPlace()); + + return {encoder_batch_ids, + encoder_tile_ids_per_batch, + encoder_num_blocks_x_cpu, /*cpu*/ + kv_batch_ids, + kv_tile_ids_per_batch, + kv_num_blocks_x_cpu, /*cpu*/ + decoder_batch_ids, + decoder_tile_ids_per_batch, + decoder_num_blocks_x_cpu /*cpu*/}; + } + + // encoder + const uint32_t encoder_max_tile_size_per_bs_q = div_up( + (max_enc_len_this_time_data * gqa_group_size), encoder_block_shape_q); + auto encoder_batch_ids = + GetEmptyTensor({bsz * encoder_max_tile_size_per_bs_q}, + paddle::DataType::INT32, + seq_lens_encoder.place()); + auto encoder_tile_ids_per_batch = + GetEmptyTensor({bsz * encoder_max_tile_size_per_bs_q}, + paddle::DataType::INT32, + seq_lens_encoder.place()); + auto encoder_num_blocks_x = + GetEmptyTensor({1}, paddle::DataType::INT32, seq_lens_encoder.place()); + split_q_block<<<1, 32, 0, stream>>>(seq_lens_encoder.data(), + nullptr, + encoder_batch_ids.data(), + encoder_tile_ids_per_batch.data(), + encoder_num_blocks_x.data(), + bsz, + encoder_block_shape_q, + gqa_group_size); + auto encoder_num_blocks_x_cpu = + encoder_num_blocks_x.copy_to(paddle::CPUPlace(), false); + + // kv + const uint32_t max_tile_size_per_bs_kv = + div_up(max_enc_len_this_time_data, block_size); + auto kv_batch_ids = GetEmptyTensor({bsz * max_tile_size_per_bs_kv}, + paddle::DataType::INT32, + seq_lens_encoder.place()); + auto kv_tile_ids_per_batch = GetEmptyTensor({bsz * max_tile_size_per_bs_kv}, + paddle::DataType::INT32, + seq_lens_encoder.place()); + auto kv_num_blocks_x = + GetEmptyTensor({1}, paddle::DataType::INT32, seq_lens_encoder.place()); + split_kv_block<<<1, 32, 0, stream>>>(seq_lens_decoder.data(), + seq_lens_encoder.data(), + kv_batch_ids.data(), + kv_tile_ids_per_batch.data(), + kv_num_blocks_x.data(), + bsz, + block_size, + block_size); + auto kv_num_blocks_x_cpu = kv_num_blocks_x.copy_to(paddle::CPUPlace(), false); + return {encoder_batch_ids, + encoder_tile_ids_per_batch, + encoder_num_blocks_x_cpu, /*cpu*/ + kv_batch_ids, + kv_tile_ids_per_batch, + kv_num_blocks_x_cpu, /*cpu*/ + decoder_batch_ids, + decoder_tile_ids_per_batch, + decoder_num_blocks_x_cpu /*cpu*/}; +} + +std::vector GetBlockShapeAndSplitKVBlockInferDtype( + const paddle::DataType& seq_lens_encoder_dtype, + const paddle::DataType& seq_lens_decoder_dtype, + const paddle::DataType& max_enc_len_this_time_dtype, + const paddle::DataType& seq_lens_this_time_dtype, + const paddle::DataType& cum_offsets_dtype) { + return {paddle::DataType::INT32, + paddle::DataType::INT32, + paddle::DataType::INT32, + paddle::DataType::INT32, + paddle::DataType::INT32, + paddle::DataType::INT32, + paddle::DataType::INT32, + paddle::DataType::INT32, + paddle::DataType::INT32}; +} + +std::vector> GetBlockShapeAndSplitKVBlockInferShape( + const std::vector& seq_lens_encoder_shape, + const std::vector& seq_lens_decoder_shape, + const std::vector& max_enc_len_this_time_shape, + const std::vector& seq_lens_this_time_shape, + const std::vector& cum_offsets_shape) { + std::vector dynamic_shape = {-1}; + + return {dynamic_shape, + dynamic_shape, + {1}, + dynamic_shape, + dynamic_shape, + {1}, + dynamic_shape, + dynamic_shape, + {1}}; +} + +PD_BUILD_OP(get_block_shape_and_split_kv_block) + .Inputs({"seq_lens_encoder", + "seq_lens_decoder", + "max_enc_len_this_time", + "seq_lens_this_time", + "cum_offsets"}) + .Outputs({"encoder_batch_ids", + "encoder_tile_ids_per_batch", + "encoder_num_blocks", + "kv_batch_ids", + "kv_tile_ids_per_batch", + "kv_num_blocks", + "decoder_batch_ids", + "decoder_tile_ids_per_batch", + "decoder_num_blocks"}) + .Attrs({"encoder_block_shape_q: int", + "decoder_block_shape_q: int", + "gqa_group_size: int", + "block_size: int"}) + .SetKernelFn(PD_KERNEL(GetBlockShapeAndSplitKVBlock)) + .SetInferShapeFn(PD_INFER_SHAPE(GetBlockShapeAndSplitKVBlockInferShape)) + .SetInferDtypeFn(PD_INFER_DTYPE(GetBlockShapeAndSplitKVBlockInferDtype)); diff --git a/csrc/gpu/append_attn/mem_util.cuh b/csrc/gpu/append_attn/mem_util.cuh new file mode 100644 index 000000000000..787f1992a3fb --- /dev/null +++ b/csrc/gpu/append_attn/mem_util.cuh @@ -0,0 +1,261 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once +#include +#include + +enum class SharedMemFillMode { + kFillZero, + kNoFill +}; + +enum class PrefetchMode { + kNoPrefetch, + kPrefetch +}; + +template +__device__ __forceinline__ void ldmatrix_m8n8x4_impl(uint32_t* R, T* smem_ptr) { + uint32_t smem_int_ptr = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0, %1, %2, %3}, [%4];\n" + : "=r"(R[0]), "=r"(R[1]), "=r"(R[2]), "=r"(R[3]) + : "r"(smem_int_ptr)); +} + +template +__device__ __forceinline__ void ldmatrix_m8n8x4_trans_impl(uint32_t* R, T* smem_ptr) { + uint32_t smem_int_ptr = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile("ldmatrix.sync.aligned.trans.m8n8.x4.shared.b16 {%0, %1, %2, %3}, [%4];\n" + : "=r"(R[0]), "=r"(R[1]), "=r"(R[2]), "=r"(R[3]) + : "r"(smem_int_ptr)); +} + +__device__ __forceinline__ void commit_group() { + asm volatile("cp.async.commit_group;\n" ::); +} + +template +__device__ __forceinline__ void wait_group() { + asm volatile("cp.async.wait_group %0;\n" ::"n"(n)); +} + +template +__device__ __forceinline__ void load_128b(T* smem_ptr, const T* gmem_ptr) { + uint32_t smem_int_ptr = static_cast(__cvta_generic_to_shared(smem_ptr)); + if constexpr (prefetch_mode == PrefetchMode::kPrefetch) { + asm volatile("cp.async.cg.shared.global.L2::128B [%0], [%1], %2, %3;\n" ::"r"(smem_int_ptr), + "l"(gmem_ptr), "n"(16), "r"(16)); + } else { + asm volatile("cp.async.cg.shared.global [%0], [%1], %2, %3;\n" ::"r"(smem_int_ptr), + "l"(gmem_ptr), "n"(16), "r"(16)); + } +} + +template +__device__ __forceinline__ void pred_load_128b(T* smem_ptr, const T* gmem_ptr, bool predicate) { + uint32_t smem_int_ptr = static_cast(__cvta_generic_to_shared(smem_ptr)); + if constexpr (fill_mode == SharedMemFillMode::kFillZero) { + int src_in_bytes = predicate ? 16 : 0; + if constexpr (prefetch_mode == PrefetchMode::kPrefetch) { + asm volatile("cp.async.cg.shared.global.L2::128B [%0], [%1], %2, %3;\n" ::"r"(smem_int_ptr), + "l"(gmem_ptr), "n"(16), "r"(src_in_bytes)); + } else { + asm volatile("cp.async.cg.shared.global [%0], [%1], %2, %3;\n" ::"r"(smem_int_ptr), + "l"(gmem_ptr), "n"(16), "r"(src_in_bytes)); + } + } else { + if constexpr (prefetch_mode == PrefetchMode::kPrefetch) { + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %0, 0;\n" + " @p cp.async.cg.shared.global.L2::128B [%1], [%2], %3;\n" + "}\n" ::"r"((int)predicate), + "r"(smem_int_ptr), "l"(gmem_ptr), "n"(16)); + } else { + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %0, 0;\n" + " @p cp.async.cg.shared.global [%1], [%2], %3;\n" + "}\n" ::"r"((int)predicate), + "r"(smem_int_ptr), "l"(gmem_ptr), "n"(16)); + } + } +} + +template +__device__ __forceinline__ void pred_load_64b(T* smem_ptr, const T* gmem_ptr, bool predicate) { + uint32_t smem_int_ptr = static_cast(__cvta_generic_to_shared(smem_ptr)); + if constexpr (fill_mode == SharedMemFillMode::kFillZero) { + int src_in_bytes = predicate ? 8 : 0; + asm volatile("cp.async.ca.shared.global [%0], [%1], %2, %3;\n" ::"r"(smem_int_ptr), + "l"(gmem_ptr), "n"(8), "r"(src_in_bytes)); + } else { + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %0, 0;\n" + " @p cp.async.ca.shared.global [%1], [%2], %3;\n" + "}\n" ::"r"((int)predicate), + "r"(smem_int_ptr), "l"(gmem_ptr), "n"(8)); + } +} + +template +__device__ __forceinline__ void pred_load_32b(T* smem_ptr, const T* gmem_ptr, bool predicate) { + uint32_t smem_int_ptr = static_cast(__cvta_generic_to_shared(smem_ptr)); + if constexpr (fill_mode == SharedMemFillMode::kFillZero) { + int src_in_bytes = predicate ? 4 : 0; + asm volatile("cp.async.ca.shared.global [%0], [%1], %2, %3;\n" ::"r"(smem_int_ptr), + "l"(gmem_ptr), "n"(4), "r"(src_in_bytes)); + } else { + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %0, 0;\n" + " @p cp.async.ca.shared.global [%1], [%2], %3;\n" + "}\n" ::"r"((int)predicate), + "r"(smem_int_ptr), "l"(gmem_ptr), "n"(4)); + } +} + +template +__device__ __forceinline__ void load(T* smem_ptr, const T* gmem_ptr) { + static_assert(num_bits == 128, "num_bits must be 128"); + load_128b(smem_ptr, gmem_ptr); +} + +template +__device__ __forceinline__ void pred_load(T* smem_ptr, const T* gmem_ptr, bool predicate) { + static_assert(num_bits == 128 || num_bits == 64 || num_bits == 32, "num_bits must be 128, 64 or 32."); + if constexpr (num_bits == 128) { + pred_load_128b(smem_ptr, gmem_ptr, predicate); + } else if constexpr (num_bits == 64) { + pred_load_64b(smem_ptr, gmem_ptr, predicate); + } else if constexpr (num_bits == 32) { + pred_load_32b(smem_ptr, gmem_ptr, predicate); + } +} + +using b32_t = uint32_t; +using b64_t = uint2; +using b128_t = uint4; + +template +constexpr __host__ __device__ __forceinline__ uint32_t num_elems_per_128b() { + return sizeof(b128_t) / sizeof(T); +} + +struct smem_t { + // The base pointer. + b128_t* base; + __device__ __forceinline__ smem_t() : base(nullptr) {} + template + __device__ __forceinline__ smem_t(T* base) : base((b128_t*)base) {} + + + template + static __device__ __forceinline__ uint32_t get_permuted_offset(uint32_t i, uint32_t j) { + if constexpr (inv_stride <= 1) { + return i * stride + (j ^ (i % 8)); + } else { + return i / inv_stride * 8 + ((j + (i % inv_stride) * stride)) ^ ((i / inv_stride) % 8); + } + } + + template + static __device__ __forceinline__ uint32_t advance_offset_by_column(uint32_t offset, + uint32_t step_idx) { + if constexpr (row_stride == 2) { + static_assert(step_size == 2, "Unsupported step size"); + return offset + step_size; + } else if constexpr (row_stride == 4) { + static_assert(step_size == 2 || step_size == 4, "Unsupported step size"); + if constexpr (step_size == 2) { + return (offset ^ 0x2) + (step_idx % 2 == 1) * 4; + } else { + return offset + step_size; + } + } else { + static_assert(step_size == 2 || step_size == 4 || step_size % 8 == 0, "Unsupported step size"); + if constexpr (step_size == 2) { + return (offset ^ (0x2 + (0x4 * (step_idx % 2 == 1)))) + (step_idx % 4 == 3) * 8; + } else if constexpr (step_size == 4) { + return (offset ^ 0x4) + (step_idx % 2 == 1) * 8; + } else { + // step_size % 8 == 0 + return offset + step_size; + } + } + } + + template + static __device__ __forceinline__ uint32_t advance_offset_by_row(uint32_t offset) { + if constexpr (row_stride == 2) { + static_assert(step_size == 16 || step_size % 32 == 0, "Unsupported step size"); + if constexpr (step_size == 16) { + return (offset ^ 0x4) + step_size * row_stride; + } else { + // step_size % 32 == 0 + return offset + step_size * row_stride; + } + } else if constexpr (row_stride == 4) { + static_assert(step_size == 8 || step_size % 16 == 0, "Unsupported step size"); + if constexpr (step_size == 8) { + return (offset ^ 0x4) + step_size * row_stride; + } else { + // step_size % 16 == 0 + return offset + step_size * row_stride; + } + } else { + static_assert(step_size == 4 || step_size % 8 == 0, "Unsupported step size"); + if constexpr (step_size == 4) { + return (offset ^ 0x4) + step_size * row_stride; + } else { + // step_size % 8 == 0 + return offset + step_size * row_stride; + } + } + } + + __device__ __forceinline__ void ldmatrix_m8n8x4(uint32_t offset, uint32_t* R) { + b128_t* smem_ptr = base + offset; + ldmatrix_m8n8x4_impl(R, smem_ptr); + } + + __device__ __forceinline__ void ldmatrix_m8n8x4_trans(uint32_t offset, uint32_t* R) { + b128_t* smem_ptr = base + offset; + ldmatrix_m8n8x4_trans_impl(R, smem_ptr); + } + + template + __device__ __forceinline__ void load_128b_async(uint32_t offset, const T* gptr, bool predicate) { + b128_t* smem_ptr = base + offset; + pred_load_128b( + smem_ptr, reinterpret_cast(gptr), predicate); + } + + template + __device__ __forceinline__ void load_128b_async(uint32_t offset, const T* gptr) { + b128_t* smem_ptr = base + offset; + load_128b( + smem_ptr, reinterpret_cast(gptr)); + } + + template + __device__ __forceinline__ void store_128b(uint32_t offset, T* gptr) { + *reinterpret_cast(gptr) = *(base + offset); + } +}; diff --git a/csrc/gpu/append_attn/mma_tensor_op.cuh b/csrc/gpu/append_attn/mma_tensor_op.cuh new file mode 100644 index 000000000000..77e3577aca0d --- /dev/null +++ b/csrc/gpu/append_attn/mma_tensor_op.cuh @@ -0,0 +1,188 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include +#include +#include + +enum class MMAMode { + kInit = 0U, + kInplaceUpdate = 1U, +}; + +template +__device__ __forceinline__ void mma_sync_m16n16k32_row_col_i8i8i32(int* C, // 8 + uint32_t* A, // 4 + uint32_t* B) { // 4 + if constexpr (mma_mode == MMAMode::kInit) { + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=r"(C[0]), "=r"(C[1]), "=r"(C[2]), "=r"(C[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(0), "r"(0), + "r"(0), "r"(0)); + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=r"(C[4]), "=r"(C[5]), "=r"(C[6]), "=r"(C[7]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[2]), "r"(B[3]), "r"(0), "r"(0), + "r"(0), "r"(0)); + } else { + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=r"(C[0]), "=r"(C[1]), "=r"(C[2]), "=r"(C[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(C[0]), "r"(C[1]), + "r"(C[2]), "r"(C[3])); + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=r"(C[4]), "=r"(C[5]), "=r"(C[6]), "=r"(C[7]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[2]), "r"(B[3]), "r"(C[4]), "r"(C[5]), + "r"(C[6]), "r"(C[7])); + } +} + +template +__device__ __forceinline__ void mma_sync_m16n16k16_row_col_f16f16f32(float* C, + uint32_t* A, + uint32_t* B) { + if constexpr (mma_mode == MMAMode::kInit) { + if constexpr (std::is_same::value) { // fp16 + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(C[0]), "=f"(C[1]), "=f"(C[2]), "=f"(C[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "f"(0.f), "f"(0.f), + "f"(0.f), "f"(0.f)); + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(C[4]), "=f"(C[5]), "=f"(C[6]), "=f"(C[7]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[2]), "r"(B[3]), "f"(0.f), "f"(0.f), + "f"(0.f), "f"(0.f)); + } else { // bf16 + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(C[0]), "=f"(C[1]), "=f"(C[2]), "=f"(C[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "f"(0.f), "f"(0.f), + "f"(0.f), "f"(0.f)); + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(C[4]), "=f"(C[5]), "=f"(C[6]), "=f"(C[7]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[2]), "r"(B[3]), "f"(0.f), "f"(0.f), + "f"(0.f), "f"(0.f)); + } + } else { + if constexpr (std::is_same::value) { + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(C[0]), "=f"(C[1]), "=f"(C[2]), "=f"(C[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "f"(C[0]), "f"(C[1]), + "f"(C[2]), "f"(C[3])); + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(C[4]), "=f"(C[5]), "=f"(C[6]), "=f"(C[7]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[2]), "r"(B[3]), "f"(C[4]), "f"(C[5]), + "f"(C[6]), "f"(C[7])); + } else { + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(C[0]), "=f"(C[1]), "=f"(C[2]), "=f"(C[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "f"(C[0]), "f"(C[1]), + "f"(C[2]), "f"(C[3])); + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(C[4]), "=f"(C[5]), "=f"(C[6]), "=f"(C[7]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[2]), "r"(B[3]), "f"(C[4]), "f"(C[5]), + "f"(C[6]), "f"(C[7])); + } + } +} + +template +__device__ __forceinline__ void rowsum_f16f16f32(float* d, DType* s) { + static_assert(sizeof(DType) == 2, "DType must be 16bit floating data type"); + uint32_t* s_u32 = (uint32_t*)(s); + if constexpr (std::is_same::value) { + asm volatile( + "{\n" + ".reg .f32 ph;\n" + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " + "{%0, ph, %1, ph}," + "{%2, %3, %4, %5}," + "{%6, %7}," + "{%8, 0., %9, 0.};\n" + "}\n" + : "=f"(d[0]), "=f"(d[1]) + : "r"(s_u32[0]), "r"(s_u32[1]), "r"(s_u32[2]), "r"(s_u32[3]), "r"(1006648320), + "r"(1006648320), "f"(d[0]), "f"(d[1])); + } else { + asm volatile( + "{\n" + ".reg .f32 ph;\n" + "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " + "{%0, ph, %1, ph}," + "{%2, %3, %4, %5}," + "{%6, %7}," + "{%8, 0., %9, 0.};\n" + "}\n" + : "=f"(d[0]), "=f"(d[1]) + : "r"(s_u32[0]), "r"(s_u32[1]), "r"(s_u32[2]), "r"(s_u32[3]), "r"(1065369472), + "r"(1065369472), "f"(d[0]), "f"(d[1])); + } +} diff --git a/csrc/gpu/append_attn/utils.cuh b/csrc/gpu/append_attn/utils.cuh new file mode 100644 index 000000000000..a094a4c8f0f2 --- /dev/null +++ b/csrc/gpu/append_attn/utils.cuh @@ -0,0 +1,388 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include +#include + +__forceinline__ __host__ __device__ int div_up(int a, int b) { + return (a + b - 1) / b; +} + +enum PosEncMode { + kNonePos, + kRoPE, + kAliBi +}; + +enum CacheType { + CacheT, + CacheInt8Hw, + CacheInt4CwZp +}; + +template +struct cascade_attn_type_traits { + using type = T; +}; + +template<> +struct cascade_attn_type_traits { + using type = __nv_bfloat16; +}; + +template<> +struct cascade_attn_type_traits { + using type = half; +}; + +template +struct cascade_attn_nv_type2_traits { + using type = T; +}; + +template<> +struct cascade_attn_nv_type2_traits<__nv_bfloat16> { + using type = __nv_bfloat162; +}; + +template<> +struct cascade_attn_nv_type2_traits { + using type = half2; +}; + +template +struct vec_traits { + using type = b128_t; +}; + +template<> +struct vec_traits { + using type = b64_t; +}; + +template<> +struct vec_traits { + using type = b32_t; +}; + +template +struct cache_type_traits { + using type = T; +}; + +template +struct cache_type_traits { + using type = uint8_t; +}; + +template +struct cache_type_traits { + using type = uint8_t; +}; + +__device__ __forceinline__ uint32_t sub_if_greater_or_zero(uint32_t x, uint32_t y) { + return (x > y) ? x - y : 0U; +} + +/******************************FASTER CAST*********************************/ +inline __device__ static void convert_int8(__nv_bfloat16* result, const uint32_t& source) { // 4 int8 each time + uint32_t* bf16_result_ptr = reinterpret_cast(result); + uint32_t const i8s = reinterpret_cast(source); + + static constexpr uint32_t fp32_base = 0x4B000000; + float fp32_intermediates[4]; + + uint32_t* fp32_intermediates_casted = + reinterpret_cast(fp32_intermediates); + fp32_intermediates_casted[0] = __byte_perm(i8s, fp32_base, 0x7650); + fp32_intermediates_casted[1] = __byte_perm(i8s, fp32_base, 0x7651); + fp32_intermediates_casted[2] = __byte_perm(i8s, fp32_base, 0x7652); + fp32_intermediates_casted[3] = __byte_perm(i8s, fp32_base, 0x7653); + +#pragma unroll + for (int ii = 0; ii < 4; ++ii) { + fp32_intermediates[ii] -= 8388736.f;// (8388608.f + 128.f); + } + +#pragma unroll + for (int ii = 0; ii < 2; ++ii) { + bf16_result_ptr[ii] = __byte_perm(fp32_intermediates_casted[2 * ii + 0], + fp32_intermediates_casted[2 * ii + 1], + 0x7632); + } +} + +inline __device__ static void convert_int8(half* result, const uint32_t& source) { // 4 int8 each time + uint32_t* fp16_result_ptr = reinterpret_cast(result); + uint32_t const i8s = reinterpret_cast(source); + static constexpr uint32_t mask_for_elt_01 = 0x5150; + static constexpr uint32_t mask_for_elt_23 = 0x5352; + static constexpr uint32_t start_byte_for_fp16 = 0x64646464; + + asm volatile("prmt.b32 %0,%1,%2,%3;\n" : "=r"(fp16_result_ptr[0]) : "r"(i8s), "n"(start_byte_for_fp16), "n"(mask_for_elt_01)); + asm volatile("prmt.b32 %0,%1,%2,%3;\n" : "=r"(fp16_result_ptr[1]) : "r"(i8s), "n"(start_byte_for_fp16), "n"(mask_for_elt_23)); + + static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480; + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(fp16_result_ptr[0]) : "r"(fp16_result_ptr[0]), "r"(I8s_TO_F16s_MAGIC_NUM)); + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(fp16_result_ptr[1]) : "r"(fp16_result_ptr[1]), "r"(I8s_TO_F16s_MAGIC_NUM)); +} + +inline __device__ static void convert_int4(__nv_bfloat16* result, const uint32_t& source) { // 8 int4 each time + uint32_t* bf16_result_ptr = reinterpret_cast(result); + + static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa; + static constexpr uint32_t MASK = 0x0f0f0f0f; // 0xf -> 0b1111 select 0,4 + static constexpr uint32_t I4s_TO_FP32s_MAGIC_NUM = 0x43434343; + static constexpr uint32_t mask_for_elt_01 = 0x5150; + static constexpr uint32_t mask_for_elt_23 = 0x5352; + + uint32_t tmp1 = source & MASK; // 0 1 2 3 + uint32_t tmp2 = source >> 4 & MASK; // 4 5 6 7 + + bf16_result_ptr[0] = __byte_perm(tmp1, + I4s_TO_FP32s_MAGIC_NUM, + mask_for_elt_01); // 0 1 + bf16_result_ptr[1] = __byte_perm(tmp1, + I4s_TO_FP32s_MAGIC_NUM, + mask_for_elt_23); // 2 3 + bf16_result_ptr[2] = __byte_perm(tmp2, + I4s_TO_FP32s_MAGIC_NUM, + mask_for_elt_01); // 4 5 + bf16_result_ptr[3] = __byte_perm(tmp2, + I4s_TO_FP32s_MAGIC_NUM, + mask_for_elt_23); // 6 7 +} + +inline __device__ static void convert_int4(half* result, const uint32_t& source) { // 7 5 3 1 6 4 2 0 + uint32_t* fp16_result_ptr = reinterpret_cast(result); + + static constexpr uint32_t MASK = 0x0f0f0f0f; // 0xf -> 0b1111 select 0,1; 7 5 3 1 6 4 2 0 + static constexpr uint32_t I4s_TO_FP32s_MAGIC_NUM = 0x64646464; + static constexpr uint32_t mask_for_elt_01 = 0x5150; + static constexpr uint32_t mask_for_elt_23 = 0x5352; + + uint32_t tmp1 = source & MASK; // 0 1 2 3 + uint32_t tmp2 = source >> 4 & MASK; // 4 5 6 7 + fp16_result_ptr[0] = __byte_perm(tmp1, + I4s_TO_FP32s_MAGIC_NUM, + mask_for_elt_01); // 0 1 + fp16_result_ptr[1] = __byte_perm(tmp1, + I4s_TO_FP32s_MAGIC_NUM, + mask_for_elt_23); // 2 3 + fp16_result_ptr[2] = __byte_perm(tmp2, + I4s_TO_FP32s_MAGIC_NUM, + mask_for_elt_01); // 4 5 + fp16_result_ptr[3] = __byte_perm(tmp2, + I4s_TO_FP32s_MAGIC_NUM, + mask_for_elt_23); // 6 7 +} + +/******************* vec_t type cast *******************/ + +template +__forceinline__ __host__ __device__ void vec_cast(dst_t* dst, const src_t* src) { +#pragma unroll + for (size_t i = 0; i < vec_size; ++i) { + dst[i] = src[i]; + } +} + +template +__forceinline__ __host__ __device__ void vec_cast(float* dst, const half* src) { +#pragma unroll + for (size_t i = 0; i < vec_size / 2; ++i) { + ((float2*)dst)[i] = __half22float2(((half2*)src)[i]); + } +} + +template +__forceinline__ __host__ __device__ void vec_cast(half* dst, const float* src) { +#pragma unroll + for (size_t i = 0; i < vec_size / 2; ++i) { + ((half2*)dst)[i] = __float22half2_rn(((float2*)src)[i]); + } +} + +template +__forceinline__ __host__ __device__ void vec_cast(float* dst, const nv_bfloat16* src) { +#pragma unroll + for (size_t i = 0; i < vec_size / 2; ++i) { + ((float2*)dst)[i] = __bfloat1622float2(((nv_bfloat162*)src)[i]); + } +} + +template +__forceinline__ __host__ __device__ void vec_cast(nv_bfloat16* dst, const float* src) { +#pragma unroll + for (size_t i = 0; i < vec_size / 2; ++i) { + ((nv_bfloat162*)dst)[i] = __float22bfloat162_rn(((float2*)src)[i]); + } +} + +#define CHECK_CUDA_CALL(func, ...) \ + { \ + cudaError_t e = (func); \ + if (e != cudaSuccess) { \ + std::cerr << "CUDA Error: " << cudaGetErrorString(e) << " (" << e << ") " << __FILE__ \ + << ": line " << __LINE__ << " at function " << STR(func) << std::endl; \ + return e; \ + } \ + } + +#define DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, ...) \ + switch (head_dim) { \ + case 128: { \ + constexpr size_t HEAD_DIM = 128; \ + __VA_ARGS__ \ + break; \ + } \ + default: { \ + PD_THROW("not support the head_dim"); \ + } \ + } + +#define DISPATCH_NUM_STAGE(num_stage, NUM_STAGE, ...) \ + if (num_stage == 2) { \ + constexpr size_t NUM_STAGE = 2; \ + __VA_ARGS__ \ + } + +#define DISPATCH_CACHE_TYPE(cache_type, cache_type_now, cache_bytes, ...) \ + if (cache_type == 0) { \ + constexpr CacheType cache_type_now = CacheType::CacheT; \ + constexpr size_t cache_bytes = 16; \ + __VA_ARGS__ \ + } else if (cache_type == 1) { \ + constexpr CacheType cache_type_now = CacheType::CacheInt8Hw; \ + constexpr size_t cache_bytes = 8; \ + __VA_ARGS__ \ + } else if (cache_type == 2) { \ + constexpr CacheType cache_type_now = CacheType::CacheInt4CwZp; \ + constexpr size_t cache_bytes = 4; \ + __VA_ARGS__ \ + } + +#define DISPATCH_DEAL_EACH_TIME(deal_each_time, DEAL_EACH_TIME, ...) \ + if (deal_each_time == 32) { \ + constexpr size_t DEAL_EACH_TIME = 32; \ + __VA_ARGS__ \ + } else if (deal_each_time == 64) { \ + constexpr size_t DEAL_EACH_TIME = 64; \ + __VA_ARGS__ \ + } + +#define DISPATCH_NUM_THREADS(num_threads, NUM_THREADS, ...) \ + if (num_threads == 128) { \ + constexpr size_t NUM_THREADS = 128; \ + __VA_ARGS__ \ + } else if (num_threads == 256) { \ + constexpr size_t NUM_THREADS = 256; \ + __VA_ARGS__ \ + } + +#define DISPATCH_GQA_GROUP_SIZE(group_size, GROUP_SIZE, ...) \ + if (group_size == 1) { \ + constexpr size_t GROUP_SIZE = 1; \ + __VA_ARGS__ \ + } else if (group_size == 2) { \ + constexpr size_t GROUP_SIZE = 2; \ + __VA_ARGS__ \ + } else if (group_size == 3) { \ + constexpr size_t GROUP_SIZE = 3; \ + __VA_ARGS__ \ + } else if (group_size == 4) { \ + constexpr size_t GROUP_SIZE = 4; \ + __VA_ARGS__ \ + } else if (group_size == 5) { \ + constexpr size_t GROUP_SIZE = 5; \ + __VA_ARGS__ \ + } else if (group_size == 6) { \ + constexpr size_t GROUP_SIZE = 6; \ + __VA_ARGS__ \ + } else if (group_size == 7) { \ + constexpr size_t GROUP_SIZE = 7; \ + __VA_ARGS__ \ + } else if (group_size == 8) { \ + constexpr size_t GROUP_SIZE = 8; \ + __VA_ARGS__ \ + } else if (group_size == 12) { \ + constexpr size_t GROUP_SIZE = 12; \ + __VA_ARGS__ \ + } + +#define DISPATCH_BLOCKSHAPE_Q(block_shape_q, BLOCK_SHAPE_Q, NUM_WARP_Q, ...) \ + if (block_shape_q <= 16) { \ + constexpr size_t BLOCK_SHAPE_Q = 16; \ + constexpr size_t NUM_WARP_Q = 1; \ + __VA_ARGS__ \ + } else if (block_shape_q <= 32) { \ + constexpr size_t BLOCK_SHAPE_Q = 32; \ + constexpr size_t NUM_WARP_Q = 1; \ + __VA_ARGS__ \ + } else if (block_shape_q <= 64) { \ + constexpr size_t BLOCK_SHAPE_Q = 64; \ + constexpr size_t NUM_WARP_Q = 4; \ + __VA_ARGS__ \ + } else { \ + constexpr size_t BLOCK_SHAPE_Q = 128; \ + constexpr size_t NUM_WARP_Q = 4; \ + __VA_ARGS__ \ + } + +#define DISPATCH_CAUSAL(causal, CAUSAL, ...) \ + if (causal) { \ + constexpr bool CAUSAL = true; \ + __VA_ARGS__ \ + } + +#define DISPATCH_ENABLE_PREFILL(enable_prefill, ENABLE_PREFILL, ...) \ + if (enable_prefill) { \ + constexpr bool ENABLE_PREFILL = 1; \ + __VA_ARGS__ \ + } else { \ + constexpr bool ENABLE_PREFILL = 0; \ + __VA_ARGS__ \ + } + +#define DISPATCH_BLOCK_SIZE(block_size, BLOCK_SIZE, ...) \ + if (block_size == 64) { \ + constexpr size_t BLOCK_SIZE = 64; \ + __VA_ARGS__ \ + } + +#define DISPATCH_BLOCKSHAPE_Q_SYSTEM(block_shape_q, BLOCK_SHAPE_Q, NUM_WARP_Q, ...) \ + if (block_shape_q <= 16) { \ + constexpr size_t BLOCK_SHAPE_Q = 16; \ + constexpr size_t NUM_WARP_Q = 1; \ + __VA_ARGS__ \ + } else if (block_shape_q <= 32) { \ + constexpr size_t BLOCK_SHAPE_Q = 32; \ + constexpr size_t NUM_WARP_Q = 1; \ + __VA_ARGS__ \ + } + +template +inline HOSTDEVICE T roundWithTiesToEven(T x) { + T xLower = floor(x); + T xUpper = ceil(x); + // x is in interval [xl,xu]. Choose closest of two bounds, breaking ties to + // even. + T dLower = x - xLower; + T dUpper = xUpper - x; + return static_cast( + (dLower == dUpper ? fmod(xLower, 2.0F) == 0.0F : dLower < dUpper) + ? xLower + : xUpper); +} \ No newline at end of file diff --git a/csrc/gpu/cutlass_kernels/cutlass_helper.h b/csrc/gpu/cutlass_kernels/cutlass_helper.h index 1374800a3d27..27594898c95b 100644 --- a/csrc/gpu/cutlass_kernels/cutlass_helper.h +++ b/csrc/gpu/cutlass_kernels/cutlass_helper.h @@ -67,7 +67,7 @@ class CutlassGemmConfigMannager { if(!file.good()){ throw std::runtime_error("cutlass gemm_best_config can not be found, please set gemm_best_config'path as FLAGS_use_cutlass_device_best_config_path, or unset FLAGS_use_cutlass_device_best_config_path to tune gemm_best_config"); } - json_ = readJsonFromFile(config_file_path); + json_ = ReadJsonFromFile(config_file_path); load_initialized_ = true; save_initialized_ = false; } @@ -82,7 +82,7 @@ class CutlassGemmConfigMannager { new_file << json_.dump(4); new_file.close(); } else { - nlohmann::json old_json = readJsonFromFile(config_file_path); + nlohmann::json old_json = ReadJsonFromFile(config_file_path); for (auto it = json_.begin(); it != json_.end(); ++it) { old_json[it.key()] = it.value(); } diff --git a/csrc/gpu/fused_get_rope.cu b/csrc/gpu/fused_get_rope.cu index af34c0f075ba..a7e1a3fcc429 100644 --- a/csrc/gpu/fused_get_rope.cu +++ b/csrc/gpu/fused_get_rope.cu @@ -36,6 +36,8 @@ union Pack { T elem[N]; }; +constexpr int kBlockSize = 256; + __global__ __launch_bounds__(kBlockSize) void fused_get_rotary_embedding_neox(const int64_t* position_ids, const int32_t bsz, const int32_t max_seq_length, @@ -162,7 +164,7 @@ std::vector GetRoPE(const paddle::Tensor& input_ids, assert(head_dim % 2 == 0); const int32_t elem_cnt = batch_size * max_seq_length * head_dim / 2; int32_t grid_size = 1; - GetNumBlocks(elem_cnt, &grid_size); + GetNumBlocks(elem_cnt, &grid_size); if (use_neox) { fused_get_rotary_embedding_neox<<>> ( position_ids.data(), diff --git a/csrc/gpu/helper.h b/csrc/gpu/helper.h index ceccbd4ee4a0..64b09dc3a7b7 100644 --- a/csrc/gpu/helper.h +++ b/csrc/gpu/helper.h @@ -14,7 +14,6 @@ #pragma once -#include "paddle/extension.h" #include #include #include @@ -37,14 +36,28 @@ namespace cub = hipcub; #endif #include #include + +#include "paddle/extension.h" +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/allocator.h" #include "nlohmann/json.hpp" using json = nlohmann::json; -constexpr int kBlockSize = 256; -constexpr int kNumWaves = 16; +#define CUDA_CHECK(call) \ + do { \ + const cudaError_t error_code = call; \ + if (error_code != cudaSuccess) { \ + std::printf("at %s:%d - %s.\n", \ + __FILE__, \ + __LINE__, \ + cudaGetErrorString(error_code)); \ + exit(1); \ + } \ + } while (0) #ifdef PADDLE_WITH_HIP +template inline hipError_t GetNumBlocks(int64_t n, int* num_blocks) { int dev; { @@ -66,6 +79,7 @@ inline hipError_t GetNumBlocks(int64_t n, int* num_blocks) { return hipSuccess; } #else +template inline cudaError_t GetNumBlocks(int64_t n, int* num_blocks) { int dev; { @@ -159,9 +173,19 @@ HOSTDEVICE inline void Store(const AlignedVector& vec, T* addr) { *addr_vec = vec; } +template +HOSTDEVICE inline void Store(const AlignedVector<__nv_bfloat16, Size>& vec, int8_t* addr) { + printf("Error: Store __nv_bfloat16 to int8_t is not supported!"); +} + +template +HOSTDEVICE inline void Store(const AlignedVector& vec, int8_t* addr) { + printf("Error: Store half to int8_t is not supported!"); +} + constexpr int VEC_16B = 16; -inline json readJsonFromFile(const std::string& filePath) { +inline json ReadJsonFromFile(const std::string& filePath) { std::ifstream file(filePath); if (!file.is_open()) { throw std::runtime_error("Unable to open file: " + filePath); @@ -170,4 +194,13 @@ inline json readJsonFromFile(const std::string& filePath) { json j; file >> j; return j; -} \ No newline at end of file +} + +// place must be an existing place object and cannot use paddle::CPUPlace() or paddle::GPUPlace() +inline paddle::Tensor GetEmptyTensor(const common::DDim& dims, const paddle::DataType& dtype, const paddle::Place& place){ + auto* allocator = paddle::GetAllocator(place); + phi::DenseTensor dense_tensor; + dense_tensor.Resize(dims); + dense_tensor.AllocateFrom(allocator, dtype, dense_tensor.numel() * phi::SizeOf(dtype)); + return paddle::Tensor(std::make_shared(dense_tensor)); +} diff --git a/csrc/setup_cuda.py b/csrc/setup_cuda.py index 00a4b205a12e..9740093dde5a 100644 --- a/csrc/setup_cuda.py +++ b/csrc/setup_cuda.py @@ -107,6 +107,13 @@ def get_gencode_flags(): "./gpu/dequant_int8.cu", "./gpu/flash_attn_bwd.cc", "./gpu/tune_cublaslt_gemm.cu", + "./gpu/append_attention.cu", + "./gpu/append_attn/get_block_shape_and_split_kv_block.cu", + "./gpu/append_attn/append_attention_bfloat16_bfloat16_kernel.cu", + "./gpu/append_attn/append_attention_bfloat16_int8_kernel.cu", + "./gpu/append_attn/encoder_write_cache_with_rope_bfloat16_bfloat16_kernel.cu", + "./gpu/append_attn/encoder_write_cache_with_rope_bfloat16_int_kernel.cu", + "./gpu/append_attn/decoder_write_cache_with_rope_kernel.cu", "./gpu/sample_kernels/top_p_sampling_reject.cu", ] diff --git a/llm/predict/predictor.py b/llm/predict/predictor.py index 6c2bb09af5b7..4ada51bb6e2e 100644 --- a/llm/predict/predictor.py +++ b/llm/predict/predictor.py @@ -121,6 +121,8 @@ class PredictorArgument: }, ) + append_attn: bool = field(default=False, metadata={"help": "whether use append attention"}) + chat_template: str = field( default=None, metadata={ @@ -139,6 +141,10 @@ def total_max_length(self): else: return 8192 # Maximum sequence length. + def __post_init__(self): + if self.append_attn: + self.block_attn = True + @dataclass class ModelArgument: @@ -1261,6 +1267,7 @@ def create_predictor( elif predictor_args.block_attn: config.max_seq_len = predictor_args.total_max_length config.block_size = predictor_args.block_size + config.append_attn = predictor_args.append_attn from paddlenlp.experimental.transformers import ( LlamaForCausalLMBlockInferenceModel as LlamaInferenceModel, ) @@ -1300,6 +1307,7 @@ def create_predictor( if predictor_args.block_attn: config.max_seq_len = predictor_args.total_max_length config.block_size = predictor_args.block_size + config.append_attn = predictor_args.append_attn from paddlenlp.experimental.transformers import ( MixtralForCausalLMBlockInferenceModel as MixtralInferenceModel, ) @@ -1369,6 +1377,7 @@ def create_predictor( config.block_size = predictor_args.block_size config.max_seq_len = predictor_args.total_max_length + config.append_attn = predictor_args.append_attn else: from paddlenlp.experimental.transformers import ( BloomForCausalLMInferenceModel as BloomInferenceModel, @@ -1397,6 +1406,7 @@ def create_predictor( if predictor_args.block_attn: config.max_seq_len = predictor_args.total_max_length config.block_size = predictor_args.block_size + config.append_attn = predictor_args.append_attn from paddlenlp.experimental.transformers import ( Qwen2MoeForCausalLMBlockInferenceModel as Qwen2MoeInferenceModel, ) @@ -1423,6 +1433,7 @@ def create_predictor( if predictor_args.block_attn: config.max_seq_len = predictor_args.total_max_length config.block_size = predictor_args.block_size + config.append_attn = predictor_args.append_attn from paddlenlp.experimental.transformers import ( Qwen2ForCausalLMBlockInferenceModel as Qwen2InferenceModel, ) diff --git a/paddlenlp/experimental/transformers/fused_transformer_layers.py b/paddlenlp/experimental/transformers/fused_transformer_layers.py index 03ae178d6086..7cf2c01da246 100644 --- a/paddlenlp/experimental/transformers/fused_transformer_layers.py +++ b/paddlenlp/experimental/transformers/fused_transformer_layers.py @@ -55,7 +55,6 @@ from paddlenlp_ops import ( dequant_int8, encode_rotary_qk, - gemm_dequant, qkv_transpose_split, quant_int8, rebuild_padding, @@ -72,6 +71,9 @@ "FusedMultiTransformerPostLayernorm", "FusedMultiTransformerWeightOnly", "FusedMultiTransformerWeightOnlyPostLayernorm", + "FusedAppendMultiTransformer", + "FusedAppendMultiTransformerWeightOnly", + "FusedAppendMultiTransformerA8W8", "FusedBlockMultiTransformer", "FusedBlockMultiTransformerWeightOnly", "FusedBlockMultiTransformerA8W8", @@ -265,6 +267,7 @@ def __init__( rank_id=-1, moe_config=MoeConfig(), avx_config=AvxConfig(), + append_attn=False, ): self.embed_dim = embed_dim self.num_heads = num_heads @@ -342,6 +345,8 @@ def __init__( self.moe_config = moe_config self.avx_config = avx_config + self.append_attn = append_attn + class FusedMultiTransformerBase(Layer): def __init__(self, config: FusedMultiTransformerConfig): @@ -359,10 +364,20 @@ def __init__(self, config: FusedMultiTransformerConfig): # self.normalize_before = normalize_before self._dtype = self._helper.get_default_dtype() + if self._dtype == "bfloat16": + self._fuse_kernel_compute_dtype = "bf16" + elif self._dtype == "float16": + self._fuse_kernel_compute_dtype = "fp16" + elif self._dtype == "float32": + self._fuse_kernel_compute_dtype = "fp32" + else: + raise ValueError( + "FusedMultiTransformer just support float32, float16 and bfloat16 as default dtype, but received {}".format( + self._dtype + ) + ) self._epsilon = config.epsilon self._residual_alpha = config.residual_alpha - self._trans_qkvw = config.trans_qkvw - self._ring_id = config.ring_id self.nranks = config.nranks self.norm_type = config.norm_type if self.norm_type == "layernorm": @@ -541,12 +556,16 @@ def __init__(self, config: FusedMultiTransformerConfig): dtype=self._helper.get_default_dtype(), ) + cache_scale_dtype = "float32" + if self.config.append_attn: + cache_scale_dtype = paddle.get_default_dtype() + cache_k_scale = None if cache_k_scale_attr: cache_k_scale = self.create_parameter( shape=[self.kv_num_heads], attr=cache_k_scale_attr, - dtype="float32", + dtype=cache_scale_dtype, is_bias=False, ) @@ -555,7 +574,7 @@ def __init__(self, config: FusedMultiTransformerConfig): cache_v_scale = self.create_parameter( shape=[self.kv_num_heads], attr=cache_v_scale_attr, - dtype="float32", + dtype=cache_scale_dtype, is_bias=False, ) @@ -564,7 +583,7 @@ def __init__(self, config: FusedMultiTransformerConfig): cache_k_out_scale = self.create_parameter( shape=[self.kv_num_heads], attr=cache_k_out_scale_attr, - dtype="float32", + dtype=cache_scale_dtype, is_bias=False, ) @@ -573,7 +592,7 @@ def __init__(self, config: FusedMultiTransformerConfig): cache_v_out_scale = self.create_parameter( shape=[self.kv_num_heads], attr=cache_v_out_scale_attr, - dtype="float32", + dtype=cache_scale_dtype, is_bias=False, ) @@ -1073,6 +1092,36 @@ def forward( kwargs["max_enc_len_this_time"] = max_enc_len_this_time kwargs["max_dec_len_this_time"] = max_dec_len_this_time + if self.config.append_attn: + kwargs["encoder_block_shape_q"] = 64 + kwargs["decoder_block_shape_q"] = 16 + kwargs["max_partition_size"] = 32768 + kwargs["encoder_max_partition_size"] = 32768 + + from paddlenlp_ops import get_block_shape_and_split_kv_block + + ( + kwargs["encoder_batch_ids"], + kwargs["encoder_tile_ids_per_batch"], + kwargs["encoder_num_blocks"], + kwargs["kv_batch_ids"], + kwargs["kv_tile_ids_per_batch"], + kwargs["kv_num_blocks"], + kwargs["decoder_batch_ids"], + kwargs["decoder_tile_ids_per_batch"], + kwargs["decoder_num_blocks"], + ) = get_block_shape_and_split_kv_block( + kwargs.get("seq_lens_encoder", None), + kwargs.get("seq_lens_decoder", None), + max_enc_len_this_time, + kwargs.get("seq_lens_this_time", None), + kwargs.get("cum_offsets", None), + kwargs.get("encoder_block_shape_q", 64), + kwargs.get("decoder_block_shape_q", 16), + self.num_heads // self.kv_num_heads, + kwargs.get("block_size", 64), + ) + residual_input = src for i in range(self.num_layers): qkv_out, residual_input = self.compute_qkv(src, residual_input, i) @@ -1701,19 +1750,6 @@ def __init__(self, config: FusedMultiTransformerConfig): self.quant_min_bound = config.quant_min_bound # self.use_gemm_dequant = False - if self._dtype == "bfloat16": - self._fuse_kernel_compute_dtype = "bf16" - elif self._dtype == "float16": - self._fuse_kernel_compute_dtype = "fp16" - elif self._dtype == "float32": - self._fuse_kernel_compute_dtype = "fp32" - else: - raise ValueError( - "FusedMultiTransformer just support float32, float16 and bfloat16 as default dtype, but received {}".format( - self._dtype - ) - ) - self.qkv_out_scales = [] self.linear_out_scales = [] self.ffn1_out_scales = [] @@ -2056,6 +2092,8 @@ def compute_out_linear(self, fmha_out, i): out_linear_out = dequant_int8(out_linear_out, self.linear_out_scales[i], self._dtype) else: try: + from paddlenlp_ops import gemm_dequant + out_linear_out = gemm_dequant( fmha_out, self.linear_weights[i], self.linear_out_scales[i], self._dtype ) @@ -2115,6 +2153,8 @@ def compute_ffn2(self, ffn1_out, i): ffn2_out = dequant_int8(ffn2_out, self.ffn2_out_scales[i], self._dtype) else: try: + from paddlenlp_ops import gemm_dequant + ffn2_out = gemm_dequant(ffn1_out, self.ffn2_weights[i], self.ffn2_out_scales[i], self._dtype) except: ffn2_out = paddle.matmul(ffn1_out, self.ffn2_weights[i], False, True) @@ -2272,11 +2312,100 @@ def post_process(self, **kwargs): return out +class FusedAppendMultiTransformer(FusedMultiTransformerBase): + def __init__(self, config: FusedMultiTransformerConfig): + super().__init__(config) + + def compute_attn( + self, + time_step, + qkv, + padding_offset, + seq_lens, + input_ids, + rotary_embs, + rotary_emb_dims, + caches, + pre_caches, + pre_caches_length, + attn_mask, + i, + **kwargs, + ): + from paddlenlp_ops import append_attention + + fmha_out = append_attention( + qkv, + caches[2 * i], + caches[2 * i + 1], + kwargs.get("seq_lens_encoder", None), + kwargs.get("seq_lens_decoder", None), + kwargs.get("seq_lens_this_time", None), + kwargs.get("padding_offsets", None), + kwargs.get("cum_offsets", None), + kwargs.get("block_tables", None), + kwargs.get("encoder_batch_ids", None), + kwargs.get("encoder_tile_ids_per_batch", None), + kwargs.get("encoder_num_blocks", None), + kwargs.get("kv_batch_ids", None), + kwargs.get("kv_tile_ids_per_batch", None), + kwargs.get("kv_num_blocks", None), + kwargs.get("decoder_batch_ids", None), + kwargs.get("decoder_tile_ids_per_batch", None), + kwargs.get("decoder_num_blocks", None), + kwargs.get("max_enc_len_this_time", None), + kwargs.get("max_dec_len_this_time", None), + rotary_embs, + None, # attn_mask + None, # qkv_bias + None, # qkv_out_scales + None, # cache_k_quant_scales + None, # cache_v_quant_scales + None, # cache_k_dequant_scales + None, # cache_v_dequant_scales + None, # cache_k_zp + None, # cache_v_zp + None, # out_shifts + None, # out_smooths + self._fuse_kernel_compute_dtype, + "none", # cache_quant_type + self.use_neox_rotary_style, + kwargs.get("max_input_length", -1), + 0.0, # out_linear_in_scale + kwargs.get("encoder_block_shape_q", 64), + kwargs.get("decoder_block_shape_q", 16), + kwargs.get("max_partition_size", 32768), + kwargs.get("encoder_max_partition_size", 32768), + 5, # speculate_max_draft_token_num + True, # causal + True, # enable_prefill + )[0] + out_linear_out = self.compute_out_linear(fmha_out, i) + + return out_linear_out + + def post_process(self, **kwargs): + multi_block_output = kwargs.get("multi_block_output", None) + cum_offsets = kwargs.get("cum_offsets", None) + seq_lens_encoder = kwargs.get("seq_lens_encoder", None) + seq_lens_decoder = kwargs.get("seq_lens_decoder", None) + max_input_length = kwargs.get("max_input_length", -1) + + out = rebuild_padding_v2(multi_block_output, cum_offsets, seq_lens_decoder, seq_lens_encoder, max_input_length) + + return out + + class FusedBlockMultiTransformerWeightOnly(FusedBlockMultiTransformer, FusedMultiTransformerWeightOnly): def __init__(self, config: FusedMultiTransformerConfig): super().__init__(config) +class FusedAppendMultiTransformerWeightOnly(FusedAppendMultiTransformer, FusedMultiTransformerWeightOnly): + def __init__(self, config: FusedMultiTransformerConfig): + super().__init__(config) + + class FusedBlockMultiTransformerA8W8(FusedBlockMultiTransformer, FusedMultiTransformerA8W8): def __init__(self, config: FusedMultiTransformerConfig): super().__init__(config) @@ -2351,6 +2480,101 @@ def compute_attn( return out_linear_out +class FusedAppendMultiTransformerA8W8(FusedAppendMultiTransformer, FusedMultiTransformerA8W8): + def __init__(self, config: FusedMultiTransformerConfig): + super().__init__(config) + + def compute_attn( + self, + time_step, + qkv, + padding_offset, + seq_lens, + input_ids, + rotary_embs, + rotary_emb_dims, + caches, + pre_caches, + pre_caches_length, + attn_mask, + i, + **kwargs, + ): + k_quant_scales = kwargs.get("k_quant_scales", None) + v_quant_scales = kwargs.get("v_quant_scales", None) + k_dequant_scales = kwargs.get("k_dequant_scales", None) + v_dequant_scales = kwargs.get("v_dequant_scales", None) + cache_k_zps = kwargs.get("cache_k_zp", None) + cache_v_zps = kwargs.get("cache_v_zp", None) + + cache_quant_type_str = "none" + if self.config.cachekv_int8_type == "static": + k_quant_scales = self.cache_k_scales + v_quant_scales = self.cache_v_scales + k_dequant_scales = self.cache_k_out_scales + v_dequant_scales = self.cache_v_out_scales + cache_quant_type_str = "cache_int8" + + k_quant_scale = k_quant_scales[i] if k_quant_scales is not None else None + v_quant_scale = v_quant_scales[i] if v_quant_scales is not None else None + k_dequant_scale = k_dequant_scales[i] if k_dequant_scales is not None else None + v_dequant_scale = v_dequant_scales[i] if v_dequant_scales is not None else None + cache_k_zp = cache_k_zps[i] if cache_k_zps is not None else None + cache_v_zp = cache_v_zps[i] if cache_v_zps is not None else None + + from paddlenlp_ops import append_attention + + fmha_out = append_attention( + qkv, + caches[2 * i], + caches[2 * i + 1], + kwargs.get("seq_lens_encoder", None), + kwargs.get("seq_lens_decoder", None), + kwargs.get("seq_lens_this_time", None), + kwargs.get("padding_offsets", None), + kwargs.get("cum_offsets", None), + kwargs.get("block_tables", None), + kwargs.get("encoder_batch_ids", None), + kwargs.get("encoder_tile_ids_per_batch", None), + kwargs.get("encoder_num_blocks", None), + kwargs.get("kv_batch_ids", None), + kwargs.get("kv_tile_ids_per_batch", None), + kwargs.get("kv_num_blocks", None), + kwargs.get("decoder_batch_ids", None), + kwargs.get("decoder_tile_ids_per_batch", None), + kwargs.get("decoder_num_blocks", None), + kwargs.get("max_enc_len_this_time", None), + kwargs.get("max_dec_len_this_time", None), + rotary_embs, + None, # attn_mask + self.qkv_biases[i] if len(self.qkv_biases) > 0 else None, + self.qkv_out_scales[i], + k_quant_scale, + v_quant_scale, + k_dequant_scale, + v_dequant_scale, + cache_k_zp, + cache_v_zp, + self.linear_shifts[i] if len(self.linear_shifts) > 0 else None, + self.linear_smooths[i] if len(self.linear_smooths) > 0 else None, + self._fuse_kernel_compute_dtype, + cache_quant_type_str, + self.use_neox_rotary_style, + kwargs.get("max_input_length", -1), + self.act_scales["out_linear_in_scale"][i], + kwargs.get("encoder_block_shape_q", 64), + kwargs.get("decoder_block_shape_q", 16), + kwargs.get("max_partition_size", 32768), + kwargs.get("encoder_max_partition_size", 32768), + 5, # speculate_max_draft_token_num + True, # causal + True, # enable_prefill + )[0] + out_linear_out = self.compute_out_linear(fmha_out, i) + + return out_linear_out + + class FusedBlockMultiTransformerFP8(Layer): def __init__(self, config: FusedMultiTransformerConfig): """""" @@ -2371,8 +2595,6 @@ def __init__(self, config: FusedMultiTransformerConfig): self._dtype = self._helper.get_default_dtype() self._epsilon = config.epsilon self._residual_alpha = config.residual_alpha - self._trans_qkvw = config.trans_qkvw - self._ring_id = config.ring_id self.nranks = config.nranks self.norm_type = config.norm_type if self.norm_type == "layernorm": @@ -2556,12 +2778,16 @@ def __init__(self, config: FusedMultiTransformerConfig): is_bias=True, ) + cache_scale_dtype = "float32" + if self.config.append_attn: + cache_scale_dtype = paddle.get_default_dtype() + cache_k_scale = None if cache_k_scale_attr: cache_k_scale = self.create_parameter( shape=[self.kv_num_heads], attr=cache_k_scale_attr, - dtype="float32", + dtype=cache_scale_dtype, is_bias=False, ) @@ -2570,7 +2796,7 @@ def __init__(self, config: FusedMultiTransformerConfig): cache_v_scale = self.create_parameter( shape=[self.kv_num_heads], attr=cache_v_scale_attr, - dtype="float32", + dtype=cache_scale_dtype, is_bias=False, ) @@ -2579,7 +2805,7 @@ def __init__(self, config: FusedMultiTransformerConfig): cache_k_out_scale = self.create_parameter( shape=[self.kv_num_heads], attr=cache_k_out_scale_attr, - dtype="float32", + dtype=cache_scale_dtype, is_bias=False, ) @@ -2588,7 +2814,7 @@ def __init__(self, config: FusedMultiTransformerConfig): cache_v_out_scale = self.create_parameter( shape=[self.kv_num_heads], attr=cache_v_out_scale_attr, - dtype="float32", + dtype=cache_scale_dtype, is_bias=False, ) diff --git a/paddlenlp/experimental/transformers/llama/modeling.py b/paddlenlp/experimental/transformers/llama/modeling.py index d55db079a6e6..7be5249940c1 100644 --- a/paddlenlp/experimental/transformers/llama/modeling.py +++ b/paddlenlp/experimental/transformers/llama/modeling.py @@ -37,6 +37,9 @@ ) from paddlenlp.experimental.transformers.fused_transformer_layers import ( AvxConfig, + FusedAppendMultiTransformer, + FusedAppendMultiTransformerA8W8, + FusedAppendMultiTransformerWeightOnly, FusedBlockMultiTransformer, FusedBlockMultiTransformerA8W8, FusedBlockMultiTransformerFP8, @@ -633,6 +636,7 @@ def __init__(self, config: LlamaConfig): norm_type="rmsnorm", use_neox_rotary_style=self.use_neox, rank_id=config.tensor_parallel_rank, + append_attn=config.append_attn, ) else: @@ -680,6 +684,7 @@ def __init__(self, config: LlamaConfig): cachekv_int8_type=config.cachekv_int8_type, rank_id=config.tensor_parallel_rank, trans_qkvw=(False if paddle.is_compiled_with_rocm() and "a8w8" in self.quant_type else True), + append_attn=config.append_attn, ) self.set_transformer_block(transformer_config) @@ -876,6 +881,8 @@ def set_state_dict(self, state_dict): ffn_hidden_size=self.intermediate_size, num_key_value_heads=self.num_key_value_heads, mp_size=self.config.tensor_parallel_degree, + concat_qkv=True, + concat_ffn1=True, ) self.transformer_block.weight_scales = weight_scales_loader.scale self.transformer_block.act_scales = act_scale_loader.scale @@ -1097,16 +1104,24 @@ def set_state_dict(self, state_dict): dtype=paddle.get_default_dtype(), ) self.transformer_block.linear_shifts[idx].set_value( - paddle.to_tensor(state_dict["llama.layers.{}.self_attn.o_proj.shift_bias".format(idx)]) + paddle.to_tensor(state_dict["llama.layers.{}.self_attn.o_proj.shift_bias".format(idx)]).astype( + paddle.get_default_dtype() + ) ) self.transformer_block.linear_smooths[idx].set_value( - paddle.to_tensor(state_dict["llama.layers.{}.self_attn.o_proj.smooth_weight".format(idx)]) + paddle.to_tensor( + state_dict["llama.layers.{}.self_attn.o_proj.smooth_weight".format(idx)] + ).astype(paddle.get_default_dtype()) ) self.transformer_block.ffn2_shifts[idx].set_value( - paddle.to_tensor(state_dict["llama.layers.{}.mlp.down_proj.shift_bias".format(idx)]) + paddle.to_tensor(state_dict["llama.layers.{}.mlp.down_proj.shift_bias".format(idx)]).astype( + paddle.get_default_dtype() + ) ) self.transformer_block.ffn2_smooths[idx].set_value( - paddle.to_tensor(state_dict["llama.layers.{}.mlp.down_proj.smooth_weight".format(idx)]) + paddle.to_tensor(state_dict["llama.layers.{}.mlp.down_proj.smooth_weight".format(idx)]).astype( + paddle.get_default_dtype() + ) ) if self.shift: @@ -1233,7 +1248,10 @@ def set_state_dict(self, state_dict): for k, v in cache_scales_loader.scale.items(): for i_layer, weight_scale in enumerate(v): - weight_scale = weight_scale.astype("float32") + if self.config.append_attn: + weight_scale = paddle.to_tensor(weight_scale).cast(paddle.get_default_dtype()) + else: + weight_scale = weight_scale.astype("float32") if k == "cache_k_scale": self.transformer_block.cache_k_scales[i_layer].set_value(weight_scale) elif k == "cache_v_scale": @@ -1352,7 +1370,10 @@ def set_state_dict_fp8(self, state_dict: dict[str, np.ndarray | paddle.Tensor], ) for k, v in cache_scales_loader.scale.items(): for i_layer, weight_scale in enumerate(v): - weight_scale = weight_scale.astype("float32") + if self.config.append_attn: + weight_scale = paddle.to_tensor(weight_scale).cast(paddle.get_default_dtype()) + else: + weight_scale = weight_scale.astype("float32") if k == "cache_k_scale": self.transformer_block.cache_k_scales[i_layer].set_value(weight_scale) elif k == "cache_v_scale": @@ -1576,19 +1597,30 @@ def _set_var(var, ndarray): @register_base_model class LlamaBlockInferenceModel(LlamaInferenceModel): def __init__(self, config: LlamaConfig): + self.append_attn = config.append_attn super().__init__(config) self.max_seq_len = config.max_seq_len self.block_size = config.block_size def set_transformer_block(self, transformer_config): - if self.use_weight_only: - self.transformer_block = FusedBlockMultiTransformerWeightOnly(transformer_config) - elif self.quant_type == "a8w8" or self.quant_type == "a8w8c8": - self.transformer_block = FusedBlockMultiTransformerA8W8(transformer_config) - elif "fp8" in self.quant_type: - self.transformer_block = FusedBlockMultiTransformerFP8(transformer_config) + if not self.append_attn: + if self.use_weight_only: + self.transformer_block = FusedBlockMultiTransformerWeightOnly(transformer_config) + elif self.quant_type == "a8w8" or self.quant_type == "a8w8c8": + self.transformer_block = FusedBlockMultiTransformerA8W8(transformer_config) + elif "fp8" in self.quant_type: + self.transformer_block = FusedBlockMultiTransformerFP8(transformer_config) + else: + self.transformer_block = FusedBlockMultiTransformer(transformer_config) else: - self.transformer_block = FusedBlockMultiTransformer(transformer_config) + if self.use_weight_only: + self.transformer_block = FusedAppendMultiTransformerWeightOnly(transformer_config) + elif self.quant_type == "a8w8" or self.quant_type == "a8w8c8": + self.transformer_block = FusedAppendMultiTransformerA8W8(transformer_config) + # elif "fp8" in self.quant_type: + # self.transformer_block = FusedAppendMultiTransformerFP8(transformer_config) + else: + self.transformer_block = FusedAppendMultiTransformer(transformer_config) def remove_padding(self, input_ids, seq_lens_this_time): cum_offsets_now = paddle.cumsum(self.max_seq_len - seq_lens_this_time) diff --git a/paddlenlp/experimental/transformers/mixtral/modeling.py b/paddlenlp/experimental/transformers/mixtral/modeling.py index d8ba9198394b..2ff03559d374 100644 --- a/paddlenlp/experimental/transformers/mixtral/modeling.py +++ b/paddlenlp/experimental/transformers/mixtral/modeling.py @@ -29,8 +29,9 @@ WeightScalesLoader, ) from paddlenlp.experimental.transformers.fused_transformer_layers import ( + FusedAppendMultiTransformer, + FusedAppendMultiTransformerWeightOnly, FusedBlockMultiTransformer, - FusedBlockMultiTransformerA8W8, FusedBlockMultiTransformerWeightOnly, FusedMultiTransformerA8W8, FusedMultiTransformerBase, @@ -341,6 +342,7 @@ def __init__(self, config: MixtralConfig): rank_id=config.tensor_parallel_rank, trans_qkvw=(False if paddle.is_compiled_with_rocm() and "a8w8" in self.quant_type else True), moe_config=moe_config, + append_attn=config.append_attn, ) self.set_transformer_block(transformer_config) @@ -494,6 +496,7 @@ def forward( @paddle.no_grad() def set_state_dict(self, state_dict): + self.transformer_block.init_weight() unfused_state_dict = {} head_size = self.hidden_size // self.num_attention_heads split_fn = split_param_func() @@ -840,7 +843,10 @@ def set_state_dict(self, state_dict): ) for k, v in cache_scales_loader.scale.items(): for i_layer, weight_scale in enumerate(v): - weight_scale = weight_scale.astype("float32") + if self.config.append_attn: + weight_scale = paddle.to_tensor(weight_scale).cast(paddle.get_default_dtype()) + else: + weight_scale = weight_scale.astype("float32") if k == "cache_k_scale": self.transformer_block.cache_k_scales[i_layer].set_value(weight_scale) elif k == "cache_v_scale": @@ -1062,17 +1068,22 @@ def set_state_dict(self, state_dict): @register_base_model class MixtralBlockInferenceModel(MixtralInferenceModel): def __init__(self, config: MixtralConfig): + self.append_attn = config.append_attn super().__init__(config) self.max_seq_len = config.max_seq_len self.block_size = config.block_size def set_transformer_block(self, transformer_config): - if self.use_weight_only: - self.transformer_block = FusedBlockMultiTransformerWeightOnly(transformer_config) - elif "a8w8" in self.quant_type: - self.transformer_block = FusedBlockMultiTransformerA8W8(transformer_config) + if not self.append_attn: + if self.use_weight_only: + self.transformer_block = FusedBlockMultiTransformerWeightOnly(transformer_config) + else: + self.transformer_block = FusedBlockMultiTransformer(transformer_config) else: - self.transformer_block = FusedBlockMultiTransformer(transformer_config) + if self.use_weight_only: + self.transformer_block = FusedAppendMultiTransformerWeightOnly(transformer_config) + else: + self.transformer_block = FusedAppendMultiTransformer(transformer_config) def remove_padding(self, input_ids, seq_lens_this_time): cum_offsets_now = paddle.cumsum(self.max_seq_len - seq_lens_this_time) diff --git a/paddlenlp/experimental/transformers/qwen2/modeling.py b/paddlenlp/experimental/transformers/qwen2/modeling.py index b8b748ac5991..5e2dc739e2c7 100644 --- a/paddlenlp/experimental/transformers/qwen2/modeling.py +++ b/paddlenlp/experimental/transformers/qwen2/modeling.py @@ -36,6 +36,9 @@ WeightScalesLoader, ) from paddlenlp.experimental.transformers.fused_transformer_layers import ( + FusedAppendMultiTransformer, + FusedAppendMultiTransformerA8W8, + FusedAppendMultiTransformerWeightOnly, FusedBlockMultiTransformer, FusedBlockMultiTransformerA8W8, FusedBlockMultiTransformerFP8, @@ -337,6 +340,7 @@ def __init__(self, config: Qwen2Config): norm_type="rmsnorm", use_neox_rotary_style=self.use_neox, rank_id=config.tensor_parallel_rank, + append_attn=config.append_attn, ) else: @@ -384,6 +388,7 @@ def __init__(self, config: Qwen2Config): cachekv_int8_type=config.cachekv_int8_type, rank_id=config.tensor_parallel_rank, trans_qkvw=(False if paddle.is_compiled_with_rocm() and "a8w8" in self.quant_type else True), + append_attn=config.append_attn, ) self.set_transformer_block(transformer_config) @@ -453,6 +458,8 @@ def set_state_dict(self, state_dict): ffn_hidden_size=self.intermediate_size, num_key_value_heads=self.num_key_value_heads, mp_size=self.config.tensor_parallel_degree, + concat_qkv=True, + concat_ffn1=True, ) self.transformer_block.weight_scales = weight_scales_loader.scale self.transformer_block.act_scales = act_scale_loader.scale @@ -466,6 +473,7 @@ def set_state_dict(self, state_dict): self.norm.weight.set_value(paddle.to_tensor(state_dict["qwen2.norm.weight"]).cast(self.norm.weight.dtype)) for idx in range(self.num_layers): + logger.info(f"set state for layer {idx}") unfused_state_dict = {} ln_scale = paddle.to_tensor(state_dict["qwen2.layers.{}.input_layernorm.weight".format(idx)]).cast( self.transformer_block.ln_scales[idx].dtype @@ -704,16 +712,24 @@ def set_state_dict(self, state_dict): dtype=paddle.get_default_dtype(), ) self.transformer_block.linear_shifts[idx].set_value( - paddle.to_tensor(state_dict["qwen2.layers.{}.self_attn.o_proj.shift_bias".format(idx)]) + paddle.to_tensor(state_dict["qwen2.layers.{}.self_attn.o_proj.shift_bias".format(idx)]).astype( + paddle.get_default_dtype() + ) ) self.transformer_block.linear_smooths[idx].set_value( - paddle.to_tensor(state_dict["qwen2.layers.{}.self_attn.o_proj.smooth_weight".format(idx)]) + paddle.to_tensor( + state_dict["qwen2.layers.{}.self_attn.o_proj.smooth_weight".format(idx)] + ).astype(paddle.get_default_dtype()) ) self.transformer_block.ffn2_shifts[idx].set_value( - paddle.to_tensor(state_dict["qwen2.layers.{}.mlp.down_proj.shift_bias".format(idx)]) + paddle.to_tensor(state_dict["qwen2.layers.{}.mlp.down_proj.shift_bias".format(idx)]).astype( + paddle.get_default_dtype() + ) ) self.transformer_block.ffn2_smooths[idx].set_value( - paddle.to_tensor(state_dict["qwen2.layers.{}.mlp.down_proj.smooth_weight".format(idx)]) + paddle.to_tensor(state_dict["qwen2.layers.{}.mlp.down_proj.smooth_weight".format(idx)]).astype( + paddle.get_default_dtype() + ) ) if self.shift: @@ -825,7 +841,10 @@ def set_state_dict(self, state_dict): for k, v in cache_scales_loader.scale.items(): for i_layer, weight_scale in enumerate(v): - weight_scale = weight_scale.astype("float32") + if self.config.append_attn: + weight_scale = paddle.to_tensor(weight_scale).cast(paddle.get_default_dtype()) + else: + weight_scale = weight_scale.astype("float32") if k == "cache_k_scale": self.transformer_block.cache_k_scales[i_layer].set_value(weight_scale) elif k == "cache_v_scale": @@ -946,7 +965,10 @@ def set_state_dict_fp8(self, state_dict: dict[str, np.ndarray | paddle.Tensor], ) for k, v in cache_scales_loader.scale.items(): for i_layer, weight_scale in enumerate(v): - weight_scale = weight_scale.astype("float32") + if self.config.append_attn: + weight_scale = paddle.to_tensor(weight_scale).cast(paddle.get_default_dtype()) + else: + weight_scale = weight_scale.astype("float32") if k == "cache_k_scale": self.transformer_block.cache_k_scales[i_layer].set_value(weight_scale) elif k == "cache_v_scale": @@ -1492,19 +1514,30 @@ def set_state_dict(self, state_dict): @register_base_model class Qwen2BlockInferenceModel(Qwen2InferenceModel): def __init__(self, config: Qwen2Config): + self.append_attn = config.append_attn super().__init__(config) self.max_seq_len = config.max_seq_len self.block_size = config.block_size def set_transformer_block(self, transformer_config): - if self.use_weight_only: - self.transformer_block = FusedBlockMultiTransformerWeightOnly(transformer_config) - elif self.quant_type == "a8w8" or self.quant_type == "a8w8c8": - self.transformer_block = FusedBlockMultiTransformerA8W8(transformer_config) - elif "fp8" in self.quant_type: - self.transformer_block = FusedBlockMultiTransformerFP8(transformer_config) + if not self.append_attn: + if self.use_weight_only: + self.transformer_block = FusedBlockMultiTransformerWeightOnly(transformer_config) + elif self.quant_type == "a8w8" or self.quant_type == "a8w8c8": + self.transformer_block = FusedBlockMultiTransformerA8W8(transformer_config) + elif "fp8" in self.quant_type: + self.transformer_block = FusedBlockMultiTransformerFP8(transformer_config) + else: + self.transformer_block = FusedBlockMultiTransformer(transformer_config) else: - self.transformer_block = FusedBlockMultiTransformer(transformer_config) + if self.use_weight_only: + self.transformer_block = FusedAppendMultiTransformerWeightOnly(transformer_config) + elif self.quant_type == "a8w8" or self.quant_type == "a8w8c8": + self.transformer_block = FusedAppendMultiTransformerA8W8(transformer_config) + # elif "fp8" in self.quant_type: + # self.transformer_block = FusedAppendMultiTransformerFP8(transformer_config) + else: + self.transformer_block = FusedAppendMultiTransformer(transformer_config) def remove_padding(self, input_ids, seq_lens_this_time): cum_offsets_now = paddle.cumsum(self.max_seq_len - seq_lens_this_time) diff --git a/paddlenlp/experimental/transformers/qwen2_moe/modeling.py b/paddlenlp/experimental/transformers/qwen2_moe/modeling.py index 4b5d0b7469dd..0f010c20738c 100644 --- a/paddlenlp/experimental/transformers/qwen2_moe/modeling.py +++ b/paddlenlp/experimental/transformers/qwen2_moe/modeling.py @@ -22,6 +22,8 @@ from paddle.nn.quant import weight_quantize from paddlenlp.experimental.transformers.fused_transformer_layers import ( + FusedAppendMultiTransformer, + FusedAppendMultiTransformerWeightOnly, FusedBlockMultiTransformer, FusedBlockMultiTransformerWeightOnly, FusedMultiTransformerBase, @@ -256,6 +258,7 @@ def __init__(self, config: Qwen2MoeConfig): use_neox_rotary_style=self.use_neox, rank_id=config.tensor_parallel_rank, moe_config=moe_config, + append_attn=config.append_attn, ) self.set_transformer_block(transformer_config) @@ -761,15 +764,22 @@ def set_state_dict(self, state_dict): @register_base_model class Qwen2MoeBlockInferenceModel(Qwen2MoeInferenceModel): def __init__(self, config: Qwen2MoeConfig): + self.append_attn = config.append_attn super().__init__(config) self.max_seq_len = config.max_seq_len self.block_size = config.block_size def set_transformer_block(self, transformer_config): - if self.use_weight_only: - self.transformer_block = FusedBlockMultiTransformerWeightOnly(transformer_config) + if not self.append_attn: + if self.use_weight_only: + self.transformer_block = FusedBlockMultiTransformerWeightOnly(transformer_config) + else: + self.transformer_block = FusedBlockMultiTransformer(transformer_config) else: - self.transformer_block = FusedBlockMultiTransformer(transformer_config) + if self.use_weight_only: + self.transformer_block = FusedAppendMultiTransformerWeightOnly(transformer_config) + else: + self.transformer_block = FusedAppendMultiTransformer(transformer_config) def remove_padding(self, input_ids, seq_lens_this_time): cum_offsets_now = paddle.cumsum(self.max_seq_len - seq_lens_this_time) diff --git a/paddlenlp/experimental/transformers/utils.py b/paddlenlp/experimental/transformers/utils.py index d24904c1f31b..8ca0a5cafaf1 100644 --- a/paddlenlp/experimental/transformers/utils.py +++ b/paddlenlp/experimental/transformers/utils.py @@ -108,6 +108,8 @@ def __init__( ffn_hidden_size, num_key_value_heads=-1, mp_size=1, + concat_qkv=False, + concat_ffn1=False, ): self.key_map = key_map_dict self.scale = {} @@ -126,6 +128,13 @@ def __init__( n = num_head * dim_head self.scale[scale_type] = np.full([num_of_layers, n], fill_value=0.1, dtype="float32") + # concat qkv and ffn1 + if concat_qkv: + self.scale["qkv_weight_scale"] = np.full([num_of_layers, qkv_out_size // mp_size], fill_value=0.1, dtype="float32") + + if concat_ffn1: + self.scale["ffn1_weight_scale"]= np.full([num_of_layers, ffn_hidden_size * 2 // mp_size], fill_value=0.1, dtype="float32") + class EmptyCacheScale: """