Skip to content

Commit de7405b

Browse files
committed
PR comments: add _custom_op suffix
Signed-off-by: Luka Govedič <lgovedic@redhat.com>
1 parent 24f1298 commit de7405b

File tree

2 files changed

+25
-21
lines changed

2 files changed

+25
-21
lines changed

tests/compile/test_fusion.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -71,8 +71,8 @@ def __init__(
7171
act_quant_group_shape=group_shape,
7272
)
7373

74-
self.enable_rms_norm = self.norm[0].enabled()
75-
self.enable_quant_fp8 = self.fp8_linear.quant_fp8.enabled()
74+
self.enable_rms_norm_custom_op = self.norm[0].enabled()
75+
self.enable_quant_fp8_custom_op = self.fp8_linear.quant_fp8.enabled()
7676

7777
def forward(self, x):
7878
# avoid having graph input be an arg to a pattern directly
@@ -107,21 +107,25 @@ def ops_in_model_after(self):
107107
def ops_in_model_before(self):
108108
return (
109109
[QUANT_OPS[self.quant_key]]
110-
if self.enable_quant_fp8
110+
if self.enable_quant_fp8_custom_op
111111
else [torch.ops.aten.reciprocal]
112112
)
113113

114114
def ops_in_model_before_partial(self):
115-
return [RMS_OP, RMS_ADD_OP] if self.enable_rms_norm else [torch.ops.aten.rsqrt]
115+
return (
116+
[RMS_OP, RMS_ADD_OP]
117+
if self.enable_rms_norm_custom_op
118+
else [torch.ops.aten.rsqrt]
119+
)
116120

117121

118122
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
119123
@pytest.mark.parametrize("hidden_size", [64])
120124
@pytest.mark.parametrize("num_tokens", [257])
121125
@pytest.mark.parametrize("eps", [1e-5, 1e-6])
122126
@pytest.mark.parametrize("static", [True, False])
123-
@pytest.mark.parametrize("enable_rms_norm", [True, False])
124-
@pytest.mark.parametrize("enable_quant_fp8", [True, False])
127+
@pytest.mark.parametrize("enable_rms_norm_custom_op", [True, False])
128+
@pytest.mark.parametrize("enable_quant_fp8_custom_op", [True, False])
125129
# cuda_force_torch used to test torch code path on platforms that
126130
# cutlass_fp8_supported() == True.
127131
@pytest.mark.parametrize(
@@ -136,8 +140,8 @@ def test_fusion_rmsnorm_quant(
136140
num_tokens,
137141
eps,
138142
static,
139-
enable_rms_norm,
140-
enable_quant_fp8,
143+
enable_rms_norm_custom_op,
144+
enable_quant_fp8_custom_op,
141145
cuda_force_torch,
142146
):
143147
torch.set_default_device("cuda")
@@ -146,9 +150,9 @@ def test_fusion_rmsnorm_quant(
146150
maybe_create_device_identity() # needed for certain non-cutlass fp8 paths
147151

148152
custom_ops = []
149-
if enable_rms_norm:
153+
if enable_rms_norm_custom_op:
150154
custom_ops.append("+rms_norm")
151-
if enable_quant_fp8:
155+
if enable_quant_fp8_custom_op:
152156
custom_ops.append("+quant_fp8")
153157
vllm_config = VllmConfig(
154158
model_config=ModelConfig(dtype=dtype),
@@ -195,7 +199,7 @@ def test_fusion_rmsnorm_quant(
195199
# there's a risk that the fused add doesn't get included in the
196200
# replacement and only the rms part gets fused with quant.
197201
# Hence, we check only 2 add nodes are left (final fused rmsnorm add).
198-
if not enable_rms_norm:
202+
if not enable_rms_norm_custom_op:
199203
n_add_nodes = lambda g: sum(1 for _ in find_op_nodes(torch.ops.aten.add, g))
200204
# 7 = 1 (RMS) + 3x2 (3xRMS_ADD, 2 each)
201205
assert n_add_nodes(backend.graph_pre_pass) == 7

tests/compile/test_fusion_all_reduce.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,7 @@ def ops_in_model_before(self):
194194

195195
@multi_gpu_test(num_gpus=2)
196196
@pytest.mark.parametrize(
197-
"test_model, enable_quant_fp8",
197+
"test_model, enable_quant_fp8_custom_op",
198198
[
199199
(TestAllReduceRMSNormModel, False),
200200
(TestAllReduceRMSNormStaticQuantFP8Model, True),
@@ -206,7 +206,7 @@ def ops_in_model_before(self):
206206
@pytest.mark.parametrize("seq_len", [8])
207207
@pytest.mark.parametrize("hidden_size", [64])
208208
@pytest.mark.parametrize("dtype", [torch.bfloat16])
209-
@pytest.mark.parametrize("enable_rms_norm", [True, False])
209+
@pytest.mark.parametrize("enable_rms_norm_custom_op", [True, False])
210210
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], reason="Only test on CUDA")
211211
@pytest.mark.skipif(
212212
not find_spec("flashinfer")
@@ -220,8 +220,8 @@ def test_all_reduce_fusion_pass_replace(
220220
seq_len: int,
221221
hidden_size: int,
222222
dtype: torch.dtype,
223-
enable_rms_norm,
224-
enable_quant_fp8,
223+
enable_rms_norm_custom_op,
224+
enable_quant_fp8_custom_op,
225225
):
226226
num_processes = 2
227227
if (
@@ -243,8 +243,8 @@ def run_torch_spawn(fn, nprocs):
243243
seq_len,
244244
hidden_size,
245245
dtype,
246-
enable_rms_norm,
247-
enable_quant_fp8,
246+
enable_rms_norm_custom_op,
247+
enable_quant_fp8_custom_op,
248248
),
249249
nprocs=nprocs,
250250
)
@@ -260,8 +260,8 @@ def all_reduce_fusion_pass_on_test_model(
260260
seq_len: int,
261261
hidden_size: int,
262262
dtype: torch.dtype,
263-
enable_rms_norm,
264-
enable_quant_fp8,
263+
enable_rms_norm_custom_op,
264+
enable_quant_fp8_custom_op,
265265
):
266266
current_platform.seed_everything(0)
267267

@@ -284,9 +284,9 @@ def all_reduce_fusion_pass_on_test_model(
284284
initialize_model_parallel(tensor_model_parallel_size=world_size)
285285

286286
custom_ops = []
287-
if enable_rms_norm:
287+
if enable_rms_norm_custom_op:
288288
custom_ops.append("+rms_norm")
289-
if enable_quant_fp8:
289+
if enable_quant_fp8_custom_op:
290290
custom_ops.append("+quant_fp8")
291291

292292
vllm_config = VllmConfig(

0 commit comments

Comments
 (0)