-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Paddle Inference] Add masked multihead attention kernel and export API. #55344
Changes from 5 commits
bcff596
ac5ca55
ff0dd45
057023d
3dec850
29f3f9e
e9c6ee3
199aca4
2da0edd
b31fe6c
2c2697e
371268f
0ce5b0e
ad3c325
edc8c5d
42282f8
70ff789
601a9d3
2106c1a
6f57cb7
21f6c84
8041fad
15b2fd5
58b358c
bf750ec
37b4d72
b9907cc
5470349
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1541,6 +1541,17 @@ | |
data_type : logits | ||
backward : margin_cross_entropy_grad | ||
|
||
- op : masked_multihead_attention_ | ||
args : (Tensor x, Tensor bias, Tensor src_mask, Tensor sequence_lengths, Tensor rotary_tensor, Tensor beam_cache_offset, Tensor cache_kv, Tensor qkv_out_scale, Tensor out_linear_shift, Tensor out_linear_smooth, int beam_size, int rotary_emb_dims, bool mask_broadcast_num_heads=true, bool compute_bias=false, bool use_neox_rotary_style=false, float out_linear_in_scale=-1, int quant_round_type=1, float quant_max_bound=127.0, float quant_min_bound=-127.0) | ||
output : Tensor(out), Tensor(cache_kv_out), Tensor(beam_cache_offset_out) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. out_linear_in_scale、out_linear_shift、out_linear_smooth这些变量都是在标准的attention之外融合的部分,需要添加对输入的说明吧,或者名字可以考虑换一下 |
||
infer_meta : | ||
func : MaskedMultiheadAttentionInferMeta | ||
kernel : | ||
func : masked_multihead_attention | ||
data_type : cache_kv | ||
optional : bias, src_mask, sequence_lengths, rotary_tensor, beam_cache_offset, qkv_out_scale, out_linear_shift, out_linear_smooth | ||
inplace : (cache_kv -> cache_kv_out), (beam_cache_offset -> beam_cache_offset_out) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should add OP(eg. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. masked_multihead_attention 目前只用于推理,反向后续是否添加需要再讨论 |
||
|
||
- op : masked_select | ||
args : (Tensor x, Tensor mask) | ||
output : Tensor (out) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3624,5 +3624,73 @@ void WeightOnlyMatmulInferMeta(const MetaTensor& x, | |
out->set_dtype(x.dtype()); | ||
} | ||
|
||
void MaskedMultiheadAttentionInferMeta(const MetaTensor& x, | ||
const MetaTensor& bias, | ||
const MetaTensor& src_mask, | ||
const MetaTensor& sequence_lengths, | ||
const MetaTensor& rotary_tensor, | ||
const MetaTensor& beam_cache_offset, | ||
const MetaTensor& cache_kv, | ||
const MetaTensor& qkv_out_scale, | ||
const MetaTensor& out_linear_shift, | ||
const MetaTensor& out_linear_smooth, | ||
int beam_size, | ||
int rotary_emb_dims, | ||
const bool mask_broadcast_num_heads, | ||
const bool compute_bias, | ||
const bool use_neox_rotary_style, | ||
const float out_linear_in_scale, | ||
const int quant_round_type, | ||
const float quant_max_bound, | ||
const float quant_min_bound, | ||
MetaTensor* out, | ||
MetaTensor* cache_kv_out, | ||
MetaTensor* beam_cache_offset_out) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. beam_cache_offset的输出也不知道含义 |
||
auto x_dims = x.dims(); | ||
auto cache_kv_dims = cache_kv.dims(); | ||
auto x_dtype = x.dtype(); | ||
int bsz = x_dims[0]; | ||
int num_head = x_dims[2]; | ||
int dim_head = x_dims[3]; | ||
|
||
if (sequence_lengths) { | ||
out->set_dims({bsz, num_head, dim_head}); | ||
} else { | ||
out->set_dims({bsz, 1, num_head, dim_head}); | ||
} | ||
if (out_linear_in_scale > 0) { | ||
out->set_dtype(DataType::INT8); | ||
} else { | ||
out->set_dtype(x_dtype); | ||
} | ||
|
||
PADDLE_ENFORCE_EQ( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 感觉要check下 x的seq_len那一维必须为1 |
||
x_dims.size(), | ||
4, | ||
errors::InvalidArgument("The dimensions of x must be 4" | ||
"(batch_size, 3, num_head, dim_head)," | ||
"but received dimensions of" | ||
"Input is [%d]", | ||
x_dims.size())); | ||
PADDLE_ENFORCE_EQ( | ||
cache_kv_dims.size(), | ||
5, | ||
errors::InvalidArgument("The cache_kv must be 5 dims, but got %d", | ||
cache_kv_dims.size())); | ||
PADDLE_ENFORCE_EQ( | ||
cache_kv_dims[0], | ||
2, | ||
errors::InvalidArgument("The first dim of cache_kv must be 2, but got %d", | ||
cache_kv_dims[0])); | ||
|
||
cache_kv_out->set_dims(cache_kv_dims); | ||
cache_kv_out->set_dtype(cache_kv.dtype()); | ||
|
||
if (beam_cache_offset) { | ||
beam_cache_offset_out->set_dims(beam_cache_offset.dims()); | ||
beam_cache_offset_out->set_dtype(beam_cache_offset.dtype()); | ||
} | ||
} | ||
|
||
} // namespace phi | ||
PD_REGISTER_INFER_META_FN(batch_norm_infer, phi::BatchNormInferInferMeta); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
感觉叫decoder_masked_multihead_attention_比较合适
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
because of Integrating rotary_embedding, attention and etc, it should be named
fused_masked_multihead_attention
according to regulations