Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions tests/distributed/test_pipeline_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,7 @@ def iter_params(self, model_id: str):
"openbmb/MiniCPM-Llama3-V-2_5": PPTestSettings.fast(),
"allenai/Molmo-7B-D-0924": PPTestSettings.fast(),
"AIDC-AI/Ovis2-1B": PPTestSettings.fast(),
"AIDC-AI/Ovis2.5-2B": PPTestSettings.fast(),
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have confirmed this test set can pass after increasing max_model_len to 8192.

"microsoft/Phi-3.5-vision-instruct": PPTestSettings.fast(),
"mistralai/Pixtral-12B-2409": PPTestSettings.fast(load_format="dummy"),
"Qwen/Qwen-VL-Chat": PPTestSettings.fast(),
Expand Down
6 changes: 1 addition & 5 deletions tests/models/multimodal/generation/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import pytest
from transformers import (AutoModel, AutoModelForImageTextToText,
AutoModelForTextToWaveform, AutoModelForVision2Seq)
from transformers.utils import is_flash_attn_2_available

from vllm.platforms import current_platform
from vllm.utils import identity
Expand Down Expand Up @@ -637,10 +636,7 @@
dtype="half",
num_logprobs=10,
patch_hf_runner=model_utils.ovis2_5_patch_hf_runner,
marks=[pytest.mark.skipif(
not is_flash_attn_2_available(),
reason="HF model needs `flash_attn` installed"
)],
hf_model_kwargs={"revision": "refs/pr/5"},
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With this revision, we can run the model test without flash-attn installed now:

