Skip to content

Commit ad366bf

Browse files
authored
[Bugfix] Follow vLLM Qwen-Moe/VL and KV Connector change to fix broken CI (#2181)
### What this PR does / why we need it? This pr fix broken CI: 1. Fix the vllm-project/vllm@ee2eb6e changes, in this commit, they fused the gate and up projections in the vision MLP, This can improve performance by reducing one matrix multiplication. so, this pr do the following things: - Specify that the two linear layers are fused as `mlp.gate_up_proj` when loading the weights. - Use a SiluAndMul activation function. 2. Fix vllm-project/vllm@aefeea0, Update ModelRunnerOutput parameters to adapt to its changes 3. Fix [vllm-commit](https://github.com/vllm-project/vllm/pull/20815/files#diff-3ffb829a39ab2b3e4706aa28f5e476815f36c3a87b98d6a66514ebedc8f3ffb4R354-R356), fix qwen moe ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.10.0 - vLLM main: vllm-project/vllm@fed5849 --------- Signed-off-by: wangli <wangli858794774@gmail.com>
1 parent e38fab0 commit ad366bf

File tree

8 files changed

+137
-56
lines changed

8 files changed

+137
-56
lines changed

tests/ut/kv_connector/test_remote_decode_lifecycle.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
create_model_runner_output,
2626
create_request, create_scheduler,
2727
create_vllm_config)
28+
from vllm_ascend.utils import vllm_version_is
2829

2930

3031
def test_basic_lifecycle():
@@ -102,7 +103,13 @@ def test_basic_lifecycle():
102103

103104
# (3b): execute_model()
104105
model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT)
105-
model_runner_output.finished_sending = [request_id]
106+
if vllm_version_is("0.10.0"):
107+
model_runner_output.finished_sending = [request_id]
108+
else:
109+
from vllm.v1.worker.kv_connector_model_runner_mixin import \
110+
KVConnectorOutput # type: ignore # noqa
111+
model_runner_output.kv_connector_output = KVConnectorOutput(
112+
finished_sending=[request_id])
106113

107114
# (3c): update_from_output()
108115
scheduler.update_from_output(scheduler_output, model_runner_output)
@@ -157,7 +164,13 @@ def test_prefix_cache_lifecycle():
157164
scheduler_output = scheduler.schedule()
158165
scheduler.schedule()
159166
model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT)
160-
model_runner_output.finished_sending = [request_remote.request_id]
167+
if vllm_version_is("0.10.0"):
168+
model_runner_output.finished_sending = [request_remote.request_id]
169+
else:
170+
from vllm.v1.worker.kv_connector_model_runner_mixin import \
171+
KVConnectorOutput # noqa
172+
model_runner_output.kv_connector_output = KVConnectorOutput(
173+
finished_sending=[request_remote.request_id])
161174
scheduler.update_from_output(scheduler_output, model_runner_output)
162175
_ = scheduler.schedule()
163176
assert_scheduler_empty(scheduler)

tests/ut/kv_connector/test_remote_prefill_lifecycle.py

Lines changed: 16 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import copy
2020

2121
from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT
22-
from vllm.v1.request import FinishReason, RequestStatus
22+
from vllm.v1.request import RequestStatus
2323

2424
from tests.ut.kv_connector.utils import (assert_scheduler_empty,
2525
create_model_runner_output,
@@ -55,10 +55,7 @@ def test_basic_lifecycle():
5555
# Nothing running and empty scheduler output.
5656
assert len(scheduler.running) == 0
5757
assert len(scheduler_output.scheduled_new_reqs) == 0
58-
if vllm_version_is("0.9.1"):
59-
assert len(scheduler_output.scheduled_cached_reqs) == 0
60-
else:
61-
assert scheduler_output.scheduled_cached_reqs.num_reqs == 0
58+
assert scheduler_output.scheduled_cached_reqs.num_reqs == 0
6259
assert len(scheduler_output.num_scheduled_tokens) == 0
6360
assert scheduler_output.total_num_scheduled_tokens == 0
6461

@@ -94,7 +91,13 @@ def test_basic_lifecycle():
9491

9592
# (2b): forward(): request finishes recv.
9693
model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT)
97-
model_runner_output.finished_recving = [request_id]
94+
if vllm_version_is("0.10.0"):
95+
model_runner_output.finished_recving = [request_id]
96+
else:
97+
from vllm.v1.worker.kv_connector_model_runner_mixin import \
98+
KVConnectorOutput # type: ignore # noqa
99+
model_runner_output.kv_connector_output = KVConnectorOutput(
100+
finished_recving=[request_id])
98101

99102
# (2c): update_from_output():
100103
engine_core_outputs = scheduler.update_from_output(scheduler_output,
@@ -135,11 +138,6 @@ def test_basic_lifecycle():
135138
model_runner_output)
136139
scheduler.schedule()
137140

