@@ -37,29 +37,44 @@ def resolve_defined_ops(op_names: list[str]) -> list[torch._ops.OpOverload]:
3737 Returns:
3838 List of successfully resolved operator overloads
3939 """
40- resolved = []
40+ resolved : list [ torch . _ops . OpOverload ] = []
4141 for op_name in op_names :
42- try :
43- op = lookup_op (op_name )
44- # Handle OpOverloadPacket: get .default if available
45- if hasattr (op , "default" ):
46- resolved .append (op .default )
42+ overload : torch ._ops .OpOverload | None = None
43+ candidate_names = [op_name ]
44+
45+ # When the caller omits an explicit overload (e.g. "namespace::op"),
46+ # also try the conventional ".default" suffix.
47+ if "." not in op_name .split ("::" )[- 1 ]:
48+ candidate_names .append (f"{ op_name } .default" )
49+
50+ for candidate in candidate_names :
51+ try :
52+ op = lookup_op (candidate )
53+ except Exception :
54+ continue
55+
56+ # lookup_op may return either an OpOverload (desired) or an
57+ # OpOverloadPacket (collection of overloads).
58+ if hasattr (op , "overloads" ):
59+ overloads = list (op .overloads ())
60+ if "default" in overloads :
61+ overload = op .default
62+ elif len (overloads ) == 1 :
63+ overload = getattr (op , overloads [0 ])
64+ else :
65+ logger .warning (
66+ "Operator '%s' has multiple overloads (%s); please "
67+ "specify the desired overload explicitly." ,
68+ candidate ,
69+ ", " .join (overloads ),
70+ )
4771 else :
48- resolved .append (op )
49- except Exception :
50- # If lookup fails and no overload specified, try with .default
51- if "." not in op_name .split ("::" )[- 1 ]:
52- try :
53- op = lookup_op (f"{ op_name } .default" )
54- resolved .append (op )
55- continue
56- except Exception :
57- pass
58- # Skip operators that don't exist
59- logger .warning (
60- "Failed to resolve operator for Inductor partition: %s" , op_name
61- )
62- continue
72+ overload = op # Already an OpOverload
73+
74+ if overload is not None :
75+ break
76+
77+ resolved .append (overload )
6378
6479 return resolved
6580
0 commit comments