Skip to content

Commit 05d8de5

Browse files
fix error
1 parent db74ab1 commit 05d8de5

File tree

2 files changed

+37
-22
lines changed

2 files changed

+37
-22
lines changed

tests/compile/piecewise/test_simple.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ def test_simple_piecewise_compile(use_inductor):
143143

144144
@torch.inference_mode()
145145
@pytest.mark.parametrize("splitting_ops", [["silly::attention"], []])
146-
def test_simple_inductor_graph_partition(splitting_ops):
146+
def test_simple_inductor_graph_partition(monkeypatch, splitting_ops):
147147
if not is_torch_equal_or_newer("2.9.0.dev"):
148148
pytest.skip("inductor graph partition is only available in PyTorch 2.9+")
149149

vllm/compilation/partition_rules.py

Lines changed: 36 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -37,29 +37,44 @@ def resolve_defined_ops(op_names: list[str]) -> list[torch._ops.OpOverload]:
3737
Returns:
3838
List of successfully resolved operator overloads
3939
"""
40-
resolved = []
40+
resolved: list[torch._ops.OpOverload] = []
4141
for op_name in op_names:
42-
try:
43-
op = lookup_op(op_name)
44-
# Handle OpOverloadPacket: get .default if available
45-
if hasattr(op, "default"):
46-
resolved.append(op.default)
42+
overload: torch._ops.OpOverload | None = None
43+
candidate_names = [op_name]
44+
45+
# When the caller omits an explicit overload (e.g. "namespace::op"),
46+
# also try the conventional ".default" suffix.
47+
if "." not in op_name.split("::")[-1]:
48+
candidate_names.append(f"{op_name}.default")
49+
50+
for candidate in candidate_names:
51+
try:
52+
op = lookup_op(candidate)
53+
except Exception:
54+
continue
55+
56+
# lookup_op may return either an OpOverload (desired) or an
57+
# OpOverloadPacket (collection of overloads).
58+
if hasattr(op, "overloads"):
59+
overloads = list(op.overloads())
60+
if "default" in overloads:
61+
overload = op.default
62+
elif len(overloads) == 1:
63+
overload = getattr(op, overloads[0])
64+
else:
65+
logger.warning(
66+
"Operator '%s' has multiple overloads (%s); please "
67+
"specify the desired overload explicitly.",
68+
candidate,
69+
", ".join(overloads),
70+
)
4771
else:
48-
resolved.append(op)
49-
except Exception:
50-
# If lookup fails and no overload specified, try with .default
51-
if "." not in op_name.split("::")[-1]:
52-
try:
53-
op = lookup_op(f"{op_name}.default")
54-
resolved.append(op)
55-
continue
56-
except Exception:
57-
pass
58-
# Skip operators that don't exist
59-
logger.warning(
60-
"Failed to resolve operator for Inductor partition: %s", op_name
61-
)
62-
continue
72+
overload = op # Already an OpOverload
73+
74+
if overload is not None:
75+
break
76+
77+
resolved.append(overload)
6378

6479
return resolved
6580

0 commit comments

Comments
 (0)