Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Paddle Inference] Add masked multihead attention kernel and export API. #55344

Merged
merged 28 commits into from
Aug 15, 2023
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
bcff596
support_mmha
yangjianfengo1 Jul 11, 2023
ac5ca55
add_python_api
yangjianfengo1 Jul 13, 2023
ff0dd45
add_api_doc
yangjianfengo1 Jul 13, 2023
057023d
fix_doc_error
xiaoxiaohehe001 Jul 13, 2023
3dec850
fix_infermeta
xiaoxiaohehe001 Jul 13, 2023
29f3f9e
add_infermeta
xiaoxiaohehe001 Jul 14, 2023
e9c6ee3
add_bf16_cuda_check
xiaoxiaohehe001 Jul 14, 2023
199aca4
fix_bf16
xiaoxiaohehe001 Jul 16, 2023
2da0edd
fix_ci_bloat16
xiaoxiaohehe001 Jul 17, 2023
b31fe6c
add_bf16_check
xiaoxiaohehe001 Jul 17, 2023
2c2697e
fix_bfloat16
xiaoxiaohehe001 Jul 17, 2023
371268f
fix_bfloat16
xiaoxiaohehe001 Jul 17, 2023
0ce5b0e
fix_ci_windows
xiaoxiaohehe001 Jul 17, 2023
ad3c325
fix_ci_windows
xiaoxiaohehe001 Jul 17, 2023
edc8c5d
fix_ci_windows
xiaoxiaohehe001 Jul 17, 2023
42282f8
fix_ci_windows
xiaoxiaohehe001 Jul 17, 2023
70ff789
fix_ci_windows_kernel_register
xiaoxiaohehe001 Jul 17, 2023
601a9d3
fix_ci_windows_kernel_register
xiaoxiaohehe001 Jul 17, 2023
2106c1a
fix_ci_windows_kernel_register
xiaoxiaohehe001 Jul 17, 2023
6f57cb7
fix_test_mmha
xiaoxiaohehe001 Jul 18, 2023
21f6c84
add_cumoffsets
xiaoxiaohehe001 Jul 18, 2023
8041fad
remove_bias
xiaoxiaohehe001 Jul 20, 2023
15b2fd5
delete_mmha_reshape_input_output
xiaoxiaohehe001 Aug 8, 2023
58b358c
fix_api_log
xiaoxiaohehe001 Aug 8, 2023
bf750ec
rename_delete_hfile
xiaoxiaohehe001 Aug 9, 2023
37b4d72
add_license_nv
xiaoxiaohehe001 Aug 9, 2023
b9907cc
Merge branch 'develop' into support_mmha
xiaoxiaohehe001 Aug 14, 2023
5470349
remove_fluid
xiaoxiaohehe001 Aug 14, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions paddle/fluid/operators/fused/fused_multi_transformer_op.cu.h
Original file line number Diff line number Diff line change
Expand Up @@ -931,9 +931,8 @@ __global__ void masked_multihead_attention_kernel(

#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER
if (bi == 0 && hi == 0 && tid == 0) {
printf("=======q_out=======\n");
for (int i = 0; i < Dh; ++i) printf("%f ", static_cast<float>(q_smem[i]));
printf("\n");
VLOG(0) << "=======q_out=======\n";
for (int i = 0; i < Dh; ++i) VLOG(0) << static_cast<float>(q_smem[i]);
}
__syncthreads();
#endif
Expand Down
11 changes: 11 additions & 0 deletions paddle/phi/api/yaml/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1541,6 +1541,17 @@
data_type : logits
backward : margin_cross_entropy_grad

- op : masked_multihead_attention_
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

感觉叫decoder_masked_multihead_attention_比较合适

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

because of Integrating rotary_embedding, attention and etc, it should be named fused_masked_multihead_attention according to regulations

args : (Tensor x, Tensor bias, Tensor src_mask, Tensor sequence_lengths, Tensor rotary_tensor, Tensor beam_cache_offset, Tensor cache_kv, Tensor qkv_out_scale, Tensor out_linear_shift, Tensor out_linear_smooth, int beam_size, int rotary_emb_dims, bool mask_broadcast_num_heads=true, bool compute_bias=false, bool use_neox_rotary_style=false, float out_linear_in_scale=-1, int quant_round_type=1, float quant_max_bound=127.0, float quant_min_bound=-127.0)
output : Tensor(out), Tensor(cache_kv_out), Tensor(beam_cache_offset_out)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

out_linear_in_scale、out_linear_shift、out_linear_smooth这些变量都是在标准的attention之外融合的部分,需要添加对输入的说明吧,或者名字可以考虑换一下

infer_meta :
func : MaskedMultiheadAttentionInferMeta
kernel :
func : masked_multihead_attention
data_type : cache_kv
optional : bias, src_mask, sequence_lengths, rotary_tensor, beam_cache_offset, qkv_out_scale, out_linear_shift, out_linear_smooth
inplace : (cache_kv -> cache_kv_out), (beam_cache_offset -> beam_cache_offset_out)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should add OP(eg. backward: fused_masked_multihead_attention_grad) to compute gradient according to regulations, otherwise it cannot be used for training

Copy link
Contributor Author

@xiaoxiaohehe001 xiaoxiaohehe001 Aug 15, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

masked_multihead_attention 目前只用于推理,反向后续是否添加需要再讨论


- op : masked_select
args : (Tensor x, Tensor mask)
output : Tensor (out)
Expand Down
68 changes: 68 additions & 0 deletions paddle/phi/infermeta/multiary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3624,5 +3624,73 @@ void WeightOnlyMatmulInferMeta(const MetaTensor& x,
out->set_dtype(x.dtype());
}

void MaskedMultiheadAttentionInferMeta(const MetaTensor& x,
const MetaTensor& bias,
const MetaTensor& src_mask,
const MetaTensor& sequence_lengths,
const MetaTensor& rotary_tensor,
const MetaTensor& beam_cache_offset,
const MetaTensor& cache_kv,
const MetaTensor& qkv_out_scale,
const MetaTensor& out_linear_shift,
const MetaTensor& out_linear_smooth,
int beam_size,
int rotary_emb_dims,
const bool mask_broadcast_num_heads,
const bool compute_bias,
const bool use_neox_rotary_style,
const float out_linear_in_scale,
const int quant_round_type,
const float quant_max_bound,
const float quant_min_bound,
MetaTensor* out,
MetaTensor* cache_kv_out,
MetaTensor* beam_cache_offset_out) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

beam_cache_offset的输出也不知道含义

auto x_dims = x.dims();
auto cache_kv_dims = cache_kv.dims();
auto x_dtype = x.dtype();
int bsz = x_dims[0];
int num_head = x_dims[2];
int dim_head = x_dims[3];

if (sequence_lengths) {
out->set_dims({bsz, num_head, dim_head});
} else {
out->set_dims({bsz, 1, num_head, dim_head});
}
if (out_linear_in_scale > 0) {
out->set_dtype(DataType::INT8);
} else {
out->set_dtype(x_dtype);
}

PADDLE_ENFORCE_EQ(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

感觉要check下 x的seq_len那一维必须为1

x_dims.size(),
4,
errors::InvalidArgument("The dimensions of x must be 4"
"(batch_size, 3, num_head, dim_head),"
"but received dimensions of"
"Input is [%d]",
x_dims.size()));
PADDLE_ENFORCE_EQ(
cache_kv_dims.size(),
5,
errors::InvalidArgument("The cache_kv must be 5 dims, but got %d",
cache_kv_dims.size()));
PADDLE_ENFORCE_EQ(
cache_kv_dims[0],
2,
errors::InvalidArgument("The first dim of cache_kv must be 2, but got %d",
cache_kv_dims[0]));

cache_kv_out->set_dims(cache_kv_dims);
cache_kv_out->set_dtype(cache_kv.dtype());

if (beam_cache_offset) {
beam_cache_offset_out->set_dims(beam_cache_offset.dims());
beam_cache_offset_out->set_dtype(beam_cache_offset.dtype());
}
}

} // namespace phi
PD_REGISTER_INFER_META_FN(batch_norm_infer, phi::BatchNormInferInferMeta);
23 changes: 23 additions & 0 deletions paddle/phi/infermeta/multiary.h
Original file line number Diff line number Diff line change
Expand Up @@ -706,4 +706,27 @@ void FusedRopeInferMeta(const MetaTensor& q,
MetaTensor* out_k,
MetaTensor* out_v);

void MaskedMultiheadAttentionInferMeta(const MetaTensor& x,
const MetaTensor& bias,
const MetaTensor& src_mask,
const MetaTensor& sequence_lengths,
const MetaTensor& rotary_tensor,
const MetaTensor& beam_cache_offset,
const MetaTensor& cache_kv,
const MetaTensor& qkv_out_scale,
const MetaTensor& out_linear_shift,
const MetaTensor& out_linear_smooth,
int beam_size,
int rotary_emb_dims,
const bool mask_broadcast_num_heads,
const bool compute_bias,
const bool use_neox_rotary_style,
const float out_linear_in_scale,
const int quant_round_type,
const float quant_max_bound,
const float quant_min_bound,
MetaTensor* out,
MetaTensor* cache_kv_out,
MetaTensor* beam_cache_offset_out);

} // namespace phi
Loading