tests/models/multimodal/generation/test_common.py::test_video_models[ovis2_5-test_case3]
  /kaggle/working/vllm/tests/models/multimodal/generation/vlm_utils/core.py:154: UserWarning: Test1:
  Matched tokens:       [151667, 198, 20002, 99601, 85106, 101042, 100678, 99487, 87140, 103027, 1773, 101140, 50930, 102650, 5122, 102833, 100469, 103645, 100811, 3837, 108391, 105666, 104433, 104972, 3837, 102196, 33108, 102936, 99165, 100243, 116434, 1773, 101889]
  hf:   '<think>\n用户现在需要分析为什么这个视频有趣。首先看画面:婴儿戴着眼镜,模仿大人读书的样子,动作和表情很滑稽。然后分解元素:\n\n1. 婴儿的“阅读”行为:婴儿模仿大人读书,动作笨拙但可爱,比如翻页、专注的样子,和成人读书的场景形成反差,很幽默。\n2. 眼镜的拟人化:婴儿戴眼镜,像是在认真阅读,这种拟人化的表现很有趣,因为婴儿戴眼镜是现实中不太常见的,加上模仿阅读,强化了喜剧效果。\n3. �'     {107799: -1.5044023990631104, 104449: -2.0981523990631104, 50930: -2.2387773990631104, 100062: -2.8325273990631104, 30534: -2.9419023990631104, 99172: -3.0044023990631104, 20412: -3.1762773990631104, 104107: -3.6137773990631104, 3837: -3.7856523990631104, 101348: -3.7856523990631104}
  vllm: '<think>\n用户现在需要分析为什么这个视频有趣。首先看画面:婴儿戴着眼镜,模仿大人读书的样子,动作和表情很滑稽。然后细节:婴儿的动作(翻书、抬手)像在认真阅读,眼镜的拟人化,还有环境(床上、背景的家具)营造的居家氛围,加上婴儿的天真可爱,模仿成人行为的反差萌,这些元素结合起来让视频有幽默感。\n\n首先,**拟人化与模仿**:婴儿戴着眼镜,模仿大人读书,这种“成人化”的行为在婴儿身上显得滑稽,因为婴儿本'     {104449: Logprob(logprob=-1.8062855005264282, rank=1, decoded_token='细节'), 107799: Logprob(logprob=-2.4156603813171387, rank=2, decoded_token='分解'), 30534: Logprob(logprob=-2.6187853813171387, rank=3, decoded_token='要'), 100374: Logprob(logprob=-2.6187853813171387, rank=4, decoded_token='结合'), 50930: Logprob(logprob=-2.7281603813171387, rank=5, decoded_token='看'), 20412: Logprob(logprob=-2.8531603813171387, rank=6, decoded_token='是'), 102122: Logprob(logprob=-2.9156603813171387, rank=7, decoded_token='场景'), 99719: Logprob(logprob=-3.0562853813171387, rank=8, decoded_token='环境'), 99172: Logprob(logprob=-3.4156603813171387, rank=9, decoded_token='想'), 3837: Logprob(logprob=-3.5250353813171387, rank=10, decoded_token=',')}
    comparator(

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
=============================================== 12 passed, 302 deselected, 29 warnings in 1048.85s (0:17:28) ================================================

),
"phi3v": VLMTestInfo(
models=["microsoft/Phi-3.5-vision-instruct"],
Expand Down
4 changes: 1 addition & 3 deletions tests/models/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,9 +465,7 @@ def check_available_online(
extras={"1.6-llama": "AIDC-AI/Ovis1.6-Llama3.2-3B",
"1.6-gemma": "AIDC-AI/Ovis1.6-Gemma2-9B"}), # noqa: E501
"Ovis2_5": _HfExamplesInfo("AIDC-AI/Ovis2.5-2B",
trust_remote_code=True,
max_transformers_version="4.53",
transformers_version_reason="HF model is not compatible"), # noqa: E501
trust_remote_code=True),
"PaliGemmaForConditionalGeneration": _HfExamplesInfo("google/paligemma-3b-mix-224", # noqa: E501
extras={"v2": "google/paligemma2-3b-ft-docci-448"}), # noqa: E501
"Phi3VForCausalLM": _HfExamplesInfo("microsoft/Phi-3-vision-128k-instruct",
Expand Down
36 changes: 20 additions & 16 deletions vllm/model_executor/models/ovis2_5.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.processors.ovis2_5 import Ovis2_5Processor

from .interfaces import MultiModalEmbeddings, SupportsMultiModal
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP

IMAGE_TOKEN = "<image>"
VIDEO_TOKEN = "<video>"
Expand Down Expand Up @@ -70,13 +70,15 @@ def __init__(
visual_vocab_size: int,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
use_data_parallel: bool = False,
):
super().__init__()
self.config = config
self.vit = self._init_backbone(
config=config,
quant_config=quant_config,
prefix=f"{prefix}.vit",
use_data_parallel=use_data_parallel,
)
# reserved tokens for INDICATOR_IDS
head_dim = visual_vocab_size - len(INDICATOR_IDS)
Expand All @@ -93,39 +95,42 @@ def _init_backbone(
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
use_data_parallel: bool = False,
):
model_type = config.model_type
if model_type == "siglip2_navit":
return Siglip2NavitModel(config=config, )
return Siglip2NavitModel(config=config,
quant_config=quant_config,
prefix=prefix,
use_data_parallel=use_data_parallel)
raise ValueError(
f"Unsupported visual tokenizer model_type: {model_type}")

@property
def dtype(self):
def dtype(self) -> torch.dtype:
return next(self.head.parameters()).dtype

@property
def device(self):
def device(self) -> torch.device:
return next(self.head.parameters()).device

def tokenize(self, logits):
def tokenize(self, logits: torch.Tensor) -> torch.Tensor:
tokens = torch.softmax(logits, dim=-1,
dtype=torch.float32).to(logits.dtype)
return tokens

def encode(self, pixel_values, grid_thws):
features = self.vit(pixel_values,
grid_thws,
output_hidden_states=True,
return_dict=True)
def encode(self, pixel_values: torch.Tensor,
grid_thws: torch.Tensor) -> torch.Tensor:
features = self.vit(pixel_values, grid_thws)
# refer to qwen2.5-vl patchmerger
seq_len, _ = features.shape
features = features.reshape(seq_len // (self.config.hidden_stride**2),
-1)

return features

def forward(self, pixel_values, grid_thws) -> torch.Tensor:
def forward(self, pixel_values: torch.Tensor,
grid_thws: torch.Tensor) -> torch.Tensor:
features = self.encode(pixel_values, grid_thws)
logits = self.head(features)
tokens = self.tokenize(logits)
Expand Down Expand Up @@ -395,7 +400,7 @@ def get_replacement_ovis(item_idx, modality: str):
@MULTIMODAL_REGISTRY.register_processor(Ovis2_5MultiModalProcessor,
info=Ovis2_5ProcessingInfo,
dummy_inputs=Ovis2_5DummyInputsBuilder)
class Ovis2_5(nn.Module, SupportsMultiModal):
class Ovis2_5(nn.Module, SupportsMultiModal, SupportsPP):

def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
Expand All @@ -421,9 +426,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
text_model_type = self.config.get_text_config().model_type
self.image_pad_token_id = IMAGE_PAD_TOKEN_ID_MAP[text_model_type]

# TODO(Isotr0py): PP support
# self.make_empty_intermediate_tensors = (
# self.language_model.make_empty_intermediate_tensors)
self.make_empty_intermediate_tensors = (
self.get_language_model().make_empty_intermediate_tensors)

def _parse_and_validate_visual_input(
self, is_video,
Expand Down Expand Up @@ -567,4 +571,4 @@ def load_weights(self, weights: Iterable[tuple[str,
return loader.load_weights(weights)

def get_language_model(self) -> torch.nn.Module:
return self.llm
return self.llm
Loading