-
-
Notifications
You must be signed in to change notification settings - Fork 11k
[Kernels] Add an inductor pass to rewrite and fuse collective communication ops with gemms #9886
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
👋 Hi! Thank you for contributing to the vLLM project. Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can do one of these:
🚀 |
|
This pull request has merge conflicts that must be resolved before it can be |
b3200f8 to
5183999
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looking forward to this one!
0a1f637 to
1c9d79c
Compare
e164973 to
1683f80
Compare
|
This pull request has merge conflicts that must be resolved before it can be |
1683f80 to
34de3a4
Compare
|
This pull request has merge conflicts that must be resolved before it can be |
ef2be0d to
7ebd94c
Compare
|
This pull request has merge conflicts that must be resolved before it can be |
|
This pull request has merge conflicts that must be resolved before it can be |
d713a7d to
7e2c490
Compare
| device_group = group.device_group | ||
| rank = group.rank_in_group | ||
|
|
||
| if use_flux: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we maybe use a better abstraction than if statements based on use_flux?
| fused_node = graph.call_function(fused_gemm_func, | ||
| kwargs=kwargs) | ||
|
|
||
| graph.inserting_after(fused_node) | ||
| result_node_new = graph.call_function(operator.getitem, | ||
| (fused_node, 0)) | ||
| residual_node_new = graph.call_function( | ||
| operator.getitem, (fused_node, 1)) | ||
| my_residual_node_new = graph.call_function( | ||
| operator.getitem, (fused_node, 2)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think multi-output match has a utility that emits a function and tuple accessors.
| res_replacements.append(residual_node_new) | ||
| my_res_replacements.append(my_residual_node_new) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Any reason we save all of the residuals instead of just the previous one?
| if gemm_1 is None or gemm_2 is None: | ||
| raise ValueError("Missing 'val' in gemm weights meta data") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wouldn't it be simpler if you just do meta["val"]
Signed-off-by: Bill Nell <bill@neuralmagic.com>
Signed-off-by: Bill Nell <bill@neuralmagic.com>
Signed-off-by: Bill Nell <bill@neuralmagic.com>
Signed-off-by: Bill Nell <bill@neuralmagic.com>
Signed-off-by: Bill Nell <bill@neuralmagic.com>
Signed-off-by: Bill Nell <bill@neuralmagic.com>
Signed-off-by: Bill Nell <bill@neuralmagic.com>
Signed-off-by: Bill Nell <bill@neuralmagic.com>
Signed-off-by: Bill Nell <bill@neuralmagic.com>
Signed-off-by: Bill Nell <bill@neuralmagic.com>
Signed-off-by: Bill Nell <bill@neuralmagic.com>
7e2c490 to
590b3d2
Compare
|
This pull request has been automatically marked as stale because it has not had any activity within 90 days. It will be automatically closed if no further activity occurs within 30 days. Leave a comment if you feel this pull request should remain open. Thank you! |
|
This pull request has been automatically closed due to inactivity. Please feel free to reopen if you intend to continue working on it. Thank you! |
Add an inductor pass to rewrite and fuse collective communication ops with gemms
See #9883 for version that includes llama hacks.
TODO:
torch._inductor.ir.ExternKernel.__str__pytorch/pytorch#139501cc @tlrmchlsmth , @ProExpertProg , @SageMoore , @youkaichao
Requires a special config to run:
Some benchmark results: