Skip to content
Merged
5 changes: 2 additions & 3 deletions vllm/compilation/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,10 +97,9 @@ def compile_context(self, runtime_shape: int | None = None):
compilation (e.g. partition rules, pass context)."""
with pass_context(runtime_shape):
if self.compilation_config.use_inductor_graph_partition:
inductor_partition_ops = resolve_defined_ops(
with inductor_partition_rule_context(
self.compilation_config.splitting_ops
)
with inductor_partition_rule_context(inductor_partition_ops):
):
yield
else:
yield
Expand Down
39 changes: 12 additions & 27 deletions vllm/compilation/partition_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,12 @@

import contextlib
import logging
from typing import TYPE_CHECKING

import torch
from torch._library.utils import lookup_op

from vllm.logger import init_logger

if TYPE_CHECKING:
import torch

logger = init_logger(__name__)


Expand Down Expand Up @@ -56,47 +53,35 @@ def resolve_defined_ops(op_names: list[str]) -> list["torch._ops.OpOverload"]:


@contextlib.contextmanager
def inductor_partition_rule_context(overloads: list["torch._ops.OpOverload"]):
def inductor_partition_rule_context(splitting_ops: list[str]):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While we're at it could we remove resolve_op_overloads from the Dynamo partition path as well and use string names for ops there too?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes I will update dynamo partition path in a follow up pr.

"""Context manager to temporarily register Inductor partition rules.

Registers custom partition rules for specified operators, forcing the
Inductor scheduler to partition the graph at these operators. The rules
are automatically restored to their previous state on exit.

Note: Callers should use resolve_defined_ops() to convert operator names
to OpOverload objects before calling this function.

Args:
overloads: List of resolved operator overload objects.
splitting_ops: List of operator names to partition on.
"""
if not overloads:
if not splitting_ops:
logger.debug("No partition ops provided; skipping rule registration.")
yield
return

from torch._inductor.scheduler import ( # type: ignore
_custom_should_partition_fns,
register_should_partition_rule,
)

def _always_partition(*_args, **_kwargs):
return True

# Save current state before registering
saved_rules = _custom_should_partition_fns.copy()

for overload in overloads:
register_should_partition_rule(
overload,
_always_partition,
)
saved_splitting_ops: list[str] = list(
torch._inductor.config.custom_should_partition_ops
)
torch._inductor.config.custom_should_partition_ops = splitting_ops
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this safe? I believe this config knob is only recently added (last week). Users who use older version of pytorch would hit an error here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is only accessed by pytorch 2.9.0 and pytorch 2.9.1. Pytorch 2.9.1 will include the config.
Pytorch 2.9.0 is patched here: https://github.com/vllm-project/vllm/pull/27702/files#diff-ffb2dbecaf604d1ecf2c216667a1f81425d1a593be868c0be141a62f6503b0c6R371-R375

Test passed for both pytorch versions.


logger.debug("Registered inductor partition rules for %d operators", len(overloads))
logger.debug(
"Registered inductor partition rules for %d operators", len(splitting_ops)
)

try:
yield
finally:
# Clear and restore previous state
_custom_should_partition_fns.clear()
_custom_should_partition_fns.update(saved_rules)
torch._inductor.config.custom_should_partition_ops = saved_splitting_ops
logger.debug("Restored previous partition rules state.")
30 changes: 1 addition & 29 deletions vllm/compilation/pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,27 +113,6 @@ def configure(self, config: VllmConfig):
self.post_cleanup = PostCleanupPass(config)
self.fix_functionalization = FixFunctionalizationPass(config)

# [HACK: Bug with Inductor graph partition and torch.compile cache]
# In PyTorch 2.9, torch.compile has a bug where the graph
# partition is not taken into account during caching.
# Because vLLM's Mode.VLLM_COMPILE is the only mode that uses
# Inductor graph partition, and VLLM_COMPILE implies there
# is a PostGradPassManager, we put the list of operators to graph
# partition into the PostGradPassManager's uuid (which
# then gets incorporated into Inductor's FX graph cache key).
# Remove this hack whenever torch.compile fixes it.

# This is the list of operators that vLLM asks Inductor to split.
self.inductor_splitting_ops = []
if (
config.compilation_config.use_inductor_graph_partition
and config.compilation_config.splitting_ops is not None
):
# Sort them so we're not dependent on the ordering.
self.inductor_splitting_ops = sorted(
config.compilation_config.splitting_ops
)

def add(self, pass_: InductorPass):
assert isinstance(pass_, InductorPass)
self.passes.append(pass_)
Expand All @@ -144,16 +123,9 @@ def uuid(self):
affects compilation caching. Its uuid depends on the UUIDs of all
dependent passes and the pass config. See InductorPass for more info.
"""
state = {
"pass_config": self.pass_config.uuid(),
"passes": [],
"inductor_splitting_ops": [],
}
state = {"pass_config": self.pass_config.uuid(), "passes": []}
for pass_ in self.passes:
state["passes"].append(pass_.uuid())
state["passes"].append(self.fix_functionalization.uuid())

# See [HACK: Bug with Inductor graph partition and torch.compile cache]
state["inductor_splitting_ops"].extend(self.inductor_splitting_ops)

return InductorPass.hash_dict(state)
26 changes: 22 additions & 4 deletions vllm/env_override.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,6 @@ def should_partition_patched(self, node, should_log: bool = False) -> bool:
from torch._inductor.scheduler import (
BaseSchedulerNode,
FusedSchedulerNode,
_custom_should_partition_fns,
)
from torch._inductor.utils import (
_unstable_customized_partition_wrapper,
Expand All @@ -283,9 +282,21 @@ def should_partition_patched(self, node, should_log: bool = False) -> bool:
# Allow users to manually specify if a node should be partitioned
# Can only do this for FallbackKernels
ir_node = node.node
if isinstance(ir_node, ir.FallbackKernel):
operator = ir_node.op_overload
if operator is not None and operator in _custom_should_partition_fns:
if isinstance(ir_node, torch._inductor.ir.FallbackKernel) and (
op := ir_node.op_overload
):
op_overload_packet_name = op.name()
op_overload_name = (
f"{op_overload_packet_name}.{op._overloadname}"
if isinstance(op, torch._ops.OpOverload)
else op_overload_packet_name
)
if (
op_overload_packet_name
in torch._inductor.config.custom_should_partition_ops
or op_overload_name in torch._inductor.config.custom_should_partition_ops
):
assert isinstance(op, torch._ops.OpOverload)
return True

# When not using cudagraphs, keep all kernels in the `call` function
Expand Down Expand Up @@ -355,6 +366,13 @@ def _update_scheduler_patched(self) -> None:
if is_torch_equal("2.9.0"):
from torch._inductor.codegen.wrapper import PythonWrapperCodegen
from torch._inductor.graph import GraphLowering
from torch.utils._config_module import _Config, _ConfigEntry

# `custom_should_partition_ops` is a new config after 2.9.0. So this would
# not overwrite any user configs.
torch._inductor.config._config["custom_should_partition_ops"] = _ConfigEntry(
_Config(default=[])
)

PythonWrapperCodegen.memory_plan_reuse = memory_plan_reuse_patched
GraphLowering._update_scheduler = _update_scheduler_patched