Skip to content

Commit 82276a9

Browse files
author
ilmarkov
committed
Try custom_ops in fallback
Signed-off-by: ilmarkov <imarkov@redhat.com>
1 parent 27a145e commit 82276a9

File tree

1 file changed

+8
-8
lines changed

1 file changed

+8
-8
lines changed

vllm/compilation/collective_fusion.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -470,8 +470,7 @@ def call_trtllm_fused_allreduce_norm(
470470
)
471471
else:
472472
allreduce_out = tensor_model_parallel_all_reduce(allreduce_in)
473-
if (scale_factor is not None and scale_out is None
474-
and fuse_rms_quant):
473+
if (scale_factor is not None and scale_out is None):
475474
# Do fused rms norm static fp8 quant fused op
476475
if norm_out is None:
477476
torch.ops._C.fused_add_rms_norm_static_fp8_quant(
@@ -490,12 +489,13 @@ def call_trtllm_fused_allreduce_norm(
490489
torch.ops._C.rms_norm(norm_out, allreduce_out, rms_gamma,
491490
rms_eps)
492491
if scale_factor is not None:
493-
if scale_out is not None:
494-
torch.ops._C.scaled_fp4_quant(quant_out, norm_out,
495-
scale_out, scale_factor)
496-
else:
497-
torch.ops._C.static_scaled_fp8_quant(
498-
quant_out, norm_out, scale_factor)
492+
assert scale_out is not None
493+
torch.ops._C.scaled_fp4_quant(quant_out, norm_out,
494+
scale_out, scale_factor)
495+
# if scale_out is not None:
496+
# else:
497+
# torch.ops._C.static_scaled_fp8_quant(
498+
# quant_out, norm_out, scale_factor)
499499
if scale_factor is None or norm_out is not None:
500500
# we need to return allreduce outpput
501501
# in cases of non quant fused AR + RMS norm

0 commit comments

Comments
 (0)