Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Paddle Inference] Add masked multihead attention kernel and export API. #55344

Merged
merged 28 commits into from
Aug 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
bcff596
support_mmha
yangjianfengo1 Jul 11, 2023
ac5ca55
add_python_api
yangjianfengo1 Jul 13, 2023
ff0dd45
add_api_doc
yangjianfengo1 Jul 13, 2023
057023d
fix_doc_error
xiaoxiaohehe001 Jul 13, 2023
3dec850
fix_infermeta
xiaoxiaohehe001 Jul 13, 2023
29f3f9e
add_infermeta
xiaoxiaohehe001 Jul 14, 2023
e9c6ee3
add_bf16_cuda_check
xiaoxiaohehe001 Jul 14, 2023
199aca4
fix_bf16
xiaoxiaohehe001 Jul 16, 2023
2da0edd
fix_ci_bloat16
xiaoxiaohehe001 Jul 17, 2023
b31fe6c
add_bf16_check
xiaoxiaohehe001 Jul 17, 2023
2c2697e
fix_bfloat16
xiaoxiaohehe001 Jul 17, 2023
371268f
fix_bfloat16
xiaoxiaohehe001 Jul 17, 2023
0ce5b0e
fix_ci_windows
xiaoxiaohehe001 Jul 17, 2023
ad3c325
fix_ci_windows
xiaoxiaohehe001 Jul 17, 2023
edc8c5d
fix_ci_windows
xiaoxiaohehe001 Jul 17, 2023
42282f8
fix_ci_windows
xiaoxiaohehe001 Jul 17, 2023
70ff789
fix_ci_windows_kernel_register
xiaoxiaohehe001 Jul 17, 2023
601a9d3
fix_ci_windows_kernel_register
xiaoxiaohehe001 Jul 17, 2023
2106c1a
fix_ci_windows_kernel_register
xiaoxiaohehe001 Jul 17, 2023
6f57cb7
fix_test_mmha
xiaoxiaohehe001 Jul 18, 2023
21f6c84
add_cumoffsets
xiaoxiaohehe001 Jul 18, 2023
8041fad
remove_bias
xiaoxiaohehe001 Jul 20, 2023
15b2fd5
delete_mmha_reshape_input_output
xiaoxiaohehe001 Aug 8, 2023
58b358c
fix_api_log
xiaoxiaohehe001 Aug 8, 2023
bf750ec
rename_delete_hfile
xiaoxiaohehe001 Aug 9, 2023
37b4d72
add_license_nv
xiaoxiaohehe001 Aug 9, 2023
b9907cc
Merge branch 'develop' into support_mmha
xiaoxiaohehe001 Aug 14, 2023
5470349
remove_fluid
xiaoxiaohehe001 Aug 14, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions paddle/fluid/operators/fused/fused_multi_transformer_op.cu.h
Original file line number Diff line number Diff line change
Expand Up @@ -931,9 +931,8 @@ __global__ void masked_multihead_attention_kernel(

#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER
if (bi == 0 && hi == 0 && tid == 0) {
printf("=======q_out=======\n");
for (int i = 0; i < Dh; ++i) printf("%f ", static_cast<float>(q_smem[i]));
printf("\n");
VLOG(0) << "=======q_out=======\n";
for (int i = 0; i < Dh; ++i) VLOG(0) << static_cast<float>(q_smem[i]);
}
__syncthreads();
#endif
Expand Down
11 changes: 11 additions & 0 deletions paddle/phi/api/yaml/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1616,6 +1616,17 @@
data_type : logits
backward : margin_cross_entropy_grad

- op : masked_multihead_attention_
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

感觉叫decoder_masked_multihead_attention_比较合适

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

because of Integrating rotary_embedding, attention and etc, it should be named fused_masked_multihead_attention according to regulations

args : (Tensor x, Tensor cache_kv, Tensor src_mask, Tensor cum_offsets, Tensor sequence_lengths, Tensor rotary_tensor, Tensor beam_cache_offset, Tensor qkv_out_scale, Tensor out_shift, Tensor out_smooth, int seq_len, int rotary_emb_dims, bool use_neox_rotary_style=false, float out_scale=-1, int quant_round_type=1, float quant_max_bound=127.0, float quant_min_bound=-127.0)
output : Tensor(out), Tensor(cache_kv_out), Tensor(beam_cache_offset_out)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

out_linear_in_scale、out_linear_shift、out_linear_smooth这些变量都是在标准的attention之外融合的部分,需要添加对输入的说明吧,或者名字可以考虑换一下

infer_meta :
func : MaskedMultiheadAttentionInferMeta
kernel :
func : masked_multihead_attention
data_type : cache_kv
optional : src_mask, cum_offsets, sequence_lengths, rotary_tensor, beam_cache_offset, qkv_out_scale, out_shift, out_smooth
inplace : (cache_kv -> cache_kv_out), (beam_cache_offset -> beam_cache_offset_out)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should add OP(eg. backward: fused_masked_multihead_attention_grad) to compute gradient according to regulations, otherwise it cannot be used for training

Copy link
Contributor Author

@xiaoxiaohehe001 xiaoxiaohehe001 Aug 15, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

masked_multihead_attention 目前只用于推理,反向后续是否添加需要再讨论


- op : masked_select
args : (Tensor x, Tensor mask)
output : Tensor (out)
Expand Down
64 changes: 64 additions & 0 deletions paddle/phi/infermeta/multiary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3983,5 +3983,69 @@ void WeightOnlyMatmulInferMeta(const MetaTensor& x,
out->set_dtype(x.dtype());
}

void MaskedMultiheadAttentionInferMeta(const MetaTensor& x,
const MetaTensor& cache_kv,
const MetaTensor& src_mask,
const MetaTensor& cum_offsets,
const MetaTensor& sequence_lengths,
const MetaTensor& rotary_tensor,
const MetaTensor& beam_cache_offset,
const MetaTensor& qkv_out_scale,
const MetaTensor& out_shift,
const MetaTensor& out_smooth,
int seq_len,
int rotary_emb_dims,
const bool use_neox_rotary_style,
const float out_scale,
const int quant_round_type,
const float quant_max_bound,
const float quant_min_bound,
MetaTensor* out,
MetaTensor* cache_kv_out,
MetaTensor* beam_cache_offset_out) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

beam_cache_offset的输出也不知道含义

int bsz = x.dims()[0];
auto x_dtype = x.dtype();
auto cache_kv_dims = cache_kv.dims();
int num_head = cache_kv.dims()[2];
int dim_head = cache_kv.dims()[4];

PADDLE_ENFORCE_EQ(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

感觉要check下 x的seq_len那一维必须为1

cache_kv_dims.size(),
5,
errors::InvalidArgument("The cache_kv must be 5 dims, but got %d",
cache_kv_dims.size()));

PADDLE_ENFORCE_EQ(
cache_kv_dims[0],
2,
errors::InvalidArgument("The first dim of cache_kv must be 2, but got %d",
cache_kv_dims[0]));

if (rotary_tensor) {
PADDLE_ENFORCE_EQ(
rotary_tensor.dtype(),
DataType::FLOAT32,
errors::InvalidArgument(
"The dtype of rotary_tensor must be float32, but got %d",
rotary_tensor.dtype()));
}

out->set_dims({bsz, num_head * dim_head});

if (out_scale > 0) {
out->set_dtype(DataType::INT8);
} else {
out->set_dtype(x_dtype);
}

cache_kv_out->set_dims(cache_kv_dims);
cache_kv_out->set_dtype(cache_kv.dtype());

if (beam_cache_offset) {
beam_cache_offset_out->set_dims(beam_cache_offset.dims());
beam_cache_offset_out->set_dtype(beam_cache_offset.dtype());
}
}

} // namespace phi
PD_REGISTER_INFER_META_FN(batch_norm_infer, phi::BatchNormInferInferMeta);
21 changes: 21 additions & 0 deletions paddle/phi/infermeta/multiary.h
Original file line number Diff line number Diff line change
Expand Up @@ -773,4 +773,25 @@ void FusedRopeInferMeta(const MetaTensor& q,
MetaTensor* out_k,
MetaTensor* out_v);

void MaskedMultiheadAttentionInferMeta(const MetaTensor& x,
const MetaTensor& cache_kv,
const MetaTensor& src_mask,
const MetaTensor& cum_offsets,
const MetaTensor& sequence_lengths,
const MetaTensor& rotary_tensor,
const MetaTensor& beam_cache_offset,
const MetaTensor& qkv_out_scale,
const MetaTensor& out_shift,
const MetaTensor& out_smooth,
int seq_len,
int rotary_emb_dims,
const bool use_neox_rotary_style,
const float out_scale,
const int quant_round_type,
const float quant_max_bound,
const float quant_min_bound,
MetaTensor* out,
MetaTensor* cache_kv_out,
MetaTensor* beam_cache_offset_out);

} // namespace phi
Loading