-
Notifications
You must be signed in to change notification settings - Fork 796
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add fused scale mask bias softmax #9867
Conversation
…oneflow into dev_flash_attention
…nto dev_alphafold_fused_attn
oneflow/core/autograd/gradient_funcs/fused_scale_mask_bias_softmax.cpp
Outdated
Show resolved
Hide resolved
建议这里加一下文档,描述一下 |
-> Maybe<void> { | ||
const float scale = ctx->Attr<float>("scale"); | ||
CHECK_LE_OR_RETURN(scale, 1.); | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
对于下面的形状检查的逻辑,可以加些注释,简单讲解一下 x, mask 和 bias 支持哪些形状
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
好的👌
…add_fused_msa_softmax
…c/oneflow into add_fused_msa_softmax
Speed stats:
|
View latest API docs preview at: https://staging.oneflow.info/docs/Oneflow-Inc/oneflow/pr/9867/ |
该pr主要是对
softmax(x * scale + mask + bias)
多步操作进行融合,减少内存访问次数以提升效率,适用于通用的attention场景: 其中x为query和key矩阵乘的结果,形状一般为[batch_size, num_heads, seq_len_q, seq_len_kv]; scale=sqrt(head_size); mask和bias分别为[batch_size, 1, 1, seq_len_k] 和 [1, num_heads, seq_len_q, seq_len_kv]。对于某些场景,输入的mask的形状可能也不同(如alphafold):
针对上述特殊情形,该pr也做了一些针对性处理。
对于半精度类型(fp16和bf16)可以使用flash attention替换,但目前flash attention 还不支持tf32和fp32,后续应该还会逐渐完善。