Skip to content

Commit 3db307d

Browse files
committed
Fix test to support non custom ops
Signed-off-by: ilmarkov <markovilya197@gmail.com>
1 parent dd988a2 commit 3db307d

File tree

1 file changed

+77
-24
lines changed

1 file changed

+77
-24
lines changed

tests/compile/test_fusion_all_reduce.py

Lines changed: 77 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,14 @@
66
import torch
77

88
import vllm.envs as envs
9+
from vllm import _custom_ops as ops
910
from vllm.compilation.collective_fusion import AllReduceFusionPass
1011
from vllm.compilation.fix_functionalization import FixFunctionalizationPass
12+
from vllm.compilation.fx_utils import find_op_nodes
1113
from vllm.compilation.noop_elimination import NoOpEliminationPass
1214
from vllm.config import (CompilationConfig, CompilationLevel, DeviceConfig,
1315
ModelConfig, PassConfig, VllmConfig,
14-
set_current_vllm_config)
16+
get_current_vllm_config, set_current_vllm_config)
1517
from vllm.distributed import tensor_model_parallel_all_reduce
1618
from vllm.distributed.parallel_state import (init_distributed_environment,
1719
initialize_model_parallel)
@@ -25,7 +27,19 @@
2527
from .backend import TestBackend
2628

2729

30+
def finisher(hidden_states):
31+
custom_ops = get_current_vllm_config().compilation_config.custom_ops
32+
if not custom_ops or "+quant_fp8" not in custom_ops:
33+
# Hack: use dynamic fp8 quantization to
34+
# suppress torch.compile optimizations
35+
# that prevent pattern matching
36+
return ops.scaled_fp8_quant(hidden_states)
37+
else:
38+
return hidden_states
39+
40+
2841
class TestAllReduceRMSNormModel(torch.nn.Module):
42+
pattern_code = 1
2943

3044
def __init__(self, hidden_size=16, token_num=16, eps=1e-6):
3145
super().__init__()
@@ -34,10 +48,14 @@ def __init__(self, hidden_size=16, token_num=16, eps=1e-6):
3448
self.norm = RMSNorm(hidden_size, eps)
3549

3650
def forward(self, hidden_states, residual):
37-
view = hidden_states.reshape(-1, self.hidden_size)
38-
all_reduce = tensor_model_parallel_all_reduce(view)
39-
norm = self.norm(all_reduce)
40-
return norm
51+
# view = hidden_states.reshape(-1, self.hidden_size)
52+
all_reduce = tensor_model_parallel_all_reduce(hidden_states)
53+
54+
hidden_states = self.norm(all_reduce)
55+
56+
hidden_states = finisher(hidden_states)
57+
58+
return hidden_states
4159

4260
def ops_in_model_before(self):
4361
return [torch.ops.vllm.all_reduce.default]
@@ -47,6 +65,7 @@ def ops_in_model_after(self):
4765

4866

4967
class TestAllReduceFusedAddRMSNormModel(torch.nn.Module):
68+
pattern_code = 1
5069

5170
def __init__(self, hidden_size=16, token_num=16, eps=1e-6):
5271
super().__init__()
@@ -57,35 +76,54 @@ def __init__(self, hidden_size=16, token_num=16, eps=1e-6):
5776
def forward(self, hidden_states, residual):
5877
view = hidden_states.reshape(-1, self.hidden_size)
5978
all_reduce = tensor_model_parallel_all_reduce(view)
60-
norm, _ = self.norm(all_reduce, residual)
61-
return norm
62-
63-
def ops_in_model_before(self):
64-
return [torch.ops.vllm.all_reduce.default]
79+
hidden_states, residual = self.norm(all_reduce, residual)
80+
# Hack: use dynamic fp8 quantization to
81+
# suppress torch.compile optimizations
82+
# that prevent pattern matching
83+
hidden_states = finisher(hidden_states)
84+
return hidden_states, residual
6585

6686
def ops_in_model_after(self):
6787
return [torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default]
6888

89+
def ops_in_model_before(self):
90+
return [
91+
torch.ops.vllm.all_reduce.default,
92+
]
93+
6994

7095
class TestAllReduceFusedAddRMSNormStaticQuantFP8Model(torch.nn.Module):
96+
pattern_code = 2
7197

7298
def __init__(self, hidden_size=16, token_num=16, eps=1e-6):
7399
super().__init__()
74100
self.hidden_size = hidden_size
75101
self.eps = eps
76102
self.norm = RMSNorm(hidden_size, eps)
77-
self.quant_fp8 = QuantFP8(static=True,
78-
group_shape=GroupShape.PER_TENSOR)
79-
self.scale = torch.rand(1, dtype=torch.float32)
80103
self.output = torch.empty((token_num, hidden_size),
81-
dtype=torch.float32)
104+
dtype=current_platform.fp8_dtype())
105+
106+
def _quant_fp8_wrapper(x, scale):
107+
torch.ops._C.static_scaled_fp8_quant(self.output, x, scale)
108+
return self.output, scale
109+
110+
vllm_config = get_current_vllm_config()
111+
if "+quant_fp8" in vllm_config.compilation_config.custom_ops:
112+
# Need to use static_scaled_fp8_quant instead of QuantFP8
113+
# due to failure in TestBackend with copying graph
114+
self.quant_fp8 = _quant_fp8_wrapper
115+
else:
116+
self.quant_fp8 = QuantFP8(static=True,
117+
group_shape=GroupShape.PER_TENSOR)
118+
self.scale = torch.rand(1, dtype=torch.float32)
82119

