Skip to content

Conversation

@ProExpertProg
Copy link
Collaborator

@ProExpertProg ProExpertProg commented Oct 13, 2025

In-progress PR to test inductor partitioning in CI.

Includes:

  • turning on inductor partitioning by default

Past fixes in this PR now in main:

Past fixes no longer necessary:

@mergify
Copy link

mergify bot commented Oct 13, 2025

Documentation preview: https://vllm--26738.org.readthedocs.build/en/26738/

@mergify mergify bot added documentation Improvements or additions to documentation ci/build llama Related to Llama models rocm Related to AMD ROCm labels Oct 13, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request upgrades PyTorch to version 2.9 and introduces tests for Inductor graph partitioning. The changes span across CI configuration, Dockerfiles, build scripts, and test files to support the new PyTorch version and its features. My review focuses on the correctness and maintainability of these changes. I've identified a couple of high-severity issues related to a monkey patch with a misleading comment and an undocumented change in default behavior for a compilation flag. These should be addressed to ensure code clarity and prevent unexpected behavior for users.

Comment on lines 29 to 31
# Copied from torch._inductor.scheduler.Scheduler.should_partition. Patches
# [this code](https://github.com/pytorch/pytorch/blob/ecb53078faf86ca1b33277df33b82985675bb011/torch/_inductor/scheduler.py#L4712-L4724)
# so that we always return True.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The comment here is misleading. It states that the patch makes should_partition 'always return True', but the patched function does not always return True. It only changes the behavior for the case where torch._inductor.config.triton.cudagraphs is False. The original implementation returns False in this case, preventing partitioning, while the patch returns True to enable it.

This monkey patch on a PyTorch internal function is a significant change and should be documented with precision to ensure future maintainability. Please update the comment to accurately describe what the patch does and why it's necessary. For example:

# Copied from torch._inductor.scheduler.Scheduler.should_partition. Patches
# [this code](https://github.com/pytorch/pytorch/blob/ecb53078faf86ca1b33277df33b82985675bb011/torch/_inductor/scheduler.py#L4712-L4715)
# to force partitioning even when CUDA graphs are disabled. The original
# implementation returns False in this case, which prevents partitioning.
# This change is necessary to enable Inductor graph partitioning for vLLM's
# piecewise CUDAGraph mode, which may operate without `config.triton.cudagraphs`
# being globally enabled.
    # Copied from torch._inductor.scheduler.Scheduler.should_partition. Patches
    # [this code](https://github.com/pytorch/pytorch/blob/ecb53078faf86ca1b33277df33b82985675bb011/torch/_inductor/scheduler.py#L4712-L4715)
    # to force partitioning even when CUDA graphs are disabled. The original
    # implementation returns False in this case, which prevents partitioning.
    # This change is necessary to enable Inductor graph partitioning for vLLM's
    # piecewise CUDAGraph mode, which may operate without `config.triton.cudagraphs`
    # being globally enabled.

VLLM_DP_RANK_LOCAL: int = -1
VLLM_DP_SIZE: int = 1
VLLM_USE_STANDALONE_COMPILE: bool = False
VLLM_USE_STANDALONE_COMPILE: bool = True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This change modifies the default behavior of VLLM_USE_STANDALONE_COMPILE from False to True. This is a significant change as it alters the default compilation path for users on PyTorch >= 2.8. While this might be intentional for the PyTorch 2.9 upgrade, it's a change in default behavior that should be clearly communicated. Could you confirm if this is the intended new default? If so, please consider adding a note to the release documentation about this change and updating the related comment on lines 493-494.

vllm/envs.py Outdated
Comment on lines 493 to 501
# In torch <= 2.7 we ignore this flag; in torch >= 2.8 this is
# disabled by default.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This comment is now outdated due to the change in the default value of VLLM_USE_STANDALONE_COMPILE to True. It should be updated to reflect that standalone compilation is now enabled by default for PyTorch >= 2.8.

Suggested change
# In torch <= 2.7 we ignore this flag; in torch >= 2.8 this is
# disabled by default.
# In torch <= 2.7 we ignore this flag; in torch >= 2.8 this is
# enabled by default.

Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Comment on lines 28 to 333
def should_partition_patched(self, node, should_log: bool = False) -> bool:
# Copied from torch._inductor.scheduler.Scheduler.should_partition. Patches
# [this code](https://github.com/pytorch/pytorch/blob/ecb53078faf86ca1b33277df33b82985675bb011/torch/_inductor/scheduler.py#L4712-L4724)
# so that we always return True.
"""Return True if we should partition the inductor graph on this node"""

import torch._inductor.ir as ir
from torch._inductor.scheduler import (
BaseSchedulerNode,
FusedSchedulerNode,
_custom_should_partition_fns,
)
from torch._inductor.utils import (
_unstable_customized_partition_wrapper,
is_cudagraph_unsafe_op,
maybe_log_cudagraph_partition,
)

