Skip to content

Commit fccaf6d

Browse files
DarkLight1337xuebwang-amd
authored andcommitted
[V1] Support LLM.apply_model (vllm-project#18465)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk> Signed-off-by: xuebwang-amd <xuebwang@amd.com>
1 parent dfda7b9 commit fccaf6d

File tree

17 files changed

+194
-169
lines changed

17 files changed

+194
-169
lines changed

tests/conftest.py

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -987,17 +987,7 @@ def score(
987987
return [req_output.outputs.score for req_output in req_outputs]
988988

989989
def apply_model(self, func: Callable[[nn.Module], _R]) -> list[_R]:
990-
if hasattr(self.llm.llm_engine, "model_executor"):
991-
# This works either in V0 or in V1 with
992-
# VLLM_ENABLE_V1_MULTIPROCESSING=0
993-
executor = self.llm.llm_engine.model_executor
994-
return executor.apply_model(func)
995-
996-
# This works in V1 with VLLM_ALLOW_INSECURE_SERIALIZATION=1
997-
def _apply_model(self):
998-
return func(self.get_model())
999-
1000-
return self.llm.llm_engine.collective_rpc(_apply_model)
990+
return self.llm.apply_model(func)
1001991

1002992
def get_llm(self) -> LLM:
1003993
return self.llm

tests/kernels/moe/test_mxfp4_moe.py

Lines changed: 22 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,24 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

4-
import importlib
54
import importlib.metadata
65
from dataclasses import dataclass
6+
from importlib.util import find_spec
77
from typing import Optional
88

99
import pytest
1010
import torch
1111
from packaging import version
1212

13+
from vllm.model_executor.layers.quantization.quark.quark import ( # noqa: E501
14+
QuarkLinearMethod, QuarkW4A4MXFP4)
15+
from vllm.model_executor.layers.quantization.quark.quark_moe import ( # noqa: E501
16+
QuarkW4A4MXFp4MoEMethod)
1317
from vllm.platforms import current_platform
1418
from vllm.utils.flashinfer import has_flashinfer
1519

16-
QUARK_MXFP4_AVAILABLE = importlib.util.find_spec(
17-
"quark") is not None and version.parse(
18-
importlib.metadata.version("amd-quark")) >= version.parse('0.8.99')
20+
QUARK_MXFP4_AVAILABLE = find_spec("quark") is not None and version.parse(
21+
importlib.metadata.version("amd-quark")) >= version.parse('0.8.99')
1922

2023
TRTLLM_GEN_MXFP4_AVAILABLE = current_platform.is_cuda(
2124
) and current_platform.is_device_capability(100)
@@ -39,6 +42,12 @@ class ModelCase:
3942
tp: int
4043

4144

45+
@pytest.fixture(scope="function", autouse=True)
46+
def enable_pickle(monkeypatch):
47+
"""`LLM.apply_model` requires pickling a function."""
48+
monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
49+
50+
4251
@pytest.mark.parametrize('model_case', [
4352
ModelCase("fxmarty/qwen_1.5-moe-a2.7b-mxfp4", tp=1),
4453
ModelCase("fxmarty/deepseek_r1_3_layers_mxfp4", tp=8),
@@ -55,21 +64,19 @@ def test_mxfp4_loading_and_execution_moe(vllm_runner, model_case: ModelCase):
5564
tensor_parallel_size=model_case.tp,
5665
load_format="dummy") as llm:
5766

58-
# TODO: llm.apply_model(check_model) currently relies on V0 internals.
59-
# Re-enable this later.
60-
# def check_model(model):
61-
# layer = model.model.layers[0]
67+
def check_model(model):
68+
layer = model.model.layers[0]
6269

63-
# qkv_proj = layer.self_attn.qkv_proj
70+
qkv_proj = layer.self_attn.qkv_proj
6471

65-
# assert isinstance(qkv_proj.quant_method, QuarkLinearMethod)
66-
# assert isinstance(qkv_proj.scheme, QuarkW4A4MXFP4)
72+
assert isinstance(qkv_proj.quant_method, QuarkLinearMethod)
73+
assert isinstance(qkv_proj.scheme, QuarkW4A4MXFP4)
6774

68-
# assert isinstance(layer.mlp.experts.quant_method,
69-
# QuarkW4A4MXFp4MoEMethod)
75+
assert isinstance(layer.mlp.experts.quant_method,
76+
QuarkW4A4MXFp4MoEMethod)
7077

71-
# if model_case.model_id == "fxmarty/qwen_1.5-moe-a2.7b-mxfp4":
72-
# llm.apply_model(check_model)
78+
if model_case.model_id == "fxmarty/qwen_1.5-moe-a2.7b-mxfp4":
79+
llm.apply_model(check_model)
7380

7481
output = llm.generate_greedy("Today I am in the French Alps and",
7582
max_tokens=20)

tests/models/multimodal/generation/test_qwen2_vl.py

Lines changed: 23 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -10,18 +10,17 @@
1010

1111
from vllm.multimodal.image import rescale_image_size
1212
from vllm.multimodal.video import rescale_video_size, sample_frames_from_video
13+
from vllm.utils import set_default_torch_num_threads
1314

1415
from ....conftest import (IMAGE_ASSETS, VIDEO_ASSETS, PromptImageInput,
1516
PromptVideoInput, VllmRunner)
1617
from ...utils import check_logprobs_close
1718

1819

1920
@pytest.fixture(scope="function", autouse=True)
20-
def use_v0_only(monkeypatch):
21-
"""
22-
V1 Test: batch_make_xxxxx_embeddings calls a V0 internal
23-
"""
24-
monkeypatch.setenv('VLLM_USE_V1', '0')
21+
def enable_pickle(monkeypatch):
22+
"""`LLM.apply_model` requires pickling a function."""
23+
monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
2524

2625

2726
models = ["Qwen/Qwen2-VL-2B-Instruct"]
@@ -126,9 +125,8 @@ def get_image_embeds(model):
126125
image_grid_thw_on_device = image_grid_thw.to(visual.device,
127126
dtype=torch.int64)
128127
return visual(pixel_values_on_device,
129-
grid_thw=image_grid_thw_on_device)
128+
grid_thw=image_grid_thw_on_device).cpu()
130129

131-
# V1 Test: this calls a V0 internal.
132130
image_embeds = torch.concat(llm.apply_model(get_image_embeds))
133131

134132
# split into original batches
@@ -210,7 +208,7 @@ def get_image_embeds(model):
210208
video_grid_thw_on_device = video_grid_thw.to(visual.device,
211209
dtype=torch.int64)
212210
return visual(pixel_values_on_device,
213-
grid_thw=video_grid_thw_on_device)
211+
grid_thw=video_grid_thw_on_device).cpu()
214212

215213
# V1 Test: this calls a V0 internal.
216214
video_embeds = torch.concat(llm.apply_model(get_image_embeds))
@@ -266,19 +264,22 @@ def run_embedding_input_test(
266264
processor = AutoProcessor.from_pretrained(model)
267265

268266
# max_model_len should be greater than image_feature_size
269-
with vllm_runner(model,
270-
runner="generate",
271-
max_model_len=4000,
272-
max_num_seqs=3,
273-
dtype=dtype,
274-
limit_mm_per_prompt={
275-
"image": mm_limit,
276-
"video": mm_limit
277-
},
278-
tensor_parallel_size=tensor_parallel_size,
279-
distributed_executor_backend=distributed_executor_backend
280-
) as vllm_model:
267+
with set_default_torch_num_threads(1):
268+
vllm_model = vllm_runner(
269+
model,
270+
runner="generate",
271+
max_model_len=4000,
272+
max_num_seqs=3,
273+
dtype=dtype,
274+
limit_mm_per_prompt={
275+
"image": mm_limit,
276+
"video": mm_limit
277+
},
278+
tensor_parallel_size=tensor_parallel_size,
279+
distributed_executor_backend=distributed_executor_backend,
280+
)
281281

282+
with vllm_model:
282283
outputs_per_case_for_original_input = [
283284
vllm_model.generate_greedy_logprobs(prompts,
284285
max_tokens,
@@ -329,9 +330,8 @@ def run_embedding_input_test(
329330
@pytest.mark.parametrize("max_tokens", [128])
330331
@pytest.mark.parametrize("num_logprobs", [10])
331332
def test_qwen2_vl_image_embeddings_input(vllm_runner, image_assets, model,
332-
size_factors, dtype: str,
333-
max_tokens: int,
334-
num_logprobs: int) -> None:
333+
size_factors, dtype, max_tokens,
334+
num_logprobs, monkeypatch) -> None:
335335
images = [asset.pil_image for asset in image_assets]
336336

337337
inputs_per_case: list[tuple[

tests/models/quantization/test_awq.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ def test_awq_models(vllm_runner, image_assets, source_model, quant_model,
112112
monkeypatch) -> None:
113113

114114
# Test V1: this test hangs during setup on single-scale input.
115-
# TODO: fixure out why and re-enable this on V1.
115+
# TODO: figure out why and re-enable this on V1.
116116
monkeypatch.setenv("VLLM_USE_V1", "0")
117117
run_awq_test(
118118
vllm_runner,

tests/quantization/test_compressed_tensors.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -43,12 +43,9 @@
4343

4444

4545
@pytest.fixture(scope="function", autouse=True)
46-
def use_v0_only(monkeypatch):
47-
"""
48-
This module relies on V0 internals, so set VLLM_USE_V1=0.
49-
"""
50-
if not current_platform.is_cpu():
51-
monkeypatch.setenv('VLLM_USE_V1', '0')
46+
def enable_pickle(monkeypatch):
47+
"""`LLM.apply_model` requires pickling a function."""
48+
monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
5249

5350

5451
@pytest.mark.parametrize(
@@ -176,10 +173,11 @@ def test_compressed_tensors_w8a8_logprobs(
176173

177174
dtype = "bfloat16"
178175

179-
# skip language translation prompt for the static per tensor asym model
180-
if (model_path ==
181-
"nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Static-Per-Tensor-Asym"
182-
): # noqa: E501
176+
# skip language translation prompt for the static per tensor models
177+
if model_path in (
178+
"nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Static-Per-Tensor-Sym",
179+
"nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Static-Per-Tensor-Asym",
180+
):
183181
example_prompts = example_prompts[0:-1]
184182

185183
with hf_runner(model_path, dtype=dtype) as hf_model:

tests/quantization/test_fp8.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,8 @@ def test_kv_cache_model_load_and_run(vllm_runner, model_id: str,
6060
if use_rocm_aiter:
6161
monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")
6262

63-
# vllm_runner.apply_model() relies on V0 internals.
64-
monkeypatch.setenv("VLLM_USE_V1", "0")
63+
# `LLM.apply_model` requires pickling a function.
64+
monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
6565
with vllm_runner(model_id, kv_cache_dtype="fp8") as llm:
6666

6767
def check_model(model):
@@ -104,8 +104,8 @@ def test_load_fp16_model(vllm_runner, kv_cache_dtype: str, force_marlin: bool,
104104
if use_rocm_aiter:
105105
monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")
106106

107-
# vllm_runner.apply_model() relies on V0 internals.
108-
monkeypatch.setenv("VLLM_USE_V1", "0")
107+
# `LLM.apply_model` requires pickling a function.
108+
monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
109109

110110
if force_marlin:
111111
monkeypatch.setenv("VLLM_TEST_FORCE_FP8_MARLIN", "1")

tests/quantization/test_gptq_dynamic.py

Lines changed: 38 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -31,41 +31,46 @@
3131
@pytest.mark.parametrize("model_id, use_marlin_kernel", MODEL_QUANT)
3232
def test_gptq_with_dynamic(vllm_runner, model_id: str, use_marlin_kernel: bool,
3333
monkeypatch):
34-
# vllm_runner.apply_model() relies on V0 internals.
35-
monkeypatch.setenv("VLLM_USE_V1", "0")
36-
37-
vllm_model = vllm_runner(model_id, dtype=torch.float16, max_model_len=2048)
34+
# `LLM.apply_model` requires pickling a function.
35+
monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
3836

3937
linear_method_cls = GPTQMarlinLinearMethod if use_marlin_kernel else (
4038
GPTQLinearMethod)
4139

42-
for name, submodule in (vllm_model.llm.llm_engine.model_executor.
43-
driver_worker.model_runner.model.named_modules()):
44-
if name == "lm_head":
45-
assert isinstance(submodule.quant_method, linear_method_cls)
46-
elif name == 'model.layers.0.self_attn.qkv_proj':
47-
# The first layer is quantized using bits=4, group_size=128
48-
# desc_act=True
49-
assert isinstance(submodule.quant_method, linear_method_cls)
50-
config = submodule.quant_method.quant_config
51-
assert config.weight_bits == 4
52-
assert config.group_size == 128
53-
assert config.desc_act
54-
elif name == 'model.layers.1.self_attn.qkv_proj':
55-
# The second layer is quantized using bits=8, group_size=32
56-
# desc_act=False
57-
assert isinstance(submodule.quant_method, linear_method_cls)
58-
config = submodule.quant_method.quant_config
59-
assert get_dynamic_override(config, layer_name=name,
60-
key="bits") == 8
61-
assert get_dynamic_override(config,
62-
layer_name=name,
63-
key="group_size") == 32
64-
assert not get_dynamic_override(
65-
config, layer_name=name, key="desc_act")
66-
elif (name == 'model.layers.2.self_attn.qkv_proj'
67-
or name == 'model.layers.2.mlp.gate_up_proj'):
68-
# All other layers (layer index >= 2) are not quantized
69-
assert isinstance(submodule.quant_method, UnquantizedLinearMethod)
40+
with vllm_runner(model_id, dtype=torch.float16, max_model_len=2048) as llm:
41+
42+
def check_model(model):
43+
for name, submodule in model.named_modules():
44+
if name == "lm_head":
45+
assert isinstance(submodule.quant_method,
46+
linear_method_cls)
47+
elif name == 'model.layers.0.self_attn.qkv_proj':
48+
# The first layer is quantized using bits=4, group_size=128
49+
# desc_act=True
50+
assert isinstance(submodule.quant_method,
51+
linear_method_cls)
52+
config = submodule.quant_method.quant_config
53+
assert config.weight_bits == 4
54+
assert config.group_size == 128
55+
assert config.desc_act
56+
elif name == 'model.layers.1.self_attn.qkv_proj':
57+
# The second layer is quantized using bits=8, group_size=32
58+
# desc_act=False
59+
assert isinstance(submodule.quant_method,
60+
linear_method_cls)
61+
config = submodule.quant_method.quant_config
62+
assert get_dynamic_override(config,
63+
layer_name=name,
64+
key="bits") == 8
65+
assert get_dynamic_override(config,
66+
layer_name=name,
67+
key="group_size") == 32
68+
assert not get_dynamic_override(
69+
config, layer_name=name, key="desc_act")
70+
elif (name == 'model.layers.2.self_attn.qkv_proj'
71+
or name == 'model.layers.2.mlp.gate_up_proj'):
72+
# All other layers (layer index >= 2) are not quantized
73+
assert isinstance(submodule.quant_method,
74+
UnquantizedLinearMethod)
7075

71-
del vllm_model
76+
llm.apply_model(check_model)

tests/quantization/test_lm_head.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,8 @@ def test_lm_head(
2929
lm_head_quantized: bool,
3030
monkeypatch,
3131
) -> None:
32-
# vllm_runner.apply_model() relies on V0 internals.
33-
monkeypatch.setenv("VLLM_USE_V1", "0")
32+
# `LLM.apply_model` requires pickling a function.
33+
monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
3434
with vllm_runner(model_id, dtype=torch.float16,
3535
max_model_len=2048) as vllm_model:
3636

tests/quantization/test_modelopt.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,16 +11,12 @@
1111
import torch
1212

1313
from tests.quantization.utils import is_quant_method_supported
14-
from vllm.platforms import current_platform
1514

1615

1716
@pytest.fixture(scope="function", autouse=True)
18-
def use_v0_only(monkeypatch):
19-
"""
20-
This module relies on V0 internals, so set VLLM_USE_V1=0.
21-
"""
22-
if not current_platform.is_cpu():
23-
monkeypatch.setenv('VLLM_USE_V1', '0')
17+
def enable_pickle(monkeypatch):
18+
"""`LLM.apply_model` requires pickling a function."""
19+
monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
2420

2521

2622
@pytest.mark.skipif(not is_quant_method_supported("modelopt"),

0 commit comments

Comments
 (0)