Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add fused scale mask bias softmax #9867

Merged
merged 88 commits into from
Mar 4, 2023
Merged

Add fused scale mask bias softmax #9867

merged 88 commits into from
Mar 4, 2023

Conversation

ofhwei
Copy link
Contributor

@ofhwei ofhwei commented Feb 14, 2023

该pr主要是对 softmax(x * scale + mask + bias) 多步操作进行融合,减少内存访问次数以提升效率,适用于通用的attention场景: 其中x为query和key矩阵乘的结果,形状一般为[batch_size, num_heads, seq_len_q, seq_len_kv]; scale=sqrt(head_size); mask和bias分别为[batch_size, 1, 1, seq_len_k] 和 [1, num_heads, seq_len_q, seq_len_kv]。
对于某些场景,输入的mask的形状可能也不同(如alphafold):

  1. global attn中x.shape=[b, h, s], mask.shape=[b, 1, s],
  2. template_pointwise_attn中x.shape=[s, s, 1, n_templ], mask.shape=[1, 1, 1, n_templ]等。

针对上述特殊情形,该pr也做了一些针对性处理。

对于半精度类型(fp16和bf16)可以使用flash attention替换,但目前flash attention 还不支持tf32和fp32,后续应该还会逐渐完善。

@Ldpe2G
Copy link
Contributor

Ldpe2G commented Feb 20, 2023

建议这里加一下文档,描述一下 fused_msa_softmax 和一般的 softmax 的区别是啥

@ofhwei ofhwei changed the title Add fused msa softmax Add fused scale mask bias softmax Feb 20, 2023
-> Maybe<void> {
const float scale = ctx->Attr<float>("scale");
CHECK_LE_OR_RETURN(scale, 1.);

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

对于下面的形状检查的逻辑,可以加些注释,简单讲解一下 x, mask 和 bias 支持哪些形状

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的👌

@ofhwei ofhwei requested a review from oneflow-ci-bot March 4, 2023 01:48
@github-actions
Copy link
Contributor

github-actions bot commented Mar 4, 2023

Speed stats:
GPU Name: GeForce GTX 1080 

❌ OneFlow resnet50 time: 141.2ms (= 14119.4ms / 100, input_shape=[16, 3, 224, 224])
PyTorch resnet50 time: 144.5ms (= 14449.6ms / 100, input_shape=[16, 3, 224, 224])
❌ Relative speed: 1.02 (= 144.5ms / 141.2ms)

OneFlow resnet50 time: 82.4ms (= 8242.7ms / 100, input_shape=[8, 3, 224, 224])
PyTorch resnet50 time: 88.9ms (= 8890.2ms / 100, input_shape=[8, 3, 224, 224])
✔️ Relative speed: 1.08 (= 88.9ms / 82.4ms)

OneFlow resnet50 time: 50.8ms (= 10165.4ms / 200, input_shape=[4, 3, 224, 224])
PyTorch resnet50 time: 59.5ms (= 11909.7ms / 200, input_shape=[4, 3, 224, 224])
✔️ Relative speed: 1.17 (= 59.5ms / 50.8ms)

OneFlow resnet50 time: 33.9ms (= 6772.5ms / 200, input_shape=[2, 3, 224, 224])
PyTorch resnet50 time: 45.5ms (= 9099.2ms / 200, input_shape=[2, 3, 224, 224])
✔️ Relative speed: 1.34 (= 45.5ms / 33.9ms)

OneFlow resnet50 time: 25.7ms (= 5146.7ms / 200, input_shape=[1, 3, 224, 224])
PyTorch resnet50 time: 45.4ms (= 9079.9ms / 200, input_shape=[1, 3, 224, 224])
✔️ Relative speed: 1.76 (= 45.4ms / 25.7ms)

OneFlow swin dataloader time: 0.236s (= 47.129s / 200, num_workers=1)
PyTorch swin dataloader time: 0.150s (= 30.085s / 200, num_workers=1)
Relative speed: 0.638 (= 0.150s / 0.236s)

OneFlow swin dataloader time: 0.068s (= 13.530s / 200, num_workers=4)
PyTorch swin dataloader time: 0.045s (= 8.931s / 200, num_workers=4)
Relative speed: 0.660 (= 0.045s / 0.068s)

OneFlow swin dataloader time: 0.044s (= 8.878s / 200, num_workers=8)
PyTorch swin dataloader time: 0.023s (= 4.579s / 200, num_workers=8)
Relative speed: 0.516 (= 0.023s / 0.044s)

❌ OneFlow resnet50 time: 154.3ms (= 15426.9ms / 100, input_shape=[16, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 164.3ms (= 16428.0ms / 100, input_shape=[16, 3, 224, 224], ddp, world size=2)
❌ Relative speed: 1.06 (= 164.3ms / 154.3ms)

OneFlow resnet50 time: 93.2ms (= 9323.5ms / 100, input_shape=[8, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 103.4ms (= 10343.0ms / 100, input_shape=[8, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.11 (= 103.4ms / 93.2ms)

OneFlow resnet50 time: 60.8ms (= 12162.5ms / 200, input_shape=[4, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 78.5ms (= 15698.3ms / 200, input_shape=[4, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.29 (= 78.5ms / 60.8ms)

OneFlow resnet50 time: 43.1ms (= 8620.6ms / 200, input_shape=[2, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 71.3ms (= 14267.9ms / 200, input_shape=[2, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.66 (= 71.3ms / 43.1ms)

OneFlow resnet50 time: 37.6ms (= 7528.9ms / 200, input_shape=[1, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 67.7ms (= 13539.4ms / 200, input_shape=[1, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.80 (= 67.7ms / 37.6ms)

@github-actions
Copy link
Contributor

github-actions bot commented Mar 4, 2023

View latest API docs preview at: https://staging.oneflow.info/docs/Oneflow-Inc/oneflow/pr/9867/

@ofhwei ofhwei merged commit 7d07caf into master Mar 4, 2023
@ofhwei ofhwei deleted the add_fused_msa_softmax branch March 4, 2023 05:09
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants