Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Paddle Inference] Add masked multihead attention kernel and export API. #55344

Merged
merged 28 commits into from
Aug 15, 2023

Conversation

xiaoxiaohehe001
Copy link
Contributor

@xiaoxiaohehe001 xiaoxiaohehe001 commented Jul 11, 2023

PR types

Others

PR changes

OPs

Description

  • Support masked multihead attention for transformer decoder stage.
  • Support dtype:[float | float16 | bfloat16]
  • Support int8 qkv out scale and outlinear in scale
  • export python api: from paddle.incubate.nn.functional import masked_multihead_attention

Pcard-71502

@paddle-bot
Copy link

paddle-bot bot commented Jul 11, 2023

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@xiaoxiaohehe001 xiaoxiaohehe001 changed the title [Paddle Inference] support_mmha for inference. [Paddle Inference] Add MMHAKernel for inference and export mmha API. Jul 13, 2023
@xiaoxiaohehe001 xiaoxiaohehe001 changed the title [Paddle Inference] Add MMHAKernel for inference and export mmha API. [Paddle Inference] Add MMHAKernel and export mmha API. Jul 13, 2023
@xiaoxiaohehe001 xiaoxiaohehe001 changed the title [Paddle Inference] Add MMHAKernel and export mmha API. [Paddle Inference] Add masked multihead attention kernel and export API. Jul 13, 2023
@@ -1541,6 +1541,17 @@
data_type : logits
backward : margin_cross_entropy_grad

- op : masked_multihead_attention_
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

感觉叫decoder_masked_multihead_attention_比较合适

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

because of Integrating rotary_embedding, attention and etc, it should be named fused_masked_multihead_attention according to regulations

out->set_dtype(x_dtype);
}

PADDLE_ENFORCE_EQ(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

感觉要check下 x的seq_len那一维必须为1

@@ -1541,6 +1541,17 @@
data_type : logits
backward : margin_cross_entropy_grad

- op : masked_multihead_attention_
args : (Tensor x, Tensor bias, Tensor src_mask, Tensor sequence_lengths, Tensor rotary_tensor, Tensor beam_cache_offset, Tensor cache_kv, Tensor qkv_out_scale, Tensor out_linear_shift, Tensor out_linear_smooth, int beam_size, int rotary_emb_dims, bool mask_broadcast_num_heads=true, bool compute_bias=false, bool use_neox_rotary_style=false, float out_linear_in_scale=-1, int quant_round_type=1, float quant_max_bound=127.0, float quant_min_bound=-127.0)
output : Tensor(out), Tensor(cache_kv_out), Tensor(beam_cache_offset_out)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

out_linear_in_scale、out_linear_shift、out_linear_smooth这些变量都是在标准的attention之外融合的部分,需要添加对输入的说明吧,或者名字可以考虑换一下

const float quant_min_bound,
MetaTensor* out,
MetaTensor* cache_kv_out,
MetaTensor* beam_cache_offset_out) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

beam_cache_offset的输出也不知道含义

src_mask (Tensor): The src_mask tensor. the shape is `[batch\_size, 1, 1, sequence\_length]`.
sequence_lengths (Tensor, optional): The sequence_lengths tensor. the shape is `[batch\_size, 1]`.
rotary_tensor (Tensor, optional): The rotary_tensor tensor. the shape is `[batch\_size, 1]`.
beam_cache_offset (Tensor, optional): The rotary_tensor tensor. the shape is `[batch\_size, beam\_size, max\_seq\_len + max\_dec\_len]`.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里注释写错了吧

rotary_tensor (Tensor, optional): The rotary_tensor tensor. the shape is `[batch\_size, 1]`.
beam_cache_offset (Tensor, optional): The rotary_tensor tensor. the shape is `[batch\_size, beam\_size, max\_seq\_len + max\_dec\_len]`.
cache_kvs (list(Tensor)|tuple(Tensor)): The cache structure tensors for the generation model. The shape is `[2, bsz, num\_head, max\_seq\_len, head\_dim]`.
rotary_tensor (Tensor, optional): The rotary_tensor tensor. the shape is `[batch\_size, 1, 1, sequence\_length, dim_head]`.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

重复了

bias (Tensor, optional): The bias tensor of qkv, the shape is `[3, num\_head, dim\_head]`.
src_mask (Tensor): The src_mask tensor. the shape is `[batch\_size, 1, 1, sequence\_length]`.
sequence_lengths (Tensor, optional): The sequence_lengths tensor. the shape is `[batch\_size, 1]`.
rotary_tensor (Tensor, optional): The rotary_tensor tensor. the shape is `[batch\_size, 1]`.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rotary_tensor没有限定类别,但是kernel里限定了float类别

self.num_head = 6
self.dim_head = 32
self.beam_size = 1
self.max_seq_len = 6
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个单测的seq_len太小了,很难覆盖真实的模型的输入