83120
def forward(self, hidden_states, residual):
84121
view = hidden_states.reshape(-1, self.hidden_size)
85122
all_reduce = tensor_model_parallel_all_reduce(view)
86123
norm_output, residual_output = self.norm(all_reduce, residual)
87-
self.output, _ = self.quant_fp8(norm_output, self.scale)
88-
return self.output, residual_output
124+
output, _ = self.quant_fp8(norm_output, self.scale)
125+
hidden_states = finisher(output.to(hidden_states.dtype))
126+
return hidden_states, residual_output
89127

90128
def ops_in_model_after(self):
91129
return [torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default]
@@ -97,6 +135,7 @@ def ops_in_model_before(self):
97135

98136

99137
class TestAllReduceFusedAddRMSNormStaticQuantFP4Model(torch.nn.Module):
138+
pattern_code = 3
100139

101140
def __init__(self, hidden_size=16, token_num=16, eps=1e-6):
102141
super().__init__()
@@ -143,6 +182,9 @@ def ops_in_model_before(self):
143182
# TODO: Enable with torch==2.8.0
144183
# TestAllReduceFusedAddRMSNormStaticQuantFP4Model,
145184
])
185+
@pytest.mark.parametrize(
186+
"custom_ops",
187+
[[], ["+rms_norm"], ["+quant_fp8"], ["+rms_norm", "+quant_fp8"]])
146188
@pytest.mark.parametrize("batch_size", [8])
147189
@pytest.mark.parametrize("seq_len", [8])
148190
@pytest.mark.parametrize("hidden_size", [16])
@@ -155,19 +197,23 @@ def ops_in_model_before(self):
155197
reason="flashinfer is not found or flashinfer "
156198
"is not compiled with trtllm_allreduce_fusion")
157199
def test_all_reduce_fusion_pass_replace(test_model: torch.nn.Module,
158-
batch_size: int, seq_len: int,
159-
hidden_size: int, dtype: torch.dtype):
200+
custom_ops: list[str], batch_size: int,
201+
seq_len: int, hidden_size: int,
202+
dtype: torch.dtype):
160203
num_processes = 2
161204
if (test_model == TestAllReduceFusedAddRMSNormStaticQuantFP4Model
162205
and not current_platform.has_device_capability(100)):
163206
pytest.skip("Skip as nvfp4 is only supported on "
164207
"devices with compute capability 10.0 (Blackwell)")
208+
if (test_model != TestAllReduceFusedAddRMSNormStaticQuantFP8Model
209+
and ("+quant_fp8" in custom_ops)):
210+
pytest.skip()
165211

166212
def run_torch_spawn(fn, nprocs):
167213
torch.multiprocessing.spawn(fn,
168214
args=(num_processes, test_model,
169215
batch_size, seq_len, hidden_size,
170-
dtype),
216+
dtype, custom_ops),
171217
nprocs=nprocs)
172218

173219
run_torch_spawn(all_reduce_fusion_pass_on_test_model, num_processes)
@@ -176,7 +222,8 @@ def run_torch_spawn(fn, nprocs):
176222
def all_reduce_fusion_pass_on_test_model(local_rank: int, world_size: int,
177223
test_model_cls: torch.nn.Module,
178224
batch_size: int, seq_len: int,
179-
hidden_size: int, dtype: torch.dtype):
225+
hidden_size: int, dtype: torch.dtype,
226+
custom_ops: list[str]):
180227
current_platform.seed_everything(0)
181228

182229
device = torch.device(f"cuda:{local_rank}")
@@ -196,10 +243,9 @@ def all_reduce_fusion_pass_on_test_model(local_rank: int, world_size: int,
196243
initialize_model_parallel(tensor_model_parallel_size=world_size)
197244

198245
vllm_config = VllmConfig(compilation_config=CompilationConfig(
199-
level=CompilationLevel.PIECEWISE,
200-
custom_ops=["+rms_norm", "+quant_fp8"]))
246+
level=CompilationLevel.PIECEWISE, custom_ops=custom_ops))
201247
vllm_config.compilation_config.pass_config = PassConfig(
202-
enable_fi_allreduce_fusion=True, enable_noop=False)
248+
enable_fi_allreduce_fusion=True, enable_noop=True)
203249
vllm_config.device_config = DeviceConfig(device=torch.device("cuda"))
204250

205251
# this is a fake model name to construct the model config
@@ -221,11 +267,18 @@ def all_reduce_fusion_pass_on_test_model(local_rank: int, world_size: int,
221267

222268
hidden_states = torch.randn((token_num, hidden_size),
223269
requires_grad=False)
224-
residual = torch.randn((token_num, hidden_size), requires_grad=False)
270+
residual = torch.randn((token_num, hidden_size),
271+
dtype=torch.float32,
272+
requires_grad=False)
225273

226274
compiled_model = torch.compile(model, backend=backend)
227275
compiled_model(hidden_states, residual)
228276

229277
backend.check_before_ops(model.ops_in_model_before(),
230278
fully_replaced=False)
231279
backend.check_after_ops(model.ops_in_model_after())
280+
for node in find_op_nodes(
281+
torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default,
282+
backend.graph_post_pass):
283+
assert (
284+
node.kwargs.get("pattern_code") == test_model_cls.pattern_code)

0 commit comments

Comments
 (0)