-
-
Notifications
You must be signed in to change notification settings - Fork 4.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Kernel] Add torch custom op for all_reduce #7755
Conversation
👋 Hi! Thank you for contributing to the vLLM project. Once the PR is approved and ready to go, please make sure to run full CI as it is required to merge (or just use auto-merge). To run full CI, you can do one of these:
🚀 |
will take a look later! |
@@ -291,7 +288,33 @@ def all_reduce(self, input_: torch.Tensor) -> torch.Tensor: | |||
ipex.distributed.all_reduce(input_, group=self.device_group) | |||
else: | |||
torch.distributed.all_reduce(input_, group=self.device_group) | |||
return input_ | |||
|
|||
def out_of_place_ar(self, input_: torch.Tensor) -> torch.Tensor: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
where do you call this function?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
|
||
|
||
def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor: | ||
if get_tp_group().should_run_out_of_place_ar(input_): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
how does Dynamo work with this condition?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tested this by running the pytest tests/distributed/test_basic_distributed_correctness.py
test with dynamo enabled. Everything runs to completion without any errors when fullgraph is True or False.
Is there something specific you are worried about? I'm also happy to run additional tests.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you try to follow
vllm/tests/tpu/test_compilation.py
Line 13 in baa5467
with depyf.prepare_debug(temp_dir): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
okay, I think all these conditions turn into guards.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think moving should_custom_ar
to python makes sense to me. But I need to know more about how Dynamo treats the selection logic you use currently.
sorry for the long wait!
No worries. Thanks for taking a look :) |
close as it has been reworked in #8526 . thanks for your pioneering work! @SageMoore 🙏 |
This PR encapsulates all_reduce inside of a torch custom op. This allows torch.dynamo to run on graphs containing all_reduce.
The current implementation of all_reduce mixes in-place and out-of-place kernels for the sake having a unified interface. Unfortunately, this won't work with torch compile. The problem with the current implementation is that the output tensor can either share a buffer with the input tensor or have its own buffer. This inconsistency is not supported in torch compile. To get around this I've added in-place and out-of-place all_reduce apis to the GroupCoordinator. This allows us to have a custom op for each one and to do the dispatching outside of the ops.
I had some difficulty getting the should_custom_ar kernel working with dynamo. The issue seemed to be centered around its non-tensor return type. For the sake of expediency I just moved that kernel to python as a workaround since it doesn't have any users outside of CustomAllReduce. I'm happy to look into this more if others feel strongly that it should remain in cuda.