3131 logger .info ("Attempting to use flux but flux not installed." )
3232 use_flux = False
3333
34-
3534# Depends on arch, see auto_tile_shape in include/flux/gemm_hparams.h
3635# Can be 256 on sm80.
3736FLUX_TILE_SIZE : int = 128
@@ -60,11 +59,11 @@ def residual_slice_shape_fake(residual: torch.Tensor, rank: int) -> int:
6059
6160
6261def match_gemm_rs_ag_gemm (
63- residual : torch .Tensor ,
64- gemm_1_weights : torch .Tensor ,
65- gemm_1_activations : torch .Tensor ,
66- rms_norm_weights : torch .Tensor ,
67- gemm_2_weights : torch .Tensor ,
62+ residual : torch .Tensor ,
63+ gemm_1_weights : torch .Tensor ,
64+ gemm_1_activations : torch .Tensor ,
65+ rms_norm_weights : torch .Tensor ,
66+ gemm_2_weights : torch .Tensor ,
6867) -> Tuple [torch .Tensor , torch .Tensor ]:
6968 gemm_1_w_perm = torch .ops .aten .permute .default (gemm_1_weights , [1 , 0 ])
7069 mm_1 = torch .ops .aten .mm .default (gemm_1_activations , gemm_1_w_perm )
@@ -239,10 +238,10 @@ def gemm_rs_ag_gemm_fake(
239238
240239
241240def match_final (
242- my_residual : torch .Tensor ,
243- gemm_1_weights : torch .Tensor ,
244- gemm_1_activations : torch .Tensor ,
245- rms_norm_weights : torch .Tensor ,
241+ my_residual : torch .Tensor ,
242+ gemm_1_weights : torch .Tensor ,
243+ gemm_1_activations : torch .Tensor ,
244+ rms_norm_weights : torch .Tensor ,
246245) -> torch .Tensor :
247246 gemm_1_w_perm = torch .ops .aten .permute .default (gemm_1_weights , [1 , 0 ])
248247 mm_1 = torch .ops .aten .mm .default (gemm_1_activations , gemm_1_w_perm )
@@ -260,7 +259,7 @@ def match_final(
260259 return normalized
261260
262261
263- # Register this as a custom op since all reduce cannot be torch.compiled yet.
262+ # Register this as a custom op since all gather cannot be torch.compiled yet.
264263def gemm_ag_final (my_residual : torch .Tensor , gemm_1_weights : torch .Tensor ,
265264 gemm_1_activations : torch .Tensor ,
266265 rms_norm_weights : torch .Tensor ) -> torch .Tensor :
@@ -333,17 +332,14 @@ def __init__(self, config: CompilationConfig):
333332 inputs = [resid , x , w , resid_w , x2 ]
334333 final_inputs = [x , w , resid , resid_w ]
335334
336- register_replacement (
337- match_gemm_rs_ag_gemm ,
338- match_gemm_rs_ag_gemm ,
339- inputs ,
340- fwd_only , [self .gemm_rs_ag_gemm_pattern ],
341- extra_check = lambda m : self .record_match (m ))
335+ register_replacement (match_gemm_rs_ag_gemm ,
336+ match_gemm_rs_ag_gemm ,
337+ inputs ,
338+ fwd_only , [self .gemm_rs_ag_gemm_pattern ],
339+ extra_check = lambda m : self .record_match (m ))
342340
343- register_replacement (match_final
344- torch .ops .vllm .gemm_ag_final ,
345- final_inputs , fwd_only ,
346- [self .final_pattern ])
341+ register_replacement (match_final , torch .ops .vllm .gemm_ag_final ,
342+ final_inputs , fwd_only , [self .final_pattern ])
347343
348344 def record_match (self , match : Match ) -> bool :
349345 # Hijack the extra_check to record the match and
@@ -394,16 +390,10 @@ def find_min_index(match: Match) -> int:
394390
395391 # Extract group_name from matched code. Use to
396392 # generate proper replacement code.
397- #ar_node = find_auto_fn(match.nodes, torch.ops.vllm.inplace_all_reduce.default)
398- ar_node = None
399- if ar_node is not None :
400- tp_group_name = ar_node .kwargs ["group_name" ]
401- else :
402- ar_node = find_fn (
403- match .nodes ,
404- torch .ops .vllm .all_reduce .default )
405- assert ar_node is not None
406- tp_group_name = ar_node .args [1 ]
393+ ar_node = find_fn (match .nodes ,
394+ torch .ops .vllm .all_reduce .default )
395+ assert ar_node is not None
396+ tp_group_name = ar_node .args [1 ]
407397
408398 fused_gemm_func = get_gemm_rs_ag_gemm (
409399 use_flux , max_m , gemm_1 .dtype , gemm_1 .shape , gemm_2 .dtype ,
0 commit comments