Skip to content

Commit 0511091

Browse files
committed
[compile] Enable sequence parallelism matching w/o custom ops enabled
Signed-off-by: angelayi <yiangela7@gmail.com>
1 parent f50cc22 commit 0511091

File tree

2 files changed

+106
-220
lines changed

2 files changed

+106
-220
lines changed

tests/compile/test_sequence_parallelism.py

Lines changed: 28 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,6 @@ def __init__(self, hidden_size=16, intermediate_size=32):
9999
super().__init__()
100100
self.hidden_size = hidden_size
101101
self.intermediate_size = intermediate_size
102-
self.vllm_config = get_current_vllm_config()
103102
self.gate_proj = torch.nn.Parameter(
104103
torch.empty((intermediate_size, hidden_size)), requires_grad=False
105104
)
@@ -152,41 +151,36 @@ def forward(self, hidden_states, residual):
152151
def ops_in_model_before(self):
153152
ops_to_remove = [torch.ops.vllm.all_reduce.default] # Always removed by SP
154153
# The following are only removed if fusion happens
155-
if (
156-
self.vllm_config
157-
and self.vllm_config.compilation_config.pass_config.enable_fusion
158-
):
159-
ops_to_remove.extend(
160-
[
161-
torch.ops._C.fused_add_rms_norm.default,
162-
torch.ops._C.static_scaled_fp8_quant.default,
163-
]
164-
)
154+
config = get_current_vllm_config()
155+
if config.compilation_config.pass_config.enable_fusion:
156+
ops_to_remove.append(torch.ops._C.fused_add_rms_norm.default)
157+
# Only check for static_scaled_fp8_quant if custom quant_fp8 is enabled
158+
if "+quant_fp8" in config.compilation_config.custom_ops:
159+
ops_to_remove.append(torch.ops._C.static_scaled_fp8_quant.default)
165160
return ops_to_remove
166161

167162
def ops_in_model_after(self):
168163
ops_to_add = [
169164
torch.ops.vllm.reduce_scatter.default,
170165
torch.ops.vllm.all_gather.default,
171166
]
172-
# The following is only added if fusion happens
167+
# The following is only added if fusion happens and custom quant_fp8 is enabled
168+
config = get_current_vllm_config()
173169
if (
174-
self.vllm_config
175-
and self.vllm_config.compilation_config.pass_config.enable_fusion
170+
config.compilation_config.pass_config.enable_fusion
171+
and "+quant_fp8" in config.compilation_config.custom_ops
176172
):
177173
ops_to_add.append(torch.ops._C.fused_add_rms_norm_static_fp8_quant.default)
178174
return ops_to_add
179175

180176
def ops_in_model(self):
181-
if (
182-
self.vllm_config
183-
and self.vllm_config.compilation_config.pass_config.enable_fusion
184-
):
185-
# If fusion happens, the fused op is the one
177+
config = get_current_vllm_config()
178+
if config.compilation_config.pass_config.enable_fusion:
179+
# If fusion happens with custom quant_fp8, the fused op is the one
186180
# we check for (de)functionalization
187181
return [torch.ops._C.fused_add_rms_norm_static_fp8_quant.default]
188182
else:
189-
# If no fusion, the original ops are checked
183+
# If no fusion or using native quant, the original ops are checked
190184
return [
191185
torch.ops._C.fused_add_rms_norm.default,
192186
# TODO functionalization pass does not handle this yet
@@ -195,7 +189,14 @@ def ops_in_model(self):
195189

196190

197191
@multi_gpu_test(num_gpus=2)
198-
@pytest.mark.parametrize("test_model_cls", [TestModel, TestQuantModel])
192+
@pytest.mark.parametrize(
193+
"test_model_cls, custom_ops",
194+
[
195+
(TestModel, ""),
196+
(TestQuantModel, "+quant_fp8"),
197+
(TestQuantModel, "-quant_fp8"),
198+
],
199+
)
199200
@pytest.mark.parametrize("batch_size", [8])
200201
@pytest.mark.parametrize("seq_len", [16])
201202
@pytest.mark.parametrize("hidden_size", [16])
@@ -204,6 +205,7 @@ def ops_in_model(self):
204205
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], reason="Only test on CUDA")
205206
def test_sequence_parallelism_pass(
206207
test_model_cls: type[torch.nn.Module],
208+
custom_ops: str,
207209
batch_size: int,
208210
seq_len: int,
209211
hidden_size: int,
@@ -220,6 +222,7 @@ def run_torch_spawn(fn, nprocs):
220222
args=(
221223
num_processes,
222224
test_model_cls,
225+
custom_ops,
223226
batch_size,
224227
seq_len,
225228
hidden_size,
@@ -236,6 +239,7 @@ def sequence_parallelism_pass_on_test_model(
236239
local_rank: int,
237240
world_size: int,
238241
test_model_cls: type[torch.nn.Module],
242+
custom_ops: str,
239243
batch_size: int,
240244
seq_len: int,
241245
hidden_size: int,
@@ -264,12 +268,14 @@ def sequence_parallelism_pass_on_test_model(
264268
initialize_model_parallel(tensor_model_parallel_size=world_size)
265269

266270
# configure vllm config for SequenceParallelismPass
271+
custom_ops_list = custom_ops.split(",") if custom_ops else []
267272
compilation_config = CompilationConfig(
273+
custom_ops=custom_ops_list,
268274
pass_config=PassConfig(
269275
enable_sequence_parallelism=True,
270276
enable_fusion=enable_fusion,
271277
enable_noop=True,
272-
)
278+
),
273279
) # NoOp needed for fusion
274280
device_config = DeviceConfig(device=torch.device("cuda"))
275281

0 commit comments

Comments
 (0)