Skip to content

Conversation

@Chenyaaang
Copy link
Contributor

@Chenyaaang Chenyaaang commented Oct 9, 2025

Fix TPU torch compile error after #26113

This pr includes:

  1. Set compilation backend to openxla on TPU platform
  2. Make sure TPU is using forward_tpu when dispatching in custom ops
  3. Bypass some backend checks which require either eager or inductor for non-tpu platforms.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

I've reviewed the changes and they look good for fixing the TPU compilation issue. The logic to enable custom ops and set the openxla backend for TPU is correct. I have one suggestion to improve the user experience by adding a log message when the backend is overridden, which is consistent with other parts of the code.

Copy link
Collaborator

@ProExpertProg ProExpertProg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for catching this, sorry we didn't catch this during review. The reason this got broken is because the logic is complicated, could we improve it?

compilation_config.backend = "openxla"
# Note: the default backend is set to inductor now
# we want to overwrite to openxla to execute the ops properly on TPU.
compilation_config.backend = "openxla"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you make this a property of current_platform? Perhaps current_platform.default_dynamo_backend?

Comment on lines 326 to 342
# If user does not set custom ops via none or all set it here based on
# compilation level and backend.
if (
self.compilation_config.custom_ops.count("none")
+ self.compilation_config.custom_ops.count("all")
== 0
):
from vllm.platforms import current_platform
if (
self.compilation_config.level > 0
and self.compilation_config.backend != "eager"
and not current_platform.is_tpu()
):
self.compilation_config.custom_ops.append("none")
else:
self.compilation_config.custom_ops.append("all")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's just try to move this after the platform-specific update, and TPU can set custom ops the way it wants to.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes we can move the self.compilation_config.custom_ops logic after current_platform.check_and_update_config(self), and the new statement will be self.compilation_config.backend not in ["eager", "openxla"], does it look better to you?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I actually think that tpu platform should just do:

        if (
            self.compilation_config.custom_ops.count("none")
            + self.compilation_config.custom_ops.count("all")
            == 0
        ):
            self.compilation_config.custom_ops.append("all")

And then the common logic won't even be hit because custom_ops already contains "all" (once the logic is moved after).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After some investigations, I prefer not to move current_platform.check_and_update_config(self) up to before this part. I'm afraid the logic inside each platform's check_and_update_config might be related to some of its previous inits, and since my change is only related to TPU, I don't want to add more potential risks to other platforms.

@Chenyaaang Chenyaaang force-pushed the torch-compile-error branch from c1c65d0 to 280b389 Compare October 9, 2025 00:26
Signed-off-by: Chenyaaang <chenyangli@google.com>
@Chenyaaang Chenyaaang force-pushed the torch-compile-error branch from 280b389 to 144aeb5 Compare October 9, 2025 00:26
@ProExpertProg
Copy link
Collaborator

Original PR was reverted, we can include these fixes in the unrevert

@Chenyaaang
Copy link
Contributor Author

Original PR was reverted, we can include these fixes in the unrevert

SG, I'll go ahead and close this pr. Can you also cc me in the new pr? Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

tpu Related to Google TPUs

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants