@@ -163,7 +163,7 @@ def check_ir_node_bucketable(
163163
164164def _get_fx_node (
165165 snode_or_ir_node : Union ["scheduler.BaseSchedulerNode" , "ir.IRNode" ],
166- expected_op : Callable [[ Any ]] ,
166+ expected_op : Any ,
167167) -> torch .fx .Node :
168168 origins = None
169169 if isinstance (snode_or_ir_node , scheduler .BaseSchedulerNode ):
@@ -190,7 +190,7 @@ def _get_fx_node(
190190
191191def get_snode_process_group_info (
192192 snode : "scheduler.BaseSchedulerNode" ,
193- expected_op : Callable [[ Any ]] ,
193+ expected_op : Any ,
194194 resolve_pg : bool = False ,
195195) -> tuple [int , Union [str , ProcessGroup ]]:
196196 fx_node = _get_fx_node (snode , expected_op = expected_op )
@@ -248,7 +248,7 @@ def get_snode_tensor_info(
248248
249249def _estimate_bucketed_node_list (
250250 current_node_list : list ["scheduler.BaseSchedulerNode" ],
251- schedule_fallback_operation : Callable [[Any ]],
251+ schedule_fallback_operation : Callable [[Any ], Any ],
252252 group_size : int ,
253253 group_name : str ,
254254 name_to_buf : Dict [str , "scheduler.SchedulerBuffer" ],
@@ -272,7 +272,7 @@ def _estimate_bucketed_node_list(
272272 )
273273 return estimated_comm , comm_size_inp , comm_size_out
274274
275- if comm_func == torch .ops ._c10d_functional .all_gather_into_tensor .default :
275+ if comm_func == " torch.ops._c10d_functional.all_gather_into_tensor.default" :
276276 bucked_node = bucket_all_gathers (
277277 schedule_fallback_operation ,
278278 group_size ,
@@ -284,7 +284,7 @@ def _estimate_bucketed_node_list(
284284 )
285285 comm_size_inp = bucked_node [0 ].layout .size
286286 comm_size_out = bucked_node [1 ].layout .size
287- elif comm_func == torch .ops ._c10d_functional .reduce_scatter_tensor .default :
287+ elif comm_func == " torch.ops._c10d_functional.reduce_scatter_tensor.default" :
288288 bucked_node = bucket_reduce_scatters (
289289 schedule_fallback_operation ,
290290 group_size ,
@@ -311,7 +311,7 @@ def _estimate_bucketed_node_list(
311311
312312def estimate_bucketed_snode_runtime (
313313 node_bucket_dict : Dict [tuple [Any , ...], list ["scheduler.BaseSchedulerNode" ]],
314- schedule_fallback_operation : Callable [[Any ]],
314+ schedule_fallback_operation : Callable [[Any ], Any ],
315315 name_to_buf : Dict [str , "scheduler.SchedulerBuffer" ],
316316 comm_func : Callable [[Any ], Any ],
317317 comm_cache : Dict [Any , Any ],
@@ -687,7 +687,7 @@ def _get_dim0_padded_size(tensor_size: torch.Size, dim0_factor: int) -> torch.Si
687687
688688 reduce_scatter_tensor = schedule_fallback_operation (
689689 torch .ops ._c10d_functional .reduce_scatter_tensor .default ,
690- (reduce_scatter_input , reduce_op , group_size , group_name ),
690+ (reduce_scatter_input , str ( reduce_op ) , group_size , group_name ),
691691 {},
692692 dep_operations = chunk_cat ,
693693 )
0 commit comments