Skip to content

Commit 291c6d6

Browse files
author
Levi-JQ
committed
[Bugfix] qwen2.5-vl-72b reports a shape ERROR during the _prepare_inputs phase under high concurrency.
Signed-off-by: Levi-JQ <yujinqi2@huawei.com>
1 parent 71acc8d commit 291c6d6

File tree

3 files changed

+168
-10
lines changed

3 files changed

+168
-10
lines changed

vllm_ascend/models/qwen2_5_vl.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
from vllm.model_executor.layers.layernorm import RMSNorm
3535
from vllm.model_executor.layers.quantization import QuantizationConfig
3636
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
37+
from vllm.model_executor.models.interfaces import MultiModalEmbeddings
3738
from vllm.model_executor.models.qwen2_5_vl import (
3839
Qwen2_5_VisionAttention, Qwen2_5_VisionBlock, Qwen2_5_VisionPatchEmbed,
3940
Qwen2_5_VisionRotaryEmbedding, Qwen2_5_VisionTransformer,
@@ -560,3 +561,68 @@ def _process_video_input(self, video_input) -> tuple[torch.Tensor, ...]:
560561
merge_size = self.visual.spatial_merge_size
561562
sizes = grid_thw.prod(-1) // merge_size // merge_size
562563
return video_embeds.split(sizes.tolist())
564+
565+
def _get_text_embeddings(
566+
self,
567+
input_ids: torch.Tensor,
568+
get_input_embeddings: Callable[[torch.Tensor], torch.Tensor],
569+
*,
570+
is_multimodal: Optional[torch.Tensor],
571+
handle_oov_mm_token: bool,
572+
) -> torch.Tensor:
573+
if handle_oov_mm_token and is_multimodal is not None:
574+
is_text = ~is_multimodal
575+
text_embeds = get_input_embeddings(input_ids[is_text])
576+
577+
return torch.empty(
578+
(input_ids.shape[0], text_embeds.shape[1]),
579+
dtype=text_embeds.dtype,
580+
device=text_embeds.device,
581+
).masked_scatter_(is_text.unsqueeze_(-1), text_embeds)
582+
583+
return get_input_embeddings(input_ids)
584+
585+
def get_input_embeddings(
586+
self,
587+
input_ids: torch.Tensor,
588+
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
589+
*,
590+
is_multimodal: Optional[torch.Tensor] = None,
591+
handle_oov_mm_token: bool = False,
592+
) -> torch.Tensor:
593+
"""
594+
Apply token embeddings to `input_ids`.
595+
596+
If `multimodal_embeddings` is passed, scatter them into
597+
`input_ids` according to the mask `is_multimodal`.
598+
599+
In case the multi-modal token IDs exceed the vocabulary size of
600+
the language model, you can set `handle_oov_mm_token=False`
601+
to avoid calling the language model's `get_input_embeddings` method
602+
on those tokens. Note however that doing so increases memory usage
603+
as an additional buffer is needed to hold the input embeddings.
604+
"""
605+
from vllm.model_executor.models.utils import \
606+
_merge_multimodal_embeddings
607+
608+
inputs_embeds = self._get_text_embeddings(
609+
input_ids,
610+
self.get_language_model().get_input_embeddings,
611+
is_multimodal=is_multimodal,
612+
handle_oov_mm_token=handle_oov_mm_token,
613+
)
614+
615+
if multimodal_embeddings is None or len(multimodal_embeddings) == 0:
616+
return inputs_embeds
617+
618+
if is_multimodal is None:
619+
raise ValueError(
620+
"`get_input_embeddings` now requires `is_multimodal` arg, "
621+
"please update your model runner according to "
622+
"https://github.com/vllm-project/vllm/pull/16229.")
623+
624+
return _merge_multimodal_embeddings(
625+
inputs_embeds=inputs_embeds,
626+
is_multimodal=is_multimodal,
627+
multimodal_embeddings=multimodal_embeddings,
628+
)

vllm_ascend/models/qwen2_5_vl_without_padding.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from einops import rearrange
2727
from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import (
2828
Qwen2_5_VLConfig, Qwen2_5_VLVisionConfig)
29+
from vllm.model_executor.models.interfaces import MultiModalEmbeddings
2930

3031
try:
3132
from transformers.models.qwen3_vl.configuration_qwen3_vl import \
@@ -523,6 +524,71 @@ def _process_video_input(self, video_input) -> tuple[torch.Tensor, ...]:
523524
sizes = grid_thw.prod(-1) // merge_size // merge_size
524525
return video_embeds.split(sizes.tolist())
525526

527+
def _get_text_embeddings(
528+
self,
529+
input_ids: torch.Tensor,
530+
get_input_embeddings: Callable[[torch.Tensor], torch.Tensor],
531+
*,
532+
is_multimodal: Optional[torch.Tensor],
533+
handle_oov_mm_token: bool,
534+
) -> torch.Tensor:
535+
if handle_oov_mm_token and is_multimodal is not None:
536+
is_text = ~is_multimodal
537+
text_embeds = get_input_embeddings(input_ids[is_text])
538+
539+
return torch.empty(
540+
(input_ids.shape[0], text_embeds.shape[1]),
541+
dtype=text_embeds.dtype,
542+
device=text_embeds.device,
543+
).masked_scatter_(is_text.unsqueeze_(-1), text_embeds)
544+
545+
return get_input_embeddings(input_ids)
546+
547+
def get_input_embeddings(
548+
self,
549+
input_ids: torch.Tensor,
550+
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
551+
*,
552+
is_multimodal: Optional[torch.Tensor] = None,
553+
handle_oov_mm_token: bool = False,
554+
) -> torch.Tensor:
555+
"""
556+
Apply token embeddings to `input_ids`.
557+
558+
If `multimodal_embeddings` is passed, scatter them into
559+
`input_ids` according to the mask `is_multimodal`.
560+
561+
In case the multi-modal token IDs exceed the vocabulary size of
562+
the language model, you can set `handle_oov_mm_token=False`
563+
to avoid calling the language model's `get_input_embeddings` method
564+
on those tokens. Note however that doing so increases memory usage
565+
as an additional buffer is needed to hold the input embeddings.
566+
"""
567+
from vllm.model_executor.models.utils import \
568+
_merge_multimodal_embeddings
569+
570+
inputs_embeds = self._get_text_embeddings(
571+
input_ids,
572+
self.get_language_model().get_input_embeddings,
573+
is_multimodal=is_multimodal,
574+
handle_oov_mm_token=handle_oov_mm_token,
575+
)
576+
577+
if multimodal_embeddings is None or len(multimodal_embeddings) == 0:
578+
return inputs_embeds
579+
580+
if is_multimodal is None:
581+
raise ValueError(
582+
"`get_input_embeddings` now requires `is_multimodal` arg, "
583+
"please update your model runner according to "
584+
"https://github.com/vllm-project/vllm/pull/16229.")
585+
586+
return _merge_multimodal_embeddings(
587+
inputs_embeds=inputs_embeds,
588+
is_multimodal=is_multimodal,
589+
multimodal_embeddings=multimodal_embeddings,
590+
)
591+
526592

527593
@MULTIMODAL_REGISTRY.register_processor(Qwen3VLMultiModalProcessor,
528594
info=Qwen3VLProcessingInfo,

vllm_ascend/worker/model_runner_v1.py

Lines changed: 36 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@
6262
from vllm.model_executor.models.interfaces import supports_transcription
6363
from vllm.model_executor.models.interfaces_base import (
6464
VllmModelForPooling, is_pooling_model, is_text_generation_model)
65+
from vllm.multimodal import MULTIMODAL_REGISTRY
6566
from vllm.multimodal.inputs import MultiModalKwargsItem, PlaceholderRange
6667
from vllm.multimodal.utils import group_mm_kwargs_by_modality
6768
from vllm.pooling_params import PoolingParams
@@ -550,6 +551,14 @@ def _init_mc2_tokens_capacity(self):
550551
num_tokens_per_tp_rank = (max_num_tokens + tp_size - 1) // tp_size
551552
self.mc2_tokens_capacity = num_tokens_per_tp_rank * tp_size
552553

554+
# Only relevant for multimodal models
555+
self.mm_registry = MULTIMODAL_REGISTRY
556+
self.supports_mm_inputs = self.mm_registry.supports_multimodal_inputs(
557+
self.model_config)
558+
if self.supports_mm_inputs:
559+
self.is_mm_embed = self._make_buffer(self.max_num_tokens,
560+
dtype=torch.bool)
561+
553562
def _make_buffer(self,
554563
*size: Union[int, torch.SymInt],
555564
dtype: torch.dtype,
@@ -1034,7 +1043,7 @@ def _batch_mm_kwargs_from_scheduler(
10341043
def _gather_mm_embeddings(
10351044
self,
10361045
scheduler_output: "SchedulerOutput",
1037-
) -> list[torch.Tensor]:
1046+
) -> tuple[list[torch.Tensor], torch.Tensor]:
10381047

10391048
def _iter_mm_features(req_state: CachedRequestState):
10401049
assert req_state.mm_features is not None
@@ -1044,8 +1053,15 @@ def _iter_mm_features(req_state: CachedRequestState):
10441053
pos_info, "is_embed", None)
10451054

10461055
mm_embeds: list[torch.Tensor] = []
1056+
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
1057+
is_mm_embed = self.is_mm_embed.cpu
1058+
is_mm_embed[:total_num_scheduled_tokens] = False
1059+
1060+
req_start_idx = 0
10471061

10481062
for req_id in self.input_batch.req_ids:
1063+
mm_embeds_req: list[torch.Tensor] = []
1064+
10491065
num_scheduled_tokens = scheduler_output.num_scheduled_tokens[
10501066
req_id]
10511067
req_state = self.requests[req_id]
@@ -1074,12 +1090,22 @@ def _iter_mm_features(req_state: CachedRequestState):
10741090
if is_embed is not None:
10751091
is_embed = is_embed[start_idx:end_idx]
10761092

1093+
req_start_pos = req_start_idx + start_pos - num_computed_tokens
1094+
is_mm_embed[req_start_pos+start_idx:req_start_pos + end_idx] \
1095+
= True if is_embed is None else is_embed
1096+
10771097
mm_embeds_item = gather_mm_placeholders(
10781098
encoder_output[start_idx:end_idx],
10791099
is_embed=is_embed,
10801100
)
1081-
mm_embeds.append(mm_embeds_item)
1082-
return mm_embeds
1101+
mm_embeds_req.append(mm_embeds_item)
1102+
1103+
mm_embeds.extend(mm_embeds_req)
1104+
req_start_idx += num_scheduled_tokens
1105+
1106+
is_mm_embed = self.is_mm_embed.copy_to_gpu(total_num_scheduled_tokens)
1107+
1108+
return mm_embeds, is_mm_embed
10831109

10841110
def _get_cumsum_and_arange(
10851111
self,
@@ -1362,17 +1388,17 @@ def _prepare_inputs(
13621388
if self.is_multimodal_model:
13631389
# Run the multimodal encoder if any.
13641390
self._execute_mm_encoder(scheduler_output)
1365-
mm_embeds = self._gather_mm_embeddings(scheduler_output)
1366-
1391+
mm_embeds, is_mm_embed = self._gather_mm_embeddings(
1392+
scheduler_output)
13671393
# NOTE(woosuk): To unify token ids and soft tokens (vision
13681394
# embeddings), we always use embeddings (rather than token ids)
13691395
# as input to the multimodal model, even when the input is text.
13701396
input_ids = self.input_ids[:total_num_scheduled_tokens]
1371-
if mm_embeds:
1372-
inputs_embeds = self.model.get_input_embeddings(
1373-
input_ids, mm_embeds)
1374-
else:
1375-
inputs_embeds = self.model.get_input_embeddings(input_ids)
1397+
inputs_embeds = self.model.get_input_embeddings(
1398+
input_ids,
1399+
multimodal_embeddings=mm_embeds,
1400+
is_multimodal=is_mm_embed,
1401+
)
13761402
# TODO(woosuk): Avoid the copy. Optimize.
13771403
self.inputs_embeds[:total_num_scheduled_tokens].copy_(
13781404
inputs_embeds)

0 commit comments

Comments
 (0)