Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] support deepseek-v3 #9769

Draft
wants to merge 39 commits into
base: develop
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
f3c1336
support deepseek-v3
yuanlehome Jan 13, 2025
09f94b9
support head_dim=192,256 for append_attn c16
lizhenyun01 Jan 13, 2025
dbfd416
update 0113
yuanlehome Jan 13, 2025
f1894f2
Merge pull request #2 from lizhenyun01/append_attn_headdim
yuanlehome Jan 14, 2025
e475e3b
attention run
yuanlehome Jan 14, 2025
c9b33cd
refine code
yuanlehome Jan 14, 2025
dc92a3b
add softmax_scale
yuanlehome Jan 14, 2025
6d59b51
support weight_only_int8
yuanlehome Jan 14, 2025
07c6bb2
refine code
yuanlehome Jan 14, 2025
73ad324
support tp
yuanlehome Jan 14, 2025
cbe3623
delete test_append_attn
yuanlehome Jan 14, 2025
f5d5d24
add splited fused_moe from ziyuan
yuanlehome Jan 15, 2025
9657580
fix repe for deepseek-v3
lizhenyun01 Jan 15, 2025
feabdb8
add deepseek-v3 class
yuanlehome Jan 15, 2025
acc025f
Merge pull request #4 from lizhenyun01/fix_rope
yuanlehome Jan 15, 2025
583e17f
fix wint8 precision and refine code
yuanlehome Jan 15, 2025
69823b6
fix wint4, big diff
yuanlehome Jan 16, 2025
2877f2c
add e_score_correction_bias
yuanlehome Jan 16, 2025
9fd7a1d
fix head_dim
yuanlehome Jan 17, 2025
abb63f0
fix v3 verify
yuanlehome Jan 19, 2025
ac0b93e
Merge branch 'develop' of https://github.com/PaddlePaddle/PaddleNLP i…
yuanlehome Jan 20, 2025
47fd938
fix d2s
yuanlehome Jan 20, 2025
57f6dce
fix v3 verify
yuanlehome Jan 21, 2025
f1cac13
support qk_head_dim != v_head_dim
yuanlehome Jan 21, 2025
271be86
fix wint8 v_head_dim
yuanlehome Jan 21, 2025
d36f5bb
fix rope
lizhenyun01 Jan 21, 2025
d6e6dff
Merge pull request #7 from lizhenyun01/fix_rope
yuanlehome Jan 21, 2025
d9283ca
fix qwen2
yuanlehome Jan 21, 2025
239564b
mla use position_ids only
yuanlehome Jan 21, 2025
f86e4f7
remove control flow
yuanlehome Jan 21, 2025
c2007ee
remove gpu concat
yuanlehome Jan 21, 2025
1e022b0
fix norm weight dtype
yuanlehome Jan 21, 2025
3ba4a11
remove all_reduce in fused_moe
yuanlehome Jan 22, 2025
0c79260
Merge branch 'develop' of https://github.com/PaddlePaddle/PaddleNLP i…
yuanlehome Jan 23, 2025
94bbe7a
fix static run
yuanlehome Jan 23, 2025
d76e357
refine rope code
yuanlehome Jan 24, 2025
a5a16d4
compute position_ids use custom op
yuanlehome Jan 24, 2025
a8f3839
Merge branch 'develop' of https://github.com/PaddlePaddle/PaddleNLP i…
yuanlehome Jan 24, 2025
9621f28
fuse rope
yuanlehome Jan 26, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 27 additions & 10 deletions csrc/gpu/append_attention.cu
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
const std::string& cache_quant_type_str,
const bool use_neox_rotary_style,
const int max_input_length,
const float softmax_scale,
const float quant_max_bound,
const float quant_min_bound,
const float out_linear_in_scale,
Expand Down Expand Up @@ -97,21 +98,21 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
if (out_linear_in_scale > 0.0) {
if (fabs(quant_max_bound - 127.0f) < 0.000001) {
fmha_out = GetEmptyTensor(
{meta_data.token_nums, meta_data.q_num_heads * meta_data.head_dims},
{meta_data.token_nums, meta_data.q_num_heads * meta_data.head_dims_v},
paddle::DataType::INT8,
qkv.place());
}
else if (fabs(quant_max_bound - 448.0f) < 0.000001) {
fmha_out = GetEmptyTensor(
{meta_data.token_nums, meta_data.q_num_heads * meta_data.head_dims},
{meta_data.token_nums, meta_data.q_num_heads * meta_data.head_dims_v},
paddle::DataType::FLOAT8_E4M3FN,
qkv.place());
}else{
PD_THROW("Only supported attr of quant_max_bound in ['127.0', '448.0'].");
}
} else {
fmha_out = GetEmptyTensor(
{meta_data.token_nums, meta_data.q_num_heads * meta_data.head_dims},
{meta_data.token_nums, meta_data.q_num_heads * meta_data.head_dims_v},
D,
qkv.place());
}
Expand Down Expand Up @@ -203,6 +204,7 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
encoder_block_shape_q,
max_input_length,
max_enc_len_this_time_data,
softmax_scale,
quant_max_bound,
quant_min_bound,
out_linear_in_scale,
Expand Down Expand Up @@ -240,6 +242,7 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
encoder_block_shape_q,
max_input_length,
max_enc_len_this_time_data,
softmax_scale,
quant_max_bound,
quant_min_bound,
out_linear_in_scale,
Expand Down Expand Up @@ -282,6 +285,7 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
encoder_block_shape_q,
max_input_length,
max_enc_len_this_time_data,
softmax_scale,
quant_max_bound,
quant_min_bound,
out_linear_in_scale,
Expand Down Expand Up @@ -428,6 +432,7 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
decoder_block_shape_q,
max_input_length,
max_len_kv_data,
softmax_scale,
quant_max_bound,
quant_min_bound,
out_linear_in_scale,
Expand Down Expand Up @@ -465,6 +470,7 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
decoder_block_shape_q,
max_input_length,
max_len_kv_data,
softmax_scale,
quant_max_bound,
quant_min_bound,
out_linear_in_scale,
Expand Down Expand Up @@ -508,6 +514,7 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
decoder_block_shape_q,
max_input_length,
max_len_kv_data,
softmax_scale,
quant_max_bound,
quant_min_bound,
out_linear_in_scale,
Expand Down Expand Up @@ -565,6 +572,7 @@ std::vector<paddle::Tensor> AppendAttention(
const std::string& cache_quant_type_str,
const bool use_neox_rotary_style,
const int max_input_length,
const float softmax_scale,
const float quant_max_bound,
const float quant_min_bound,
const float out_linear_in_scale,
Expand All @@ -578,9 +586,10 @@ std::vector<paddle::Tensor> AppendAttention(
meta_data.token_nums = qkv_dims[0];
meta_data.kv_num_heads = key_cache_dims[1];
meta_data.head_dims = key_cache_dims[3];
const int total_num_head =
qkv_dims[qkv_dims.size() - 1] / meta_data.head_dims;
meta_data.q_num_heads = total_num_head - 2 * meta_data.kv_num_heads;
meta_data.head_dims_v = value_cache.dims()[3];
const int q_hidden_size =
qkv_dims[qkv_dims.size() - 1] - meta_data.kv_num_heads * (meta_data.head_dims + meta_data.head_dims_v);
meta_data.q_num_heads = q_hidden_size / meta_data.head_dims;

meta_data.max_blocks_per_seq = block_tables.dims()[1];
meta_data.block_size = key_cache.dims()[2];
Expand Down Expand Up @@ -626,6 +635,7 @@ std::vector<paddle::Tensor> AppendAttention(
cache_quant_type_str,
use_neox_rotary_style,
max_input_length,
softmax_scale,
quant_max_bound,
quant_min_bound,
out_linear_in_scale,
Expand Down Expand Up @@ -672,6 +682,7 @@ std::vector<paddle::Tensor> AppendAttention(
cache_quant_type_str,
use_neox_rotary_style,
max_input_length,
softmax_scale,
quant_max_bound,
quant_min_bound,
out_linear_in_scale,
Expand Down Expand Up @@ -719,6 +730,7 @@ std::vector<paddle::Tensor> AppendAttention(
cache_quant_type_str,
use_neox_rotary_style,
max_input_length,
softmax_scale,
quant_max_bound,
quant_min_bound,
out_linear_in_scale,
Expand Down Expand Up @@ -764,6 +776,7 @@ std::vector<paddle::Tensor> AppendAttention(
cache_quant_type_str,
use_neox_rotary_style,
max_input_length,
softmax_scale,
quant_max_bound,
quant_min_bound,
out_linear_in_scale,
Expand Down Expand Up @@ -821,10 +834,12 @@ std::vector<std::vector<int64_t>> AppendAttentionInferShape(
const paddle::optional<std::vector<int64_t>>& out_linear_smooths_shape) {
const int token_num = qkv_shape[0];
const int kv_num_heads = key_cache_shape[1];
const int head_dim = key_cache_shape[3];
const int total_num_head = qkv_shape[qkv_shape.size() - 1] / head_dim;
const int num_heads = total_num_head - 2 * kv_num_heads;
return {{token_num, num_heads * head_dim}, qkv_shape};
const int head_dim_qk = key_cache_shape[3];
const int head_dim_v = value_cache_shape[3];
const int q_hidden_size =
qkv_shape[qkv_shape.size() - 1] - kv_num_heads * (head_dim_qk + head_dim_v);
const int num_heads = q_hidden_size / head_dim_qk;
return {{token_num, num_heads * head_dim_v}, qkv_shape};
}

std::vector<paddle::DataType> AppendAttentionInferDtype(
Expand Down Expand Up @@ -865,6 +880,7 @@ std::vector<paddle::DataType> AppendAttentionInferDtype(
const std::string& cache_quant_type_str,
const bool use_neox_rotary_style,
const int max_input_length,
const float softmax_scale,
const float quant_max_bound,
const float quant_min_bound,
const float out_linear_in_scale,
Expand Down Expand Up @@ -941,6 +957,7 @@ PD_BUILD_OP(append_attention)
"cache_quant_type: std::string",
"use_neox_rotary_style: bool",
"max_input_length: int",
"softmax_scale: float",
"quant_max_bound: float",
"quant_min_bound: float",
"out_linear_in_scale: float",
Expand Down
Loading