Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions tests/models/multimodal/generation/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,8 +222,7 @@
vllm_runner_kwargs={
"model_impl": "transformers",
},
# FIXME: Investigate mrope issue
marks=[large_gpu_mark(min_gb=32), pytest.mark.skip(reason="Mrope issue")],
marks=[large_gpu_mark(min_gb=32)],
),
#### Extended model tests
"aria": VLMTestInfo(
Expand Down
103 changes: 37 additions & 66 deletions vllm/model_executor/models/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,6 @@
AutoWeightsLoader,
PPMissingLayer,
WeightsMapper,
flatten_bn,
make_empty_intermediate_tensors_factory,
maybe_prefix,
)
Expand Down Expand Up @@ -347,54 +346,37 @@ def _get_prompt_updates(

def _get_mm_fields_config(
self,
hf_inputs,
hf_processor_mm_kwargs,
num_image_patches: torch.Tensor = None,
):
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
# HF Processors always return a mask but vLLM doesn't need it
hf_inputs.pop("attention_mask", None)
num_image_patches = hf_inputs.get("num_image_patches")
mm_fields = {
key: MultiModalFieldConfig.flat_from_sizes("image", num_image_patches)
for key in hf_inputs
}
mm_fields["image_embeds"] = MultiModalFieldConfig.flat_from_sizes(
"image", num_image_patches
)

# Keep these as batched, as they always have batch size as first dim
mm_fields["image_grid_thw"] = MultiModalFieldConfig.batched("image")
mm_fields["video_grid_thw"] = MultiModalFieldConfig.batched("image")
mm_fields["num_image_patches"] = MultiModalFieldConfig.batched("image")
return mm_fields

def _apply_hf_processor_text_mm(
def _get_hf_mm_data(
self,
prompt_text: str,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
tokenization_kwargs: Mapping[str, object],
) -> tuple[list[int], BatchFeature, bool]:
) -> tuple[Mapping[str, object], Mapping[str, object]]:
"""
Apply the HF processor on the prompt text and multi-modal data
together.

In addition, return whether prompt replacements have been applied.
In contrast to the base class, this method always adds
`return_mm_token_type_ids` to the processor data
"""
processor_data, passthrough_data = self._get_hf_mm_data(mm_items)
processor_data, passthrough_data = super()._get_hf_mm_data(mm_items)
processor_data["return_mm_token_type_ids"] = True

processed_data = self._call_hf_processor(
prompt=prompt_text,
mm_data=processor_data,
mm_kwargs=hf_processor_mm_kwargs,
tok_kwargs=tokenization_kwargs,
)
processed_data.update(passthrough_data)

(prompt_ids,) = processed_data.pop("input_ids").tolist()
mm_token_type_ids = (
processed_data.pop("mm_token_type_ids")
if "mm_token_type_ids" in processed_data
else processed_data.pop("token_type_ids")
) # for gemma3 only

return prompt_ids, processed_data, mm_token_type_ids
return processor_data, passthrough_data

def apply(
self,
Expand All @@ -421,18 +403,28 @@ def apply(
# into string
prompt = hf_processor.decode(prompt)

(prompt_ids, processed_data, mm_token_type_ids) = (
self._apply_hf_processor_text_mm(
prompt_text=prompt,
mm_items=mm_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
tokenization_kwargs=tokenization_kwargs,
)
# Bypass cached processor and always apply to the full set of mm inputs
# NOTE: we can't just set caching=False because base class method
# transforms outputs to `MultiModalKwargs` which is not going to
# work for Transformers. We have a lot of logic tied to
# `mm_tokens_per_modality` below
prompt_ids, processed_data, _ = self._apply_hf_processor_text_mm(
prompt_text=prompt,
mm_items=mm_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
tokenization_kwargs=tokenization_kwargs,
)

# HF processor will return `mm_token_type_ids` from which
# we can infer mm_placeholders. Until then hardcode to make code run
# Below tested on Llava. Prompts and `mm_token_type_ids` are always bs=1
# For gemma3 we check `token_type_ids` as the key
token_type_key = (
"mm_token_type_ids"
if "mm_token_type_ids" in processed_data
else "token_type_ids"
)
mm_token_type_ids = processed_data.pop(token_type_key)

# We can infer vLLM style placeholder from token type ids, if we split
# it for each input `mm_data`.
mm_positions = torch.where(mm_token_type_ids == 1)[1]
images = mm_items.get_items("image", ImageProcessorItems)
multimodal_config = self.info.ctx.model_config.multimodal_config
Expand Down Expand Up @@ -462,17 +454,12 @@ def apply(
]
mm_placeholders = {"image": ranges}

num_image_patches = (
torch.tensor(mm_tokens_per_modality["num_image_patches"])
if "num_image_patches" in mm_tokens_per_modality
else None
processed_data["num_image_patches"] = torch.tensor(
mm_tokens_per_modality["num_image_patches"]
)
processed_data["num_image_patches"] = num_image_patches
mm_kwargs = MultiModalKwargsItems.from_hf_inputs(
processed_data,
self._get_mm_fields_config(
processed_data, hf_processor_mm_kwargs, num_image_patches
),
self._get_mm_fields_config(processed_data, hf_processor_mm_kwargs),
)

# Use overrides if provided; fallback to data-dependent hashing.
Expand Down Expand Up @@ -533,8 +520,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self.ignore_unexpected_suffixes.append(".bias")

# Set correct attn and init on "meta" to delay allocating GPU tensors
# TODO: @raushan, use the public `model.set_attn_implementation()`
# method once its checks are fixed in Transformers.
self.text_config._attn_implementation = "vllm"
with init_on_device_without_buffers("meta"):
self.model: PreTrainedModel = AutoModel.from_config(
Expand Down Expand Up @@ -842,17 +827,6 @@ def compute_logits(
return logits


def flatten_and_concat(x: list[torch.Tensor]) -> torch.Tensor:
"""Flatten until a list of tensors can be concatenated then do concat"""

def _can_concat(x: list[torch.Tensor]):
return len(set(map(lambda _x: _x.shape[1:], x))) == 1

if _can_concat(x):
return torch.concat(x)
return flatten_and_concat(flatten_bn(x))


@MULTIMODAL_REGISTRY.register_processor(
MultiModalProcessor,
info=MultiModalProcessingInfo,
Expand Down Expand Up @@ -933,9 +907,6 @@ def get_multimodal_embeddings(self, **kwargs):
vision_embeddings = self.model.get_image_features(pixel_values, **kwargs)

if isinstance(vision_embeddings, torch.Tensor):
if isinstance(num_image_patches, list):
num_image_patches = torch.cat(num_image_patches)

if vision_embeddings.ndim == 2:
vision_embeddings = vision_embeddings.unsqueeze(0)

Expand Down