-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Paddle Inference] Add masked multihead attention kernel and export API. #55344
Changes from all commits
bcff596
ac5ca55
ff0dd45
057023d
3dec850
29f3f9e
e9c6ee3
199aca4
2da0edd
b31fe6c
2c2697e
371268f
0ce5b0e
ad3c325
edc8c5d
42282f8
70ff789
601a9d3
2106c1a
6f57cb7
21f6c84
8041fad
15b2fd5
58b358c
bf750ec
37b4d72
b9907cc
5470349
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1616,6 +1616,17 @@ | |
data_type : logits | ||
backward : margin_cross_entropy_grad | ||
|
||
- op : masked_multihead_attention_ | ||
args : (Tensor x, Tensor cache_kv, Tensor src_mask, Tensor cum_offsets, Tensor sequence_lengths, Tensor rotary_tensor, Tensor beam_cache_offset, Tensor qkv_out_scale, Tensor out_shift, Tensor out_smooth, int seq_len, int rotary_emb_dims, bool use_neox_rotary_style=false, float out_scale=-1, int quant_round_type=1, float quant_max_bound=127.0, float quant_min_bound=-127.0) | ||
output : Tensor(out), Tensor(cache_kv_out), Tensor(beam_cache_offset_out) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. out_linear_in_scale、out_linear_shift、out_linear_smooth这些变量都是在标准的attention之外融合的部分,需要添加对输入的说明吧,或者名字可以考虑换一下 |
||
infer_meta : | ||
func : MaskedMultiheadAttentionInferMeta | ||
kernel : | ||
func : masked_multihead_attention | ||
data_type : cache_kv | ||
optional : src_mask, cum_offsets, sequence_lengths, rotary_tensor, beam_cache_offset, qkv_out_scale, out_shift, out_smooth | ||
inplace : (cache_kv -> cache_kv_out), (beam_cache_offset -> beam_cache_offset_out) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should add OP(eg. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. masked_multihead_attention 目前只用于推理,反向后续是否添加需要再讨论 |
||
|
||
- op : masked_select | ||
args : (Tensor x, Tensor mask) | ||
output : Tensor (out) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3983,5 +3983,69 @@ void WeightOnlyMatmulInferMeta(const MetaTensor& x, | |
out->set_dtype(x.dtype()); | ||
} | ||
|
||
void MaskedMultiheadAttentionInferMeta(const MetaTensor& x, | ||
const MetaTensor& cache_kv, | ||
const MetaTensor& src_mask, | ||
const MetaTensor& cum_offsets, | ||
const MetaTensor& sequence_lengths, | ||
const MetaTensor& rotary_tensor, | ||
const MetaTensor& beam_cache_offset, | ||
const MetaTensor& qkv_out_scale, | ||
const MetaTensor& out_shift, | ||
const MetaTensor& out_smooth, | ||
int seq_len, | ||
int rotary_emb_dims, | ||
const bool use_neox_rotary_style, | ||
const float out_scale, | ||
const int quant_round_type, | ||
const float quant_max_bound, | ||
const float quant_min_bound, | ||
MetaTensor* out, | ||
MetaTensor* cache_kv_out, | ||
MetaTensor* beam_cache_offset_out) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. beam_cache_offset的输出也不知道含义 |
||
int bsz = x.dims()[0]; | ||
auto x_dtype = x.dtype(); | ||
auto cache_kv_dims = cache_kv.dims(); | ||
int num_head = cache_kv.dims()[2]; | ||
int dim_head = cache_kv.dims()[4]; | ||
|
||
PADDLE_ENFORCE_EQ( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 感觉要check下 x的seq_len那一维必须为1 |
||
cache_kv_dims.size(), | ||
5, | ||
errors::InvalidArgument("The cache_kv must be 5 dims, but got %d", | ||
cache_kv_dims.size())); | ||
|
||
PADDLE_ENFORCE_EQ( | ||
cache_kv_dims[0], | ||
2, | ||
errors::InvalidArgument("The first dim of cache_kv must be 2, but got %d", | ||
cache_kv_dims[0])); | ||
|
||
if (rotary_tensor) { | ||
PADDLE_ENFORCE_EQ( | ||
rotary_tensor.dtype(), | ||
DataType::FLOAT32, | ||
errors::InvalidArgument( | ||
"The dtype of rotary_tensor must be float32, but got %d", | ||
rotary_tensor.dtype())); | ||
} | ||
|
||
out->set_dims({bsz, num_head * dim_head}); | ||
|
||
if (out_scale > 0) { | ||
out->set_dtype(DataType::INT8); | ||
} else { | ||
out->set_dtype(x_dtype); | ||
} | ||
|
||
cache_kv_out->set_dims(cache_kv_dims); | ||
cache_kv_out->set_dtype(cache_kv.dtype()); | ||
|
||
if (beam_cache_offset) { | ||
beam_cache_offset_out->set_dims(beam_cache_offset.dims()); | ||
beam_cache_offset_out->set_dtype(beam_cache_offset.dtype()); | ||
} | ||
} | ||
|
||
} // namespace phi | ||
PD_REGISTER_INFER_META_FN(batch_norm_infer, phi::BatchNormInferInferMeta); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
感觉叫decoder_masked_multihead_attention_比较合适
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
because of Integrating rotary_embedding, attention and etc, it should be named
fused_masked_multihead_attention
according to regulations