Skip to content

Commit b7f52bf

Browse files
committed
allreduce fusion working with/without custom ops (except fp4)
Signed-off-by: Luka Govedič <lgovedic@redhat.com>
1 parent 54189a9 commit b7f52bf

File tree

1 file changed

+35
-11
lines changed

1 file changed

+35
-11
lines changed

tests/compile/test_fusion_all_reduce.py

Lines changed: 35 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,9 @@ def __init__(self, hidden_size=16, token_num=16, eps=1e-6):
6666
def forward(self, hidden_states, residual):
6767
view = hidden_states.reshape(-1, self.hidden_size)
6868
all_reduce = tensor_model_parallel_all_reduce(view)
69-
norm, _ = self.norm(all_reduce, residual)
70-
return norm
69+
norm, res = self.norm(all_reduce, residual)
70+
71+
return norm, res
7172

7273
def ops_in_model_before(self):
7374
return [torch.ops.vllm.all_reduce.default]
@@ -98,7 +99,9 @@ def ops_in_model_after(self):
9899
def ops_in_model_before(self):
99100
return [
100101
torch.ops.vllm.all_reduce.default,
101-
torch.ops._C.static_scaled_fp8_quant.default,
102+
torch.ops._C.static_scaled_fp8_quant.default
103+
if self.quant_fp8.enabled()
104+
else torch.ops.aten.reciprocal.default,
102105
]
103106

104107

@@ -139,19 +142,21 @@ def ops_in_model_before(self):
139142

140143
@multi_gpu_test(num_gpus=2)
141144
@pytest.mark.parametrize(
142-
"test_model",
145+
"test_model, enable_quant_fp8",
143146
[
144-
TestAllReduceRMSNormModel,
145-
TestAllReduceFusedAddRMSNormModel,
146-
TestAllReduceFusedAddRMSNormStaticQuantFP8Model,
147+
(TestAllReduceRMSNormModel, False),
148+
(TestAllReduceFusedAddRMSNormModel, False),
149+
(TestAllReduceFusedAddRMSNormStaticQuantFP8Model, True),
150+
(TestAllReduceFusedAddRMSNormStaticQuantFP8Model, False),
147151
# TODO: Enable with torch==2.8.0
148-
# TestAllReduceFusedAddRMSNormStaticQuantFP4Model,
152+
# (TestAllReduceFusedAddRMSNormStaticQuantFP4Model, False),
149153
],
150154
)
151155
@pytest.mark.parametrize("batch_size", [8])
152156
@pytest.mark.parametrize("seq_len", [8])
153157
@pytest.mark.parametrize("hidden_size", [16])
154158
@pytest.mark.parametrize("dtype", [torch.bfloat16])
159+
@pytest.mark.parametrize("enable_rms_norm", [True, False])
155160
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], reason="Only test on CUDA")
156161
@pytest.mark.skipif(
157162
not find_spec("flashinfer")
@@ -165,6 +170,8 @@ def test_all_reduce_fusion_pass_replace(
165170
seq_len: int,
166171
hidden_size: int,
167172
dtype: torch.dtype,
173+
enable_rms_norm,
174+
enable_quant_fp8,
168175
):
169176
num_processes = 2
170177
if (
@@ -179,7 +186,16 @@ def test_all_reduce_fusion_pass_replace(
179186
def run_torch_spawn(fn, nprocs):
180187
torch.multiprocessing.spawn(
181188
fn,
182-
args=(num_processes, test_model, batch_size, seq_len, hidden_size, dtype),
189+
args=(
190+
num_processes,
191+
test_model,
192+
batch_size,
193+
seq_len,
194+
hidden_size,
195+
dtype,
196+
enable_rms_norm,
197+
enable_quant_fp8,
198+
),
183199
nprocs=nprocs,
184200
)
185201

@@ -194,6 +210,8 @@ def all_reduce_fusion_pass_on_test_model(
194210
seq_len: int,
195211
hidden_size: int,
196212
dtype: torch.dtype,
213+
enable_rms_norm,
214+
enable_quant_fp8,
197215
):
198216
current_platform.seed_everything(0)
199217

@@ -215,9 +233,15 @@ def all_reduce_fusion_pass_on_test_model(
215233
init_distributed_environment()
216234
initialize_model_parallel(tensor_model_parallel_size=world_size)
217235

236+
custom_ops = []
237+
if enable_rms_norm:
238+
custom_ops.append("+rms_norm")
239+
if enable_quant_fp8:
240+
custom_ops.append("+quant_fp8")
241+
218242
vllm_config = VllmConfig(
219243
compilation_config=CompilationConfig(
220-
level=CompilationLevel.PIECEWISE, custom_ops=["+rms_norm", "+quant_fp8"]
244+
level=CompilationLevel.PIECEWISE, custom_ops=custom_ops
221245
)
222246
)
223247
vllm_config.compilation_config.pass_config = PassConfig(
@@ -239,7 +263,7 @@ def all_reduce_fusion_pass_on_test_model(
239263
cleanup_pass = PostCleanupPass(vllm_config)
240264

241265
backend = TestBackend(
242-
all_reduce_fusion_pass, noop_pass, func_pass, cleanup_pass
266+
noop_pass, all_reduce_fusion_pass, func_pass, cleanup_pass
243267
)
244268

245269
token_num = batch_size * seq_len

0 commit comments

Comments
 (0)