Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

scaled_dot_product_attention api #55242

Conversation

liuzhenhai93
Copy link
Contributor

@liuzhenhai93 liuzhenhai93 commented Jul 7, 2023

PR types

Others

PR changes

APIs

Description

scaled_dot_product_attention api
card-72806

@paddle-bot
Copy link

paddle-bot bot commented Jul 7, 2023

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@paddle-bot
Copy link

paddle-bot bot commented Jul 7, 2023

✅ This PR's description meets the template requirements!
Please wait for other CI results.

sneaxiy
sneaxiy previously approved these changes Jul 10, 2023
@paddle-ci-bot
Copy link

paddle-ci-bot bot commented Jul 17, 2023

Sorry to inform you that d419bc7's CIs have passed for more than 7 days. To prevent PR conflicts, you need to re-run all CIs manually.

python/paddle/nn/functional/flash_attention.py Outdated Show resolved Hide resolved

where : ``Q``, ``K``, and ``V`` represent the three input parameters of the attention module.
The dimensions of the three parameters are the same.
``d`` represents the size of the last dimension of the three parameters.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here, Q, K, and V denote the three input parameters of the attention module, all sharing identical dimensions. d represents the size of the last dimension of these three parameters.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

在数学公式里面, 一般用 where

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK,我这个是用的ChatGPT做的改动,仅供参考。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👌

The dtype can be float16 or bfloat16.

Examples:
.. code-block:: python
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

框架正在引入xdoctest,示例代码可以顺便改成xdoctest支持的格式,see #55295

Copy link
Contributor Author

@liuzhenhai93 liuzhenhai93 Jul 20, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

xdoctest支持的格式是什么样的呢?
是否有个 demo 或明确的规范

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

请参看我给的PR里的改动。

jzhang533
jzhang533 previously approved these changes Jul 24, 2023
Copy link
Contributor

@jzhang533 jzhang533 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

jzhang533
jzhang533 previously approved these changes Jul 25, 2023
@@ -407,4 +407,57 @@ def flash_attn_unpadded(
return out, softmax if return_softmax else None


scaled_dot_product_attention = flash_attention
def scaled_dot_product_attention(
query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In order to be consistent with other APIs, there must be a parameter name=None at last

>>> print(output)
>>> # xdoctest: -SKIP
"""
assert attn_mask is None, "attn_mask is not supported yet"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If attn_mask is not currently supported, add a TODO statement to indicate that it will be supported later

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已经有工作正在支持attn_mask,因此依赖当前PR合入。

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK

@Xreki Xreki merged commit b19dfb8 into PaddlePaddle:develop Aug 2, 2023
@liuzhenhai93 liuzhenhai93 deleted the develop_scaleed_dot_product_attention_api branch October 7, 2023 03:09
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants