Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions vllm/model_executor/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from vllm.logger import init_logger
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.multimodal import MultiModalPlaceholderMap, NestedTensors
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
from vllm.utils import is_pin_memory_available

Expand Down Expand Up @@ -443,6 +444,14 @@ def merge_multimodal_embeddings(
Note:
This updates ``inputs_embeds`` in place.
"""
if current_platform.is_hpu():
return _hpu_merge_multimodal_embeddings(
input_ids,
inputs_embeds,
multimodal_embeddings,
placeholder_token_id,
)

if isinstance(placeholder_token_id, list):
placeholder_token_id = torch.tensor(placeholder_token_id,
device=input_ids.device)
Expand Down Expand Up @@ -641,3 +650,28 @@ def extract_layer_index(layer_name: str) -> int:
assert len(int_vals) == 1, (f"layer name {layer_name} should"
" only contain one integer")
return int_vals[0]


def _hpu_merge_multimodal_embeddings(
input_ids: torch.Tensor,
inputs_embeds: torch.Tensor,
multimodal_embeddings: NestedTensors,
placeholder_token_id: torch.tensor,
) -> torch.Tensor:
"""
Merge ``multimodal_embeddings`` into ``inputs_embeds`` by overwriting the
positions in ``inputs_embeds`` corresponding to placeholder tokens in
``input_ids``.
merge_multimodal_embeddings on HPU to avoid dynamicity.
Note:
This updates ``inputs_embeds`` in place.
"""
batch_size, seq_length, hidden_size = inputs_embeds.shape
inputs_embeds = inputs_embeds.reshape(-1, hidden_size)
multimodal_embeddings = multimodal_embeddings.reshape(-1, hidden_size)
placeholder_token_id = torch.tensor(placeholder_token_id,
device=input_ids.device)
mask = torch.isin(input_ids.reshape(-1), placeholder_token_id)
inputs_embeds.index_put_((mask, ), multimodal_embeddings)
inputs_embeds = inputs_embeds.reshape(batch_size, seq_length, hidden_size)
return inputs_embeds