-
-
Notifications
You must be signed in to change notification settings - Fork 10.8k
[torch.compile] Enable attention and allreduce fusion without custom ops enabled #24604
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[torch.compile] Enable attention and allreduce fusion without custom ops enabled #24604
Conversation
|
This pull request has merge conflicts that must be resolved before it can be |
b374514 to
4a44829
Compare
|
This pull request has merge conflicts that must be resolved before it can be |
42f2231 to
a8c9181
Compare
|
This pull request has merge conflicts that must be resolved before it can be |
1e9326c to
e3d0c83
Compare
e3d0c83 to
9151d01
Compare
|
This pull request has merge conflicts that must be resolved before it can be |
9151d01 to
da3cb54
Compare
…g utils, fix DCE bug (vllm-project#23091), fix test (vllm-project#24376), and prep for custom op matching (vllm-project#24604) (vllm-project#24542) Signed-off-by: Luka Govedič <lgovedic@redhat.com> Signed-off-by: luka <lgovedic@redhat.com> Signed-off-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My current understanding is that when we pattern match against the torch native implementation of a custom operator, we register a pattern in Inductor using that native implementation. I'm worried that this approach might be fragile. When the torch native implementation is passed through torch.compile, various graph passes can transform it, so by the time it reaches the post-grad phase (where vLLM’s pattern matching currently happens), the structure may look different.
For example, with rms_norm, it seems we’d need to modify the implementation in a non-trivial way to make it pattern match. I don't know if this is an issue in practice, but it suggests that this scheme could unintentionally constrain how custom operators need to be authored — in ways we might not fully understand yet.
It might be more robust to preserve the custom operator as-is (i.e., avoid decomposing it into torch native ops) and then perform pattern matching directly on the custom operator itself. That would make the process less sensitive to internal graph transformations.
I did see that you wanted this in for the release. Was there a specific reason? If we are turning on the allreduce+rmsnorm fusion by default, for example, then could the fusion instead imply "+rmsnorm"?
|
This pull request has merge conflicts that must be resolved before it can be |
|
The reason this is needed is it lets us do fusion without having to enable custom ops (-O.custom_ops=["+quant_fp8"]). Enabling custom ops leads to lost performance, as demonstrated in the PR description. That's because there are 4 quant ops per layer, one per matmul, and I agree this is a somewhat fragile approach. I would be happy to work on a "lowering" approach where we preserve the high-level structure of ops until later. The downside would be that it would require more work (I think), and we might lose access to optimizations that currently happen before our passes . But I think it wouldn't hurt Inductor in general to have a more explicit sense of converting between higher-level and lower-level representations (or we just move where our custom passes happen). We can tie this work into the "autotuning custom op implementations" like done in pytorch/pytorch#164212. |
|
As discussed offline, we are going to proceed by merging this PR. After PTC, we will move our custom op matching passes to |
|
view/slice noop eliminations were upstreamed to PyTorch so I'm wondering if this is sufficient pytorch/pytorch#151095 pytorch/pytorch#151175 |
…hing-2 Signed-off-by: Luka Govedič <lgovedic@redhat.com>
|
@BoyuanFeng wouldn't that run after |
Signed-off-by: Luka Govedič <lgovedic@redhat.com>
…g utils, fix DCE bug (vllm-project#23091), fix test (vllm-project#24376), and prep for custom op matching (vllm-project#24604) (vllm-project#24542) Signed-off-by: Luka Govedič <lgovedic@redhat.com> Signed-off-by: luka <lgovedic@redhat.com> Signed-off-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
…ops enabled (vllm-project#24604) Signed-off-by: Luka Govedič <lgovedic@redhat.com> Signed-off-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
…ops enabled (vllm-project#24604) Signed-off-by: Luka Govedič <lgovedic@redhat.com> Signed-off-by: Luka Govedič <ProExpertProg@users.noreply.github.com> Signed-off-by: Alberto Perdomo <aperdomo@redhat.com>
…ops enabled (vllm-project#24604) Signed-off-by: Luka Govedič <lgovedic@redhat.com> Signed-off-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
…g utils, fix DCE bug (vllm-project#23091), fix test (vllm-project#24376), and prep for custom op matching (vllm-project#24604) (vllm-project#24542) Signed-off-by: Luka Govedič <lgovedic@redhat.com> Signed-off-by: luka <lgovedic@redhat.com> Signed-off-by: Luka Govedič <ProExpertProg@users.noreply.github.com> Signed-off-by: xuebwang-amd <xuebwang@amd.com>
…ops enabled (vllm-project#24604) Signed-off-by: Luka Govedič <lgovedic@redhat.com> Signed-off-by: Luka Govedič <ProExpertProg@users.noreply.github.com> Signed-off-by: xuebwang-amd <xuebwang@amd.com>
…ops enabled (vllm-project#24604) Signed-off-by: Luka Govedič <lgovedic@redhat.com> Signed-off-by: Luka Govedič <ProExpertProg@users.noreply.github.com> Signed-off-by: xuebwang-amd <xuebwang@amd.com>
…ops enabled (vllm-project#24604) Signed-off-by: Luka Govedič <lgovedic@redhat.com> Signed-off-by: Luka Govedič <ProExpertProg@users.noreply.github.com> Signed-off-by: 0xrushi <6279035+0xrushi@users.noreply.github.com>
Purpose
This PR enables matching the torch implementations of custom ops QuantFP8 and RMSNorm. On
main, fusion currently requires enabling custom ops, but they are slower than their torch counterparts, so the benefit of custom fusion passes is reduced.We add a bunch of "matcher util" objects which can be called in patterns and get traced to the same fx nodes as the custom op they correspond to in both enabled and disabled form automatically.
This PR also adds additional debugging utilities and adds E2E fusion tests to verify fusions happen in models end-to-end instead of just in unit tests.
Test Plan
Unit tests, added more fusion E2E tests.
Test Result
Tests all pass
Performance numbers
Below are B200 numbers (with flashinfer) from
vllm bench serveon the following serve command:We test the following regimes with corresponding additional arguments:
none:-O.custom_ops='["none"]' -O.pass_config={"enable_fi_allreduce_fusion":false,"enable_attn_fusion":false,"enable_noop":true}none_fusion_attention:-O.custom_ops='["none"]' -O.pass_config={"enable_fi_allreduce_fusion":false,"enable_attn_fusion":true,"enable_noop":true}none_fusion_attention_allreduce:-O.custom_ops='["none"]' -O.pass_config={"enable_fi_allreduce_fusion":true,"enable_attn_fusion":true,"enable_noop":true}rms_quant:-O.custom_ops='["none", "+quant_fp8", "+rms_norm"]' -O.pass_config={"enable_fi_allreduce_fusion":false,"enable_attn_fusion":false,"enable_noop":true}rms_quant_fusion_attention:-O.custom_ops='["none", "+quant_fp8", "+rms_norm"]' -O.pass_config={"enable_fi_allreduce_fusion":false,"enable_attn_fusion":true,"enable_noop":true}rms_quant_fusion_attention_allreduce:-O.custom_ops='["none", "+quant_fp8", "+rms_norm"]' -O.pass_config={"enable_fi_allreduce_fusion":true,"enable_attn_fusion":true,"enable_noop":true}2 (
none_fusion_attention) and 3 (none_fusion_attention_allreduce) are newly possible with this PR. On main, results are similar except those two are worse as fusion cannot happen without custom ops enabled.redhatai/meta-llama-3.1-70B-Instruct-FP8 (TP=1):Past QPS=10 the server is overloaded so the latency spikes and becomes much more variable. Also note that allreduce fusion is a noop for tp=1.
📊 TTFT Median (ms)
📊 TPOT Median (ms)
📊 ITL Median (ms)
redhatai/meta-llama-3.1-70B-Instruct-FP8 (TP=4):Note that allreduce fusion reduces TPOT at low QP but increases it at high QPS and increases TTFT across the board, this will be addressed in #24248 and #24252.
📊 TTFT Median (ms)
📊 TPOT Median (ms)
📊 ITL Median (ms)