138-
if vllm_version_is("0.9.1"):
139-
outputs = engine_core_outputs[0].outputs
140-
assert len(outputs) == 1
141-
output = outputs[0]
142-
assert output.finish_reason == FinishReason.STOP
143141
assert_scheduler_empty(scheduler)
144142

145143

@@ -213,7 +211,13 @@ def test_full_block_prompt():
213211
# # STEP (2): Recv.
214212
scheduler_output = scheduler.schedule()
215213
model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT)
216-
model_runner_output.finished_recving = [request_id]
214+
if vllm_version_is("0.10.0"):
215+
model_runner_output.finished_recving = [request_id]
216+
else:
217+
from vllm.v1.worker.kv_connector_model_runner_mixin import \
218+
KVConnectorOutput # type: ignore # noqa
219+
model_runner_output.kv_connector_output = KVConnectorOutput(
220+
finished_recving=[request_id])
217221
scheduler.update_from_output(scheduler_output, model_runner_output)
218222
assert len(scheduler.waiting) == 1
219223
assert (request_id in scheduler.finished_recving_kv_req_ids)
@@ -236,13 +240,6 @@ def test_full_block_prompt():
236240
# # Step (4): Hit EOS.
237241
scheduler_output = scheduler.schedule()
238242
model_runner_output = create_model_runner_output([request], use_eos=True)
239-
engine_core_outputs = scheduler.update_from_output(scheduler_output,
240-
model_runner_output)
241243
scheduler.schedule()
242244

243-
if vllm_version_is("0.9.1"):
244-
outputs = engine_core_outputs[0].outputs
245-
assert len(outputs) == 1
246-
output = outputs[0]
247-
assert output.finish_reason == FinishReason.STOP
248245
assert_scheduler_empty(scheduler)

