Skip to content

Commit ad72306

Browse files
committed
Support matching to native rms_norm, fp8_quant
Signed-off-by: ilmarkov <markovilya197@gmail.com>
1 parent 3e4e159 commit ad72306

File tree

4 files changed

+628
-373
lines changed

4 files changed

+628
-373
lines changed

tests/compile/test_fusion_all_reduce.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,8 @@ def all_reduce_fusion_pass_on_test_model(local_rank: int, world_size: int,
196196
initialize_model_parallel(tensor_model_parallel_size=world_size)
197197

198198
vllm_config = VllmConfig(compilation_config=CompilationConfig(
199-
level=CompilationLevel.PIECEWISE, custom_ops=["+rms_norm"]))
199+
level=CompilationLevel.PIECEWISE,
200+
custom_ops=["+rms_norm", "+quant_fp8"]))
200201
vllm_config.compilation_config.pass_config = PassConfig(
201202
enable_fi_allreduce_fusion=True, enable_noop=False)
202203
vllm_config.device_config = DeviceConfig(device=torch.device("cuda"))

0 commit comments

Comments
 (0)