np.testing.assert_allclose(
paddle_mmha_out[0].numpy(),
paddle_naive_rmsnorm[0].numpy(),
rtol=5e-2,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

5e-2的相对误差是不是太大了?

@qingqing01 qingqing01 self-requested a review July 17, 2023 11:42
@paddle-ci-bot
Copy link

paddle-ci-bot bot commented Jul 30, 2023

Sorry to inform you that 8041fad's CIs have passed for more than 7 days. To prevent PR conflicts, you need to re-run all CIs manually.

#include "paddle/fluid/operators/fused/fmha_ref.h"
#include "paddle/fluid/operators/fused/fused_dropout_helper.h"
#include "paddle/fluid/platform/device/gpu/gpu_dnn.h"
#include "paddle/fluid/platform/dynload/cublasLt.h"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. 这里不是所有的include都有用吧,至少op_registry.h没用到吧
  2. 其他文件也注意下这个问题

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

#include "paddle/fluid/distributed/collective/process_group.h"
#include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/device/gpu/nccl_helper.h"
#endif
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

应该也没用到 nccl吧,没用到的清理掉吧

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

const int quant_round_type = 1,
const float quant_max_bound = 127.0f,
const float quant_min_bound = -127.0f) {
if (dequant_qkv_scales != nullptr && quant_fmha_out_scale > 0) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

better to add commits for difference branch

# limitations under the License.

from paddle import _C_ops
from paddle.fluid.layer_helper import LayerHelper
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

from paddle.framework import LayerHelper

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

quant_min_bound=-127.0,
):
r"""
Multi-head attention for text summarization.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

text generation

qkv_out_scale=None,
out_linear_shift=None,
out_linear_smooth=None,
seq_len=1,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

下面没有解释 seq_len 含义

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done~

rotary_emb_dims (int, optional): The rotary_emb_dims. Default 0.
use_neox_rotary_style (bool, optional): A flag indicating whether neox_rotary_style is needed or not. Default False.
out_linear_in_scale (float, optional): The out_linear_in_scale.
quant_round_type (int, optional): The quant_round_type. Default 1.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

round有哪些type类型?


import paddle
from paddle.fluid import core
from paddle.fluid.layer_helper import LayerHelper
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上,非必要不用fluid API

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done~

from paddle.framework import in_dynamic_mode


def mmha_wrapper(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

上面已经封装接口,这个接口看起来没必要吧

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done~

@@ -0,0 +1,552 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

为什么放到 legacy_test 文件夹?

Aurelius84
Aurelius84 previously approved these changes Aug 14, 2023
quant_max_bound=127.0,
quant_min_bound=-127.0,
):
r"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

新增api需要在docs下同步增加中文文档~

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的后续会添加上


Args:
x (Tensor): The input tensor could be 2-D tensor. Its shape is [batch_size, 3 * num_head * head_dim].
cache_kvs (list(Tensor)|tuple(Tensor)): The cache structure tensors for the generation model. Its shape is [2, batch_size, num_head, max_seq_len, head_dim].
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
cache_kvs (list(Tensor)|tuple(Tensor)): The cache structure tensors for the generation model. Its shape is [2, batch_size, num_head, max_seq_len, head_dim].
cache_kvs (list(Tensor)|tuple(Tensor), optional): The cache structure tensors for the generation model. Its shape is [2, batch_size, num_head, max_seq_len, head_dim].

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cache_kvs 并非 optional 输入

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

但是它有默认值呀,上面不是写了 cache_kv=None

Comment on lines 70 to 72
namespace plat = paddle::platform;
using float16 = plat::float16;
using bfloat16 = plat::bfloat16;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

phi下面应该没有paddle::platform的namespace, 这里可以用phi::dtype替换

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done~

Comment on lines 17 to 18
#include "paddle/fluid/memory/memcpy.h"
#include "paddle/fluid/platform/profiler.h"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

phi在独立编译为动态链接库后,已经不允许引用fluid目录的头文件

Copy link
Contributor Author

@xiaoxiaohehe001 xiaoxiaohehe001 Aug 14, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

收到,Done~

Comment on lines +25 to +26
template <typename T, typename Context>
void MMHAKernel(const Context& dev_ctx,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fusion类kernel不用创建头文件声明,避免被各处使用

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

看上去这是一个纯GPU的实现代码,不用写在impl目录下,直接写到gpu kernel的.cu文件里即可

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的

qingqing01
qingqing01 previously approved these changes Aug 14, 2023
@xiaoxiaohehe001 xiaoxiaohehe001 dismissed stale reviews from qingqing01 and Aurelius84 via 5470349 August 14, 2023 13:57
#pragma once

#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/kernels/fusion/gpu/masked_multihead_attention_utils.h"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个头文件是不是不需要了,它会把glog头文件引过来,被拦截了

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个后续提pr 合并

@@ -45,4 +46,5 @@
'variable_length_memory_efficient_attention',
"fused_rms_norm",
"fused_layer_norm",
"masked_multihead_attention",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

because of Integrating rotary_embedding, attention and etc, it should be named fused_masked_multihead_attention according to regulations

func : masked_multihead_attention
data_type : cache_kv
optional : src_mask, cum_offsets, sequence_lengths, rotary_tensor, beam_cache_offset, qkv_out_scale, out_shift, out_smooth
inplace : (cache_kv -> cache_kv_out), (beam_cache_offset -> beam_cache_offset_out)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should add OP(eg. backward: fused_masked_multihead_attention_grad) to compute gradient according to regulations, otherwise it cannot be used for training

Copy link
Contributor Author

@xiaoxiaohehe001 xiaoxiaohehe001 Aug 15, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

masked_multihead_attention 目前只用于推理,反向后续是否添加需要再讨论

Copy link
Contributor

@jiahy0825 jiahy0825 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM for include "logging.h" in paddle/phi/kernels/fusion/gpu/masked_multihead_attention_utils.h temporarily.
Please create another PR to remove this line later.

Copy link
Collaborator

@raindrops2sea raindrops2sea left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@qingqing01 qingqing01 merged commit 989c5e8 into PaddlePaddle:develop Aug 15, 2023
@unittest.skipIf(
not core.is_compiled_with_cuda(), "core is not compiled with CUDA"
)
class TestLayerNormStaticInt8Op(unittest.TestCase):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个命名是不是错了

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.