tests/ut/kv_connector/utils.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -186,16 +186,27 @@ def create_model_runner_output(
186186
sampled_token_ids = [[sampled_token] for _ in req_ids]
187187

188188
# Make output data structure.
189+
extra_args = {}
190+
if not vllm_version_is("0.10.0"):
191+
from vllm.v1.worker.kv_connector_model_runner_mixin import \
192+
KVConnectorOutput # type: ignore # noqa
193+
kv_connector_output = KVConnectorOutput(
194+
finished_sending=finished_sending,
195+
finished_recving=finished_recving)
196+
extra_args = {"kv_connector_output": kv_connector_output}
197+
else:
198+
extra_args = {
199+
"finished_sending": finished_sending,
200+
"finished_recving": finished_recving,
201+
}
202+
189203
return ModelRunnerOutput(
190204
req_ids=req_ids,
191205
req_id_to_index=req_id_to_index,
192206
sampled_token_ids=sampled_token_ids,
193207
spec_token_ids=None,
194208
logprobs=None,
195209
prompt_logprobs_dict={},
196-
**({
197-
"pooler_output": []
198-
} if not vllm_version_is("0.9.1") else {}),
199-
finished_sending=finished_sending,
200-
finished_recving=finished_recving,
210+
pooler_output=[],
211+
**extra_args,
201212
)

vllm_ascend/models/qwen2_5_vl.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
# limitations under the License.
1919

2020
from functools import partial
21-
from typing import Callable, Iterable, Optional, Set, Tuple
21+
from typing import Callable, Iterable, Optional, Set, Tuple, Union
2222

2323
import torch
2424
import torch.nn as nn
@@ -30,7 +30,8 @@
3030
from vllm.config import VllmConfig
3131
from vllm.distributed import parallel_state
3232
from vllm.distributed import utils as dist_utils
33-
from vllm.model_executor.layers.activation import _ACTIVATION_REGISTRY
33+
from vllm.model_executor.layers.activation import (_ACTIVATION_REGISTRY,
34+
get_act_and_mul_fn)
3435
from vllm.model_executor.layers.layernorm import RMSNorm
3536
from vllm.model_executor.layers.quantization import QuantizationConfig
3637
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
@@ -42,6 +43,8 @@
4243
from vllm.model_executor.models.utils import maybe_prefix
4344
from vllm.multimodal import MULTIMODAL_REGISTRY
4445

46+
from vllm_ascend.utils import vllm_version_is
47+
4548
MIN_PAD_SIZE = 64 # min_size to pad weight
4649
MAX_PAD_SIZE = 128 # max_size to pad weight
4750

@@ -197,12 +200,16 @@ def __init__(
197200
in_channels=vision_config.in_channels,
198201
hidden_size=self.hidden_size,
199202
)
203+
204+
act_fn = get_act_and_mul_fn(vision_config.hidden_act)
205+
if vllm_version_is("0.10.0"):
206+
act_fn = _ACTIVATION_REGISTRY[vision_config.hidden_act]
200207
self.blocks = nn.ModuleList([
201208
AscendQwen2_5_VisionBlock(
202209
dim=self.hidden_size,
203210
num_heads=self.num_heads,
204211
mlp_hidden_dim=vision_config.intermediate_size,
205-
act_fn=_ACTIVATION_REGISTRY[vision_config.hidden_act],
212+
act_fn=act_fn,
206213
norm_layer=norm_layer,
207214
quant_config=quant_config,
208215
prefix=f"{prefix}.blocks.{layer_idx}")
@@ -291,12 +298,17 @@ def pad_proj_weight(self, data):
291298

292299
def load_weights(self, weights: Iterable[Tuple[str,
293300
torch.Tensor]]) -> Set[str]:
294-
stacked_params_mapping = [
301+
stacked_params_mapping: list[tuple[str, str, Union[str, int]]] = [
295302
# (param_name, shard_name, shard_id)
296303
("qkv_proj", "q_proj", "q"),
297304
("qkv_proj", "k_proj", "k"),
298305
("qkv_proj", "v_proj", "v"),
299306
]
307+
if not vllm_version_is("0.10.0"):
308+
stacked_params_mapping.extend([
309+
("mlp.gate_up_proj.", "mlp.gate_proj.", 0),
310+
("mlp.gate_up_proj.", "mlp.up_proj.", 1),
311+
])
300312
params_dict = dict(self.named_parameters(remove_duplicate=False))
301313
loaded_params: Set[str] = set()
302314
for name, loaded_weight in weights:

vllm_ascend/models/qwen2_5_vl_without_padding.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,8 @@
3030
from vllm.config import VllmConfig
3131
from vllm.distributed import parallel_state
3232
from vllm.distributed import utils as dist_utils
33-
from vllm.model_executor.layers.activation import _ACTIVATION_REGISTRY
33+
from vllm.model_executor.layers.activation import (_ACTIVATION_REGISTRY,
34+
get_act_and_mul_fn)
3435
from vllm.model_executor.layers.layernorm import RMSNorm
3536
from vllm.model_executor.layers.quantization import QuantizationConfig
3637
from vllm.model_executor.models.qwen2_5_vl import (
@@ -42,6 +43,7 @@
4243
from vllm.multimodal import MULTIMODAL_REGISTRY
4344

4445
from vllm_ascend.models.qwen2_5_vl import AscendQwen2_5_VisionRotaryEmbedding
46+
from vllm_ascend.utils import vllm_version_is
4547

4648

4749
class AscendQwen2_5_VisionAttention_Without_Padding(Qwen2_5_VisionAttention):
@@ -171,12 +173,16 @@ def __init__(
171173
in_channels=vision_config.in_channels,
172174
hidden_size=self.hidden_size,
173175
)
176+
177+
act_fn = get_act_and_mul_fn(vision_config.hidden_act)
178+
if vllm_version_is("0.10.0"):
179+
act_fn = _ACTIVATION_REGISTRY[vision_config.hidden_act]
174180
self.blocks = nn.ModuleList([
175181
AscendQwen2_5_VisionBlock_Without_Padding(
176182
dim=self.hidden_size,
177183
num_heads=self.num_heads,
178184
mlp_hidden_dim=vision_config.intermediate_size,
179-
act_fn=_ACTIVATION_REGISTRY[vision_config.hidden_act],
185+
act_fn=act_fn,
180186
norm_layer=norm_layer,
181187
quant_config=quant_config,
182188
prefix=f"{prefix}.blocks.{layer_idx}")

vllm_ascend/models/qwen3_moe.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
100100
cache_config = vllm_config.cache_config
101101
quant_config = vllm_config.quant_config
102102

103+
parallel_config = vllm_config.parallel_config
104+
self.num_redundant_experts = parallel_config.num_redundant_experts
103105
self.padding_idx = config.pad_token_id
104106
self.vocab_size = config.vocab_size
105107
self.config = config

vllm_ascend/worker/model_runner_v1.py

Lines changed: 43 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,8 @@
9494

9595
if not vllm_version_is("0.10.0"):
9696
from vllm.tasks import GenerationTask, SupportedTask
97+
from vllm.v1.worker.kv_connector_model_runner_mixin import \
98+
KVConnectorOutput
9799

98100
if TYPE_CHECKING:
99101
import xgrammar as xgr # type: ignore[import-untyped]
@@ -1472,8 +1474,9 @@ def _pool(
14721474
hidden_states: torch.Tensor,
14731475
num_scheduled_tokens: int,
14741476
num_scheduled_tokens_np: np.ndarray,
1475-
finished_sending: Optional[set[str]],
1476-
finished_receiving: Optional[set[str]],
1477+
finished_sending: Optional[set[str]] = None,
1478+
finished_recving: Optional[set[str]] = None,
1479+
kv_connector_output: Optional["KVConnectorOutput"] = None,
14771480
) -> ModelRunnerOutput:
14781481
assert self.input_batch.num_reqs ==\
14791482
len(self.input_batch.pooling_params), \
@@ -1499,6 +1502,12 @@ def _pool(
14991502
pooler_output.append(raw_output.data.cpu())
15001503
else:
15011504
pooler_output.append(None)
1505+
extra_args = ({
1506+
"finished_sending": finished_sending,
1507+
"finished_recving": finished_recving
1508+
} if vllm_version_is("0.10.0") else {
1509+
"kv_connector_output": kv_connector_output
1510+
})
15021511

15031512
return ModelRunnerOutput(
15041513
req_ids=self.input_batch.req_ids,
@@ -1508,8 +1517,8 @@ def _pool(
15081517
logprobs=None,
15091518
prompt_logprobs_dict={},
15101519
pooler_output=pooler_output,
1511-
finished_sending=finished_sending,
1512-
finished_recving=finished_receiving)
1520+
**extra_args,
1521+
)
15131522

15141523
@torch.inference_mode()
15151524
def execute_model(
@@ -1533,7 +1542,13 @@ def execute_model(
15331542
num_scheduled_tokens_np, finished_sending,
15341543
finished_recving) = (self._process_reqs(scheduler_output,
15351544
intermediate_tensors))
1536-
1545+
kv_connector_output = None
1546+
if not vllm_version_is("0.10.0"):
1547+
kv_connector_output = KVConnectorOutput(
1548+
finished_sending=finished_sending,
1549+
finished_recving=finished_recving)
1550+
finished_sending = None
1551+
finished_recving = None
15371552
with ProfileExecuteDuration().capture_async("post process"):
15381553
# Broadcast PP output for external_launcher (torchrun)
15391554
# to make sure we are synced across pp ranks
@@ -1545,7 +1560,10 @@ def execute_model(
15451560
if not get_pp_group().is_last_rank:
15461561
# For mid-pipeline stages, return the hidden states.
15471562
if not broadcast_pp_output:
1548-
if finished_sending or finished_recving:
1563+
if kv_connector_output is not None:
1564+
hidden_states.kv_connector_output = kv_connector_output
1565+
else:
1566+
#TODO: Remove this after we drop vllm v0.10.0
15491567
hidden_states.finished_sending = finished_sending
15501568
hidden_states.finished_recving = finished_recving
15511569
return hidden_states
@@ -1557,7 +1575,8 @@ def execute_model(
15571575
if self.input_batch.pooling_params:
15581576
return self._pool(hidden_states, num_scheduled_tokens,
15591577
num_scheduled_tokens_np,
1560-
finished_sending, finished_recving)
1578+
finished_sending, finished_recving,
1579+
kv_connector_output)
15611580
sample_hidden_states = hidden_states[logits_indices]
15621581
logits = self.model.compute_logits(sample_hidden_states, None)
15631582
if broadcast_pp_output:
@@ -1691,17 +1710,23 @@ def execute_model(
16911710
if has_kv_transfer_group():
16921711
get_kv_transfer_group().clear_connector_metadata()
16931712

1694-
model_runner_output = ModelRunnerOutput(
1695-
req_ids=self.input_batch.req_ids,
1696-
req_id_to_index=self.input_batch.req_id_to_index,
1697-
sampled_token_ids=valid_sampled_token_ids,
1698-
spec_token_ids=spec_token_ids,
1699-
logprobs=logprobs_lists,
1700-
prompt_logprobs_dict=prompt_logprobs_dict,
1701-
pooler_output=[],
1702-
finished_sending=finished_sending,
1703-
finished_recving=finished_recving,
1704-
)
1713+
extra_args = ({
1714+
"finished_sending": finished_sending,
1715+
"finished_recving": finished_recving
1716+
} if vllm_version_is("0.10.0") else {
1717+
"kv_connector_output": kv_connector_output
1718+
})
1719+
1720+
model_runner_output = ModelRunnerOutput(
1721+
req_ids=self.input_batch.req_ids,
1722+
req_id_to_index=self.input_batch.req_id_to_index,
1723+
sampled_token_ids=valid_sampled_token_ids,
1724+
spec_token_ids=spec_token_ids,
1725+
logprobs=logprobs_lists,
1726+
prompt_logprobs_dict=prompt_logprobs_dict,
1727+
pooler_output=[],
1728+
**extra_args,
1729+
)
17051730

17061731
durations = ProfileExecuteDuration().pop_captured_sync()
17071732
if durations:

0 commit comments

Comments
 (0)