Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Typing][C-85] Add type annotations for python/paddle/incubate/nn/layer/fused_transformer.py #67178

Merged
merged 7 commits into from
Aug 9, 2024

Conversation

enkilee
Copy link
Contributor

@enkilee enkilee commented Aug 8, 2024

PR Category

User Experience

PR Types

Improvements

Description

为公开 API 标注类型提示信息

C-85 python/paddle/incubate/nn/layer/fused_transformer.py

@megemini

Copy link

paddle-bot bot commented Aug 8, 2024

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@luotao1 luotao1 added contributor External developers HappyOpenSource 快乐开源活动issue与PR labels Aug 8, 2024
@megemini
Copy link
Contributor

megemini commented Aug 8, 2024

另外,CI 中的示例代码运行错误了,我本地也错误 ... ...

import paddle
from paddle.incubate.nn import FusedMultiTransformer
paddle.device.set_device('gpu')

# encoder input: [batch_size, src_len, d_model]
enc_input = paddle.rand((2, 4, 128))
# self attention mask: [batch_size, 1, src_len, src_len]
attn_mask = paddle.rand((2, 1, 4, 4))
encoder_layers = FusedMultiTransformer(128, 2, 512, num_layers=1)
enc_output = encoder_layers(enc_input, attn_mask)
print(enc_output.shape)

@SigureMo 找研发看看?还是直接 skip 或者提 issue ?

@SigureMo
Copy link
Member

SigureMo commented Aug 8, 2024

现在只有 fp16 kernel 了

PD_REGISTER_STRUCT_KERNEL(fused_multi_transformer,
GPU,
ALL_LAYOUT,
ops::FusedMultiTransformerOpKernel,
phi::dtype::float16,
phi::dtype::bfloat16) {}
#else
PD_REGISTER_STRUCT_KERNEL(fused_multi_transformer,
GPU,
ALL_LAYOUT,
ops::FusedMultiTransformerOpKernel,
phi::dtype::float16) {}

我看是 #64125 改的,改成 fp16 输入试试呢?

@megemini
Copy link
Contributor

megemini commented Aug 8, 2024

我看是 #64125 改的,改成 fp16 输入试试呢?

不行 ... ... 如果只对输入作修改:

代码里面直接读取默认类型 self._dtype = self._helper.get_default_dtype() ,然后继续初始化 ~

In [16]: import paddle
    ...: from paddle.incubate.nn import FusedMultiTransformer
    ...: paddle.device.set_device('gpu')
    ...: 
    ...: # encoder input: [batch_size, src_len, d_model]
    ...: enc_input = paddle.rand((2, 4, 128)).astype('float16')
    ...: # self attention mask: [batch_size, 1, src_len, src_len]
    ...: attn_mask = paddle.rand((2, 1, 4, 4)).astype('float16')
    ...: encoder_layers = FusedMultiTransformer(128, 2, 512, num_layers=1)
    ...: enc_output = encoder_layers(enc_input, attn_mask)
    ...: print(enc_output.shape)

导致类型不一致

ValueError: (InvalidArgument) The type of data we are trying to retrieve (float16) does not match the type of data (float32) currently contained in the container.

如果设置默认数据类型

import paddle
from paddle.incubate.nn import FusedMultiTransformer
paddle.device.set_device('gpu')
paddle.set_default_dtype('float16')

# encoder input: [batch_size, src_len, d_model]
enc_input = paddle.rand((2, 4, 128))
# self attention mask: [batch_size, 1, src_len, src_len]
attn_mask = paddle.rand((2, 1, 4, 4))
encoder_layers = FusedMultiTransformer(128, 2, 512, num_layers=1)
enc_output = encoder_layers(enc_input, attn_mask)
print(enc_output.shape)

运行也有错误

OSError: (External) Error in Flash-Attention, detail information is: `is_sm8x || is_sm90` check failed at /paddle/third_party/flashattn/csrc/capi/flash_attn.cu:681
  [Hint: Expected status == true, but received status:0 != true:1.] (at /paddle/paddle/phi/kernels/gpu/flash_attn_utils.h:360)
  [operator < fused_multi_transformer > error]

单测里面只有静态图的,也木的参考 🫠

@SigureMo
Copy link
Member

SigureMo commented Aug 8, 2024

给相关同学反馈了,先 skip 吧,不能阻塞这边任务

@SigureMo
Copy link
Member

SigureMo commented Aug 8, 2024

import paddle
from paddle.incubate.nn import FusedMultiTransformer
paddle.device.set_device('gpu')

# encoder input: [batch_size, src_len, d_model]
paddle.set_default_dtype('float16')
enc_input = paddle.rand((2, 4, 128)).astype('float16')
# self attention mask: [batch_size, 1, src_len, src_len]
attn_mask = paddle.rand((2, 1, 4, 4))
encoder_layers = FusedMultiTransformer(128, 2, 512, num_layers=1)
enc_output = encoder_layers(enc_input, attn_mask)
print(enc_output.shape)

相关同学反馈这个是可以跑的,我这边没带 flash attention 编,在 CI 上试试呢?如果还有问题就先 skip

@megemini
Copy link
Contributor

megemini commented Aug 8, 2024

CI 挂了 ~ @enkilee 示例代码 SKIP 吧,理由就写需要编译 flash attention 项 ~

@enkilee
Copy link
Contributor Author

enkilee commented Aug 9, 2024

CI 挂了 ~ @enkilee 示例代码 SKIP 吧,理由就写需要编译 flash attention 项 ~

收到

megemini

This comment was marked as duplicate.

megemini

This comment was marked as duplicate.

Copy link
Contributor

@megemini megemini left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM ~ 🤟

Copy link
Member

@SigureMo SigureMo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTMeow 🐾

@SigureMo SigureMo merged commit b6c6a4e into PaddlePaddle:develop Aug 9, 2024
31 checks passed
Jeff114514 pushed a commit to Jeff114514/Paddle that referenced this pull request Aug 14, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
contributor External developers HappyOpenSource 快乐开源活动issue与PR
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants