@@ -470,8 +470,7 @@ def call_trtllm_fused_allreduce_norm(
470470 )
471471 else :
472472 allreduce_out = tensor_model_parallel_all_reduce (allreduce_in )
473- if (scale_factor is not None and scale_out is None
474- and fuse_rms_quant ):
473+ if (scale_factor is not None and scale_out is None ):
475474 # Do fused rms norm static fp8 quant fused op
476475 if norm_out is None :
477476 torch .ops ._C .fused_add_rms_norm_static_fp8_quant (
@@ -490,12 +489,13 @@ def call_trtllm_fused_allreduce_norm(
490489 torch .ops ._C .rms_norm (norm_out , allreduce_out , rms_gamma ,
491490 rms_eps )
492491 if scale_factor is not None :
493- if scale_out is not None :
494- torch .ops ._C .scaled_fp4_quant (quant_out , norm_out ,
495- scale_out , scale_factor )
496- else :
497- torch .ops ._C .static_scaled_fp8_quant (
498- quant_out , norm_out , scale_factor )
492+ assert scale_out is not None
493+ torch .ops ._C .scaled_fp4_quant (quant_out , norm_out ,
494+ scale_out , scale_factor )
495+ # if scale_out is not None:
496+ # else:
497+ # torch.ops._C.static_scaled_fp8_quant(
498+ # quant_out, norm_out, scale_factor)
499499 if scale_factor is None or norm_out is not None :
500500 # we need to return allreduce outpput
501501 # in cases of non quant fused AR + RMS norm
0 commit comments