Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix FlashAttnOpInferSymbolicShape and FlashAttnInferMeta #63816

Merged
merged 5 commits into from
Apr 26, 2024

Conversation

Hongqing-work
Copy link
Contributor

PR Category

Others

PR Types

Bug fixes

Description

Pcard-67164
This PR fixed FlashAttnOpInferSymbolicShape and FlashAttnInferMeta by adding shape inferring for softmax, softmax_lse and seed_offset.

Copy link

paddle-bot bot commented Apr 24, 2024

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

Comment on lines 388 to 390
auto batch_size = q.dims()[0];
auto num_heads = q.dims()[2];
auto seqlen_q_rounded = round_multiple(q.dims()[1]);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里简单的int类型最好还是用原始的数据类型,auto表示不够直观,会提高阅读成本,增加犯错的概率

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

zyfncg
zyfncg previously approved these changes Apr 25, 2024
@@ -287,6 +290,35 @@ bool FlashAttnOpInferSymbolicShape(

shape_analysis->SetShapeOrDataForValue(
op->result(0), symbol::TensorShapeOrDataDimExprs(out_shape));

auto round_multiple = [](symbol::DimExpr x) {
auto m = symbol::DimExpr{128};
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个地方为啥 选择128, 是kernel里面写的128么

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cuda kernel用了基于128的round,但是XPU没有,这里暂时和cuda kernel保持一致并增加备注。

auto m_minus_one = symbol::DimExpr{127};
return (x + m_minus_one) / m * m;
};
auto batch_size_expr = q.shape()[0];
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

能够加一个 q.shape size 的检查,防止传入了错误的输入,这里直接core dump

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

@phlrain phlrain self-requested a review April 26, 2024 03:25
@zyfncg zyfncg merged commit 52db8e4 into PaddlePaddle:develop Apr 26, 2024
28 of 30 checks passed
runzhech pushed a commit to runzhech/Paddle that referenced this pull request Apr 30, 2024
…e#63816)

* fix FlashAttnOpInferSymbolicShape and FlashAttnInferMeta

* use int for simple value

* add check and constraint

* fix

* add shape constraint for attention_mask
runzhech pushed a commit to runzhech/Paddle that referenced this pull request Apr 30, 2024
…e#63816)

* fix FlashAttnOpInferSymbolicShape and FlashAttnInferMeta

* use int for simple value

* add check and constraint

* fix

* add shape constraint for attention_mask
@Hongqing-work Hongqing-work deleted the fix-flash-attn-infer-shape branch May 10, 2024 06:22
co63oc pushed a commit to co63oc/Paddle that referenced this pull request May 10, 2024
…e#63816)

* fix FlashAttnOpInferSymbolicShape and FlashAttnInferMeta

* use int for simple value

* add check and constraint

* fix

* add shape constraint for attention_mask
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants