Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 66 additions & 0 deletions vllm_ascend/models/qwen2_5_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.interfaces import MultiModalEmbeddings
from vllm.model_executor.models.qwen2_5_vl import (
Qwen2_5_VisionAttention, Qwen2_5_VisionBlock, Qwen2_5_VisionPatchEmbed,
Qwen2_5_VisionRotaryEmbedding, Qwen2_5_VisionTransformer,
Expand Down Expand Up @@ -560,3 +561,68 @@ def _process_video_input(self, video_input) -> tuple[torch.Tensor, ...]:
merge_size = self.visual.spatial_merge_size
sizes = grid_thw.prod(-1) // merge_size // merge_size
return video_embeds.split(sizes.tolist())

def _get_text_embeddings(
self,
input_ids: torch.Tensor,
get_input_embeddings: Callable[[torch.Tensor], torch.Tensor],
*,
is_multimodal: Optional[torch.Tensor],
handle_oov_mm_token: bool,
) -> torch.Tensor:
if handle_oov_mm_token and is_multimodal is not None:
is_text = ~is_multimodal
text_embeds = get_input_embeddings(input_ids[is_text])

return torch.empty(
(input_ids.shape[0], text_embeds.shape[1]),
dtype=text_embeds.dtype,
device=text_embeds.device,
).masked_scatter_(is_text.unsqueeze_(-1), text_embeds)

return get_input_embeddings(input_ids)

def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
*,
is_multimodal: Optional[torch.Tensor] = None,
handle_oov_mm_token: bool = False,
) -> torch.Tensor:
"""
Apply token embeddings to `input_ids`.

If `multimodal_embeddings` is passed, scatter them into
`input_ids` according to the mask `is_multimodal`.

In case the multi-modal token IDs exceed the vocabulary size of
the language model, you can set `handle_oov_mm_token=False`
to avoid calling the language model's `get_input_embeddings` method
on those tokens. Note however that doing so increases memory usage
as an additional buffer is needed to hold the input embeddings.
"""
from vllm.model_executor.models.utils import \
_merge_multimodal_embeddings

inputs_embeds = self._get_text_embeddings(
input_ids,
self.get_language_model().get_input_embeddings,
is_multimodal=is_multimodal,
handle_oov_mm_token=handle_oov_mm_token,
)

if multimodal_embeddings is None or len(multimodal_embeddings) == 0:
return inputs_embeds

if is_multimodal is None:
raise ValueError(
"`get_input_embeddings` now requires `is_multimodal` arg, "
"please update your model runner according to "
"https://github.com/vllm-project/vllm/pull/16229.")

return _merge_multimodal_embeddings(
inputs_embeds=inputs_embeds,
is_multimodal=is_multimodal,
multimodal_embeddings=multimodal_embeddings,
)
66 changes: 66 additions & 0 deletions vllm_ascend/models/qwen2_5_vl_without_padding.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from einops import rearrange
from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import (
Qwen2_5_VLConfig, Qwen2_5_VLVisionConfig)
from vllm.model_executor.models.interfaces import MultiModalEmbeddings

try:
from transformers.models.qwen3_vl.configuration_qwen3_vl import \
Expand Down Expand Up @@ -523,6 +524,71 @@ def _process_video_input(self, video_input) -> tuple[torch.Tensor, ...]:
sizes = grid_thw.prod(-1) // merge_size // merge_size
return video_embeds.split(sizes.tolist())

def _get_text_embeddings(
self,
input_ids: torch.Tensor,
get_input_embeddings: Callable[[torch.Tensor], torch.Tensor],
*,
is_multimodal: Optional[torch.Tensor],
handle_oov_mm_token: bool,
) -> torch.Tensor:
if handle_oov_mm_token and is_multimodal is not None:
is_text = ~is_multimodal
text_embeds = get_input_embeddings(input_ids[is_text])

return torch.empty(
(input_ids.shape[0], text_embeds.shape[1]),
dtype=text_embeds.dtype,
device=text_embeds.device,
).masked_scatter_(is_text.unsqueeze_(-1), text_embeds)

return get_input_embeddings(input_ids)

def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
*,
is_multimodal: Optional[torch.Tensor] = None,
handle_oov_mm_token: bool = False,
) -> torch.Tensor:
"""
Apply token embeddings to `input_ids`.

If `multimodal_embeddings` is passed, scatter them into
`input_ids` according to the mask `is_multimodal`.

In case the multi-modal token IDs exceed the vocabulary size of
the language model, you can set `handle_oov_mm_token=False`
to avoid calling the language model's `get_input_embeddings` method
on those tokens. Note however that doing so increases memory usage
as an additional buffer is needed to hold the input embeddings.
"""
from vllm.model_executor.models.utils import \
_merge_multimodal_embeddings

inputs_embeds = self._get_text_embeddings(
input_ids,
self.get_language_model().get_input_embeddings,
is_multimodal=is_multimodal,
handle_oov_mm_token=handle_oov_mm_token,
)

if multimodal_embeddings is None or len(multimodal_embeddings) == 0:
return inputs_embeds

if is_multimodal is None:
raise ValueError(
"`get_input_embeddings` now requires `is_multimodal` arg, "
"please update your model runner according to "
"https://github.com/vllm-project/vllm/pull/16229.")

return _merge_multimodal_embeddings(
inputs_embeds=inputs_embeds,
is_multimodal=is_multimodal,
multimodal_embeddings=multimodal_embeddings,
)


@MULTIMODAL_REGISTRY.register_processor(Qwen3VLMultiModalProcessor,
info=Qwen3VLProcessingInfo,
Expand Down
46 changes: 36 additions & 10 deletions vllm_ascend/worker/model_runner_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
from vllm.model_executor.models.interfaces import supports_transcription
from vllm.model_executor.models.interfaces_base import (
VllmModelForPooling, is_pooling_model, is_text_generation_model)
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import MultiModalKwargsItem, PlaceholderRange
from vllm.multimodal.utils import group_mm_kwargs_by_modality
from vllm.pooling_params import PoolingParams
Expand Down Expand Up @@ -550,6 +551,14 @@ def _init_mc2_tokens_capacity(self):
num_tokens_per_tp_rank = (max_num_tokens + tp_size - 1) // tp_size
self.mc2_tokens_capacity = num_tokens_per_tp_rank * tp_size

# Only relevant for multimodal models
self.mm_registry = MULTIMODAL_REGISTRY
self.supports_mm_inputs = self.mm_registry.supports_multimodal_inputs(
self.model_config)
if self.supports_mm_inputs:
self.is_mm_embed = self._make_buffer(self.max_num_tokens,
dtype=torch.bool)

def _make_buffer(self,
*size: Union[int, torch.SymInt],
dtype: torch.dtype,
Expand Down Expand Up @@ -1034,7 +1043,7 @@ def _batch_mm_kwargs_from_scheduler(
def _gather_mm_embeddings(
self,
scheduler_output: "SchedulerOutput",
) -> list[torch.Tensor]:
) -> tuple[list[torch.Tensor], torch.Tensor]:

def _iter_mm_features(req_state: CachedRequestState):
assert req_state.mm_features is not None
Expand All @@ -1044,8 +1053,15 @@ def _iter_mm_features(req_state: CachedRequestState):
pos_info, "is_embed", None)

mm_embeds: list[torch.Tensor] = []
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
is_mm_embed = self.is_mm_embed.cpu
is_mm_embed[:total_num_scheduled_tokens] = False

req_start_idx = 0

for req_id in self.input_batch.req_ids:
mm_embeds_req: list[torch.Tensor] = []

num_scheduled_tokens = scheduler_output.num_scheduled_tokens[
req_id]
req_state = self.requests[req_id]
Expand Down Expand Up @@ -1074,12 +1090,22 @@ def _iter_mm_features(req_state: CachedRequestState):
if is_embed is not None:
is_embed = is_embed[start_idx:end_idx]

req_start_pos = req_start_idx + start_pos - num_computed_tokens
is_mm_embed[req_start_pos+start_idx:req_start_pos + end_idx] \
= True if is_embed is None else is_embed

mm_embeds_item = gather_mm_placeholders(
encoder_output[start_idx:end_idx],
is_embed=is_embed,
)
mm_embeds.append(mm_embeds_item)
return mm_embeds
mm_embeds_req.append(mm_embeds_item)

mm_embeds.extend(mm_embeds_req)
req_start_idx += num_scheduled_tokens

is_mm_embed = self.is_mm_embed.copy_to_gpu(total_num_scheduled_tokens)

return mm_embeds, is_mm_embed

def _get_cumsum_and_arange(
self,
Expand Down Expand Up @@ -1362,17 +1388,17 @@ def _prepare_inputs(
if self.is_multimodal_model:
# Run the multimodal encoder if any.
self._execute_mm_encoder(scheduler_output)
mm_embeds = self._gather_mm_embeddings(scheduler_output)

mm_embeds, is_mm_embed = self._gather_mm_embeddings(
scheduler_output)
# NOTE(woosuk): To unify token ids and soft tokens (vision
# embeddings), we always use embeddings (rather than token ids)
# as input to the multimodal model, even when the input is text.
input_ids = self.input_ids[:total_num_scheduled_tokens]
if mm_embeds:
inputs_embeds = self.model.get_input_embeddings(
input_ids, mm_embeds)
else:
inputs_embeds = self.model.get_input_embeddings(input_ids)
inputs_embeds = self.model.get_input_embeddings(
input_ids,
multimodal_embeddings=mm_embeds,
is_multimodal=is_mm_embed,
)
Comment on lines +1397 to +1401
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The call to get_input_embeddings is missing the handle_oov_mm_token argument. Without it, the default value False will be used, and the logic to handle out-of-vocabulary (OOV) multimodal tokens will not be triggered. This can lead to errors when the input contains multimodal tokens with IDs outside the language model's vocabulary range, which is likely the root cause of the issue this PR aims to fix. You should pass handle_oov_mm_token=self.supports_mm_inputs to correctly handle these cases.

Suggested change
inputs_embeds = self.model.get_input_embeddings(
input_ids,
multimodal_embeddings=mm_embeds,
is_multimodal=is_mm_embed,
)
inputs_embeds = self.model.get_input_embeddings(
input_ids,
multimodal_embeddings=mm_embeds,
is_multimodal=is_mm_embed,
handle_oov_mm_token=self.supports_mm_inputs,
)

# TODO(woosuk): Avoid the copy. Optimize.
self.inputs_embeds[:total_num_scheduled_tokens].copy_(
inputs_embeds)
Expand Down
Loading