Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 48 additions & 20 deletions vllm/compilation/partition_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,29 +74,57 @@ def inductor_partition_rule_context(overloads: list["torch._ops.OpOverload"]):
yield
return

from torch._inductor.scheduler import ( # type: ignore
_custom_should_partition_fns,
register_should_partition_rule,
)
try:
# Try new API first using torch._inductor.config.custom_should_partition_ops
from torch._inductor import config # type: ignore

op_names = []
for overload in overloads:
op_names.append(str(overload._schema.name))

def _always_partition(*_args, **_kwargs):
return True
# Save current state before registering
saved_ops = config.custom_should_partition_ops.copy()

# Save current state before registering
saved_rules = _custom_should_partition_fns.copy()
config.custom_should_partition_ops.extend(op_names)

for overload in overloads:
register_should_partition_rule(
overload,
_always_partition,
logger.debug(
"Registered inductor partition ops for %d operators", len(op_names)
)

logger.debug("Registered inductor partition rules for %d operators", len(overloads))
try:
yield
finally:
# Restore previous state
config.custom_should_partition_ops.clear()
config.custom_should_partition_ops.extend(saved_ops)
logger.debug("Restored previous partition ops state.")
except (ImportError, AttributeError):
# Fall back to old API if new API is not available
from torch._inductor.scheduler import ( # type: ignore
_custom_should_partition_fns,
register_should_partition_rule,
)

try:
yield
finally:
# Clear and restore previous state
_custom_should_partition_fns.clear()
_custom_should_partition_fns.update(saved_rules)
logger.debug("Restored previous partition rules state.")
def _always_partition(*_args, **_kwargs):
return True

# Save current state before registering
saved_rules = _custom_should_partition_fns.copy()

for overload in overloads:
register_should_partition_rule(
overload,
_always_partition,
)

logger.debug(
"Registered inductor partition rules for %d operators", len(overloads)
)

try:
yield
finally:
# Clear and restore previous state
_custom_should_partition_fns.clear()
_custom_should_partition_fns.update(saved_rules)
logger.debug("Restored previous partition rules state.")