# Allow users to manually specify if a node should be partitioned
# Can only do this for FallbackKernels
ir_node = node.node
if isinstance(ir_node, ir.FallbackKernel):
operator = ir_node.op_overload
if operator is not None and operator in _custom_should_partition_fns:
return True

# When not using cudagraphs, keep all kernels in the `call` function
# instead of graph partition functions, since graph partition only brings
# benefit to cudagraph
if (
not torch._inductor.config.triton.cudagraphs
and _unstable_customized_partition_wrapper.wrapper is None
):
return True

# avoid duplicating logs when should_partition is called multiple times
# on the same node
def noop_log(msg: str, node: BaseSchedulerNode | None) -> None:
return

log_partition_reason = maybe_log_cudagraph_partition if should_log else noop_log

if isinstance(node, FusedSchedulerNode):
return any(self.should_partition(snode) for snode in node.snodes)

assert node.node is not None

if not node.is_gpu():
log_partition_reason("non gpu ops", node=node)

return True

if isinstance(node.node, ir.DeviceCopy):
log_partition_reason("DeviceCopy ops", node=node)
return True

if isinstance(node.node, ir.Conditional):
log_partition_reason("Conditional ops", node=node)
return True

if getattr(node.node, "unbacked_bindings", None):
log_partition_reason("unbacked binding ops", node=node)
return True

if is_cudagraph_unsafe_op(node.node):
log_partition_reason("CUDAGraph-unsafe custom ops", node=node)
return True

return False

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Patched scheduler still skips inductor partitioning

The new should_partition_patched is commented as forcing Inductor to always partition, but it still ends with return False. When none of the earlier conditions match, the monkeypatched method behaves exactly like the upstream implementation and refuses to partition, so the monkeypatch has no effect in the cases the commit is trying to fix. As a result, cudagraph‑unsafe ops can still be left unpartitioned and the referenced issue remains unresolved.

Useful? React with 👍 / 👎.

@mergify
Copy link

mergify bot commented Oct 13, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @ProExpertProg.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Oct 13, 2025
@ProExpertProg ProExpertProg force-pushed the luka/2.9-inductor-partition-test-with-fix branch 2 times, most recently from ee24dd3 to a654b46 Compare October 14, 2025 13:43
@mergify mergify bot removed the needs-rebase label Oct 14, 2025
@ProExpertProg ProExpertProg force-pushed the luka/2.9-inductor-partition-test-with-fix branch 2 times, most recently from 44de6d0 to a46b7b3 Compare October 14, 2025 15:24
@ProExpertProg ProExpertProg force-pushed the luka/2.9-inductor-partition-test-with-fix branch from 37947bd to a7d1db9 Compare October 14, 2025 19:25
@mergify
Copy link

mergify bot commented Oct 15, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @ProExpertProg.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@ProExpertProg ProExpertProg force-pushed the luka/2.9-inductor-partition-test-with-fix branch from eff712d to 67c4d19 Compare October 15, 2025 03:47
@mergify mergify bot removed the needs-rebase label Oct 15, 2025
@BoyuanFeng
Copy link
Contributor

this (Quantization test) may relate to #26878

@ProExpertProg ProExpertProg force-pushed the luka/2.9-inductor-partition-test-with-fix branch 3 times, most recently from 4e2976b to e811cb5 Compare October 15, 2025 05:06
@mergify
Copy link

mergify bot commented Oct 15, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @ProExpertProg.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Oct 15, 2025
@ProExpertProg ProExpertProg changed the title [DO NOT MERGE] 2.9, Inductor partition unit tests, monkeypatch fix [DO NOT MERGE] 2.9, Inductor partition, standalone compile, monkeypatch fix(es) Oct 15, 2025
@ProExpertProg
Copy link
Collaborator Author

@BoyuanFeng LORA test failure seems relevant: perhaps we need to review how we deal with weak_ref_tensors with inductor partitioning but also not sure why there's a CPU tensor there.

@ProExpertProg ProExpertProg force-pushed the luka/2.9-inductor-partition-test-with-fix branch from e811cb5 to b4518de Compare October 15, 2025 15:43
@mergify mergify bot removed the needs-rebase label Oct 15, 2025
@ProExpertProg ProExpertProg force-pushed the luka/2.9-inductor-partition-test-with-fix branch from b4518de to 6f1222c Compare October 16, 2025 00:41
@mergify
Copy link

mergify bot commented Oct 16, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @ProExpertProg.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify
Copy link

mergify bot commented Oct 16, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @ProExpertProg.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

Signed-off-by: ProExpertProg <lgovedic@redhat.com>
@ProExpertProg ProExpertProg force-pushed the luka/2.9-inductor-partition-test-with-fix branch from f2bde49 to 69bef4b Compare October 29, 2025 18:26
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci/build documentation Improvements or additions to documentation llama Related to Llama models ready ONLY add when PR is ready to merge/full CI is needed rocm Related to AMD ROCm

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants