diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 83d8cdae1ed3..f37c155c0fce 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -97,10 +97,9 @@ def compile_context(self, runtime_shape: int | None = None): compilation (e.g. partition rules, pass context).""" with pass_context(runtime_shape): if self.compilation_config.use_inductor_graph_partition: - inductor_partition_ops = resolve_defined_ops( + with inductor_partition_rule_context( self.compilation_config.splitting_ops - ) - with inductor_partition_rule_context(inductor_partition_ops): + ): yield else: yield diff --git a/vllm/compilation/partition_rules.py b/vllm/compilation/partition_rules.py index cea4f9a81637..094b86dcb4aa 100644 --- a/vllm/compilation/partition_rules.py +++ b/vllm/compilation/partition_rules.py @@ -3,15 +3,12 @@ import contextlib import logging -from typing import TYPE_CHECKING +import torch from torch._library.utils import lookup_op from vllm.logger import init_logger -if TYPE_CHECKING: - import torch - logger = init_logger(__name__) @@ -56,47 +53,35 @@ def resolve_defined_ops(op_names: list[str]) -> list["torch._ops.OpOverload"]: @contextlib.contextmanager -def inductor_partition_rule_context(overloads: list["torch._ops.OpOverload"]): +def inductor_partition_rule_context(splitting_ops: list[str]): """Context manager to temporarily register Inductor partition rules. Registers custom partition rules for specified operators, forcing the Inductor scheduler to partition the graph at these operators. The rules are automatically restored to their previous state on exit. - Note: Callers should use resolve_defined_ops() to convert operator names - to OpOverload objects before calling this function. - Args: - overloads: List of resolved operator overload objects. + splitting_ops: List of operator names to partition on. """ - if not overloads: + if not splitting_ops: logger.debug("No partition ops provided; skipping rule registration.") yield return - from torch._inductor.scheduler import ( # type: ignore - _custom_should_partition_fns, - register_should_partition_rule, - ) - - def _always_partition(*_args, **_kwargs): - return True - # Save current state before registering - saved_rules = _custom_should_partition_fns.copy() - for overload in overloads: - register_should_partition_rule( - overload, - _always_partition, - ) + saved_splitting_ops: list[str] = list( + torch._inductor.config.custom_should_partition_ops + ) + torch._inductor.config.custom_should_partition_ops = splitting_ops - logger.debug("Registered inductor partition rules for %d operators", len(overloads)) + logger.debug( + "Registered inductor partition rules for %d operators", len(splitting_ops) + ) try: yield finally: # Clear and restore previous state - _custom_should_partition_fns.clear() - _custom_should_partition_fns.update(saved_rules) + torch._inductor.config.custom_should_partition_ops = saved_splitting_ops logger.debug("Restored previous partition rules state.") diff --git a/vllm/compilation/pass_manager.py b/vllm/compilation/pass_manager.py index 3bc35a8f7198..dfda2adf1d3b 100644 --- a/vllm/compilation/pass_manager.py +++ b/vllm/compilation/pass_manager.py @@ -113,27 +113,6 @@ def configure(self, config: VllmConfig): self.post_cleanup = PostCleanupPass(config) self.fix_functionalization = FixFunctionalizationPass(config) - # [HACK: Bug with Inductor graph partition and torch.compile cache] - # In PyTorch 2.9, torch.compile has a bug where the graph - # partition is not taken into account during caching. - # Because vLLM's Mode.VLLM_COMPILE is the only mode that uses - # Inductor graph partition, and VLLM_COMPILE implies there - # is a PostGradPassManager, we put the list of operators to graph - # partition into the PostGradPassManager's uuid (which - # then gets incorporated into Inductor's FX graph cache key). - # Remove this hack whenever torch.compile fixes it. - - # This is the list of operators that vLLM asks Inductor to split. - self.inductor_splitting_ops = [] - if ( - config.compilation_config.use_inductor_graph_partition - and config.compilation_config.splitting_ops is not None - ): - # Sort them so we're not dependent on the ordering. - self.inductor_splitting_ops = sorted( - config.compilation_config.splitting_ops - ) - def add(self, pass_: InductorPass): assert isinstance(pass_, InductorPass) self.passes.append(pass_) @@ -144,16 +123,9 @@ def uuid(self): affects compilation caching. Its uuid depends on the UUIDs of all dependent passes and the pass config. See InductorPass for more info. """ - state = { - "pass_config": self.pass_config.uuid(), - "passes": [], - "inductor_splitting_ops": [], - } + state = {"pass_config": self.pass_config.uuid(), "passes": []} for pass_ in self.passes: state["passes"].append(pass_.uuid()) state["passes"].append(self.fix_functionalization.uuid()) - # See [HACK: Bug with Inductor graph partition and torch.compile cache] - state["inductor_splitting_ops"].extend(self.inductor_splitting_ops) - return InductorPass.hash_dict(state) diff --git a/vllm/env_override.py b/vllm/env_override.py index ae3e4e751bd9..14dae2850c35 100644 --- a/vllm/env_override.py +++ b/vllm/env_override.py @@ -272,7 +272,6 @@ def should_partition_patched(self, node, should_log: bool = False) -> bool: from torch._inductor.scheduler import ( BaseSchedulerNode, FusedSchedulerNode, - _custom_should_partition_fns, ) from torch._inductor.utils import ( _unstable_customized_partition_wrapper, @@ -283,9 +282,21 @@ def should_partition_patched(self, node, should_log: bool = False) -> bool: # Allow users to manually specify if a node should be partitioned # Can only do this for FallbackKernels ir_node = node.node - if isinstance(ir_node, ir.FallbackKernel): - operator = ir_node.op_overload - if operator is not None and operator in _custom_should_partition_fns: + if isinstance(ir_node, torch._inductor.ir.FallbackKernel) and ( + op := ir_node.op_overload + ): + op_overload_packet_name = op.name() + op_overload_name = ( + f"{op_overload_packet_name}.{op._overloadname}" + if isinstance(op, torch._ops.OpOverload) + else op_overload_packet_name + ) + if ( + op_overload_packet_name + in torch._inductor.config.custom_should_partition_ops + or op_overload_name in torch._inductor.config.custom_should_partition_ops + ): + assert isinstance(op, torch._ops.OpOverload) return True # When not using cudagraphs, keep all kernels in the `call` function @@ -355,6 +366,13 @@ def _update_scheduler_patched(self) -> None: if is_torch_equal("2.9.0"): from torch._inductor.codegen.wrapper import PythonWrapperCodegen from torch._inductor.graph import GraphLowering + from torch.utils._config_module import _Config, _ConfigEntry + + # `custom_should_partition_ops` is a new config after 2.9.0. So this would + # not overwrite any user configs. + torch._inductor.config._config["custom_should_partition_ops"] = _ConfigEntry( + _Config(default=[]) + ) PythonWrapperCodegen.memory_plan_reuse = memory_plan_reuse_patched GraphLowering._update_scheduler = _update_scheduler_patched