@@ -417,7 +417,6 @@ def call_trtllm_fused_allreduce_norm(
417417 fp32_acc : bool ,
418418 max_token_num : int ,
419419 pattern_code : int ,
420- fuse_rms_quant : bool ,
421420 norm_out : Optional [torch .Tensor ] = None ,
422421 quant_out : Optional [torch .Tensor ] = None ,
423422 scale_out : Optional [torch .Tensor ] = None ,
@@ -489,13 +488,8 @@ def call_trtllm_fused_allreduce_norm(
489488 torch .ops ._C .rms_norm (norm_out , allreduce_out , rms_gamma ,
490489 rms_eps )
491490 if scale_factor is not None :
492- assert scale_out is not None
493491 torch .ops ._C .scaled_fp4_quant (quant_out , norm_out ,
494492 scale_out , scale_factor )
495- # if scale_out is not None:
496- # else:
497- # torch.ops._C.static_scaled_fp8_quant(
498- # quant_out, norm_out, scale_factor)
499493 if scale_factor is None or norm_out is not None :
500494 # we need to return allreduce outpput
501495 # in cases of non quant fused AR + RMS norm
@@ -514,7 +508,6 @@ def call_trtllm_fused_allreduce_norm_fake(
514508 fp32_acc : bool ,
515509 max_token_num : int ,
516510 pattern_code : int ,
517- fuse_rms_quant : bool ,
518511 norm_out : Optional [torch .Tensor ] = None ,
519512 quant_out : Optional [torch .Tensor ] = None ,
520513 scale_out : Optional [torch .Tensor ] = None ,
@@ -547,17 +540,14 @@ def __init__(
547540 world_size : int ,
548541 use_fp32_lamport : bool = False ,
549542 max_token_num : int = 1024 ,
550- fuse_rms_quant : bool = False ,
551543 ):
552544 self .rank = rank
553545 self .world_size = world_size
554546 self .use_fp32_lamport = use_fp32_lamport
555547 self .trigger_completion_at_end = True
556548 self .launch_with_pdl = True
557549 self .fp32_acc = True
558- self .use_oneshot = False
559550 self .max_token_num = max_token_num
560- self .fuse_rms_quant = fuse_rms_quant
561551
562552 def get_trtllm_fused_allreduce_kwargs (self ):
563553 return {
@@ -567,7 +557,6 @@ def get_trtllm_fused_allreduce_kwargs(self):
567557 "trigger_completion_at_end" : self .trigger_completion_at_end ,
568558 "fp32_acc" : self .fp32_acc ,
569559 "max_token_num" : self .max_token_num ,
570- "fuse_rms_quant" : self .fuse_rms_quant ,
571560 }
572561
573562
@@ -1103,10 +1092,7 @@ def __init__(self, config: VllmConfig):
11031092 world_size = self .tp_size ,
11041093 use_fp32_lamport = use_fp32_lamport ,
11051094 max_token_num = max_num_token ,
1106- # fuse rms norm static fp8 quant fused op
1107- # in fallback path, when we don't use flashinfer
1108- fuse_rms_quant = config .compilation_config .pass_config .enable_fusion )
1109-
1095+ )
11101096 for epsilon in [1e-5 , 1e-6 ]:
11111097 AllReduceFusedRMSNormStaticQuantFP8Pattern (
11121098 epsilon ,
0 commit comments