Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions docs/source/models/supported_models.md
Original file line number Diff line number Diff line change
Expand Up @@ -1045,10 +1045,10 @@ Specified using `--task generate`.
*
* ✅︎
* ✅︎
- * `Ovis2ForConditionalGeneration`<sup>^</sup>
* Ovis2
- * `Ovis`
* Ovis2, Ovis1.6
* T + I<sup>+</sup>
* `AIDC-AI/Ovis2-1B`, `AIDC-AI/Ovis2-2B`, etc.
* `AIDC-AI/Ovis2-1B`, `AIDC-AI/Ovis1.6-Llama3.2-3B`, etc.
*
*
* ✅︎
Expand Down
21 changes: 12 additions & 9 deletions examples/offline_inference/vision_language.py
Original file line number Diff line number Diff line change
Expand Up @@ -725,8 +725,8 @@ def run_nvlm_d(questions: list[str], modality: str) -> ModelRequestData:
)


# Ovis2
def run_ovis2(questions: list[str], modality: str) -> ModelRequestData:
# Ovis
def run_ovis(questions: list[str], modality: str) -> ModelRequestData:
assert modality == "image"

model_name = "AIDC-AI/Ovis2-1B"
Expand All @@ -737,15 +737,18 @@ def run_ovis2(questions: list[str], modality: str) -> ModelRequestData:
max_num_seqs=2,
trust_remote_code=True,
dtype="half",
hf_overrides={"architectures": ["Ovis2ForConditionalGeneration"]},
limit_mm_per_prompt={modality: 1},
)

placeholder = "<image>\n"
prompts = [("<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
f"<|im_start|>user\n{placeholder}"
f"{question}<|im_end|>\n"
"<|im_start|>assistant\n") for question in questions]
tokenizer = AutoTokenizer.from_pretrained(model_name,
trust_remote_code=True)
messages = [[{
'role': 'user',
'content': f"<image>\n{question}"
}] for question in questions]
prompts = tokenizer.apply_chat_template(messages,
tokenize=False,
add_generation_prompt=True)

return ModelRequestData(
engine_args=engine_args,
Expand Down Expand Up @@ -1069,7 +1072,7 @@ def run_skyworkr1v(questions: list[str], modality: str) -> ModelRequestData:
"llama4": run_llama4,
"molmo": run_molmo,
"NVLM_D": run_nvlm_d,
"ovis2": run_ovis2,
"ovis": run_ovis,
"paligemma": run_paligemma,
"paligemma2": run_paligemma2,
"phi3_v": run_phi3v,
Expand Down
22 changes: 12 additions & 10 deletions examples/offline_inference/vision_language_multi_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,8 +436,8 @@ def load_nvlm_d(question: str, image_urls: list[str]) -> ModelRequestData:
)


# Ovis2
def load_ovis2(question: str, image_urls: list[str]) -> ModelRequestData:
# Ovis
def load_ovis(question: str, image_urls: list[str]) -> ModelRequestData:
model_name = "AIDC-AI/Ovis2-1B"

engine_args = EngineArgs(
Expand All @@ -447,15 +447,17 @@ def load_ovis2(question: str, image_urls: list[str]) -> ModelRequestData:
trust_remote_code=True,
dtype="half",
limit_mm_per_prompt={"image": len(image_urls)},
hf_overrides={"architectures": ["Ovis2ForConditionalGeneration"]},
)

placeholder = '\n'.join(
[f'Image {i+1}: <image>' for i in range(len(image_urls))]) + '\n'
prompt = ("<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
f"<|im_start|>user\n{placeholder}"
f"{question}<|im_end|>\n"
"<|im_start|>assistant\n")
placeholders = "\n".join(f"Image-{i}: <image>\n"
for i, _ in enumerate(image_urls, start=1))
messages = [{'role': 'user', 'content': f"{placeholders}\n{question}"}]

tokenizer = AutoTokenizer.from_pretrained(model_name,
trust_remote_code=True)
prompt = tokenizer.apply_chat_template(messages,
tokenize=False,
add_generation_prompt=True)

return ModelRequestData(
engine_args=engine_args,
Expand Down Expand Up @@ -713,7 +715,7 @@ def load_qwen2_5_vl(question: str, image_urls: list[str]) -> ModelRequestData:
"mistral3": load_mistral3,
"mllama": load_mllama,
"NVLM_D": load_nvlm_d,
"ovis2": load_ovis2,
"ovis": load_ovis,
"phi3_v": load_phi3v,
"phi4_mm": load_phi4mm,
"pixtral_hf": load_pixtral_hf,
Expand Down
8 changes: 7 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,10 +355,16 @@ def __init__(
**model_kwargs,
)

# in case some unquantized custom models are not in same dtype
if (getattr(model, "quantization_method", None) is None
and any(p.dtype != self.dtype
for p in model.parameters())):
model = model.to(dtype=self.dtype)

if (getattr(model, "quantization_method", None) != "bitsandbytes"
and len({p.device
for p in model.parameters()}) < 2):
model = model.to(self.device)
model = model.to(device=self.device)

self.model = model

Expand Down
27 changes: 26 additions & 1 deletion tests/models/multimodal/generation/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,6 +476,31 @@
max_num_seqs=2,
patch_hf_runner=model_utils.molmo_patch_hf_runner,
),
"ovis1_6-gemma2": VLMTestInfo(
models=["AIDC-AI/Ovis1.6-Gemma2-9B"],
test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE),
prompt_formatter=lambda img_prompt: f"<bos><start_of_turn>user\n{img_prompt}<end_of_turn>\n<start_of_turn>model\n", # noqa: E501
img_idx_to_prompt=lambda idx: "<image>\n", # noqa: E501
max_model_len=4096,
max_num_seqs=2,
dtype="half",
# use sdpa mode for hf runner since ovis2 didn't work with flash_attn
hf_model_kwargs={"llm_attn_implementation": "sdpa"},
patch_hf_runner=model_utils.ovis_patch_hf_runner,
marks=[large_gpu_mark(min_gb=32)],
),
"ovis1_6": VLMTestInfo(
models=["AIDC-AI/Ovis1.6-Llama3.2-3B"],
test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE),
prompt_formatter=lambda img_prompt: f"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nYou are a helpful and honest multimodal assistant.<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n{img_prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n", # noqa: E501
img_idx_to_prompt=lambda idx: "<image>\n", # noqa: E501
max_model_len=4096,
max_num_seqs=2,
dtype="half",
# use sdpa mode for hf runner since ovis2 didn't work with flash_attn
hf_model_kwargs={"llm_attn_implementation": "sdpa"},
patch_hf_runner=model_utils.ovis_patch_hf_runner,
),
"ovis2": VLMTestInfo(
models=["AIDC-AI/Ovis2-1B"],
test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE),
Expand All @@ -486,7 +511,7 @@
dtype="half",
# use sdpa mode for hf runner since ovis2 didn't work with flash_attn
hf_model_kwargs={"llm_attn_implementation": "sdpa"},
patch_hf_runner=model_utils.ovis2_patch_hf_runner,
patch_hf_runner=model_utils.ovis_patch_hf_runner,
),
"phi3v": VLMTestInfo(
models=["microsoft/Phi-3.5-vision-instruct"],
Expand Down
17 changes: 11 additions & 6 deletions tests/models/multimodal/generation/vlm_utils/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -678,20 +678,25 @@ def _generate(self, max_new_tokens=None, do_sample=None, **kwargs):
return hf_model


def ovis2_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
def ovis_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
"""Patches and returns an instance of the HfRunner to use for Ovis2."""
hf_model.model.visual_tokenizer.to(hf_model.dtype)
hf_model.model.vte.to(hf_model.dtype)
hf_model.model.llm.to(hf_model.dtype)

hf_model.model.get_output_embeddings = lambda: \
hf_model.model.llm.get_output_embeddings()

def processor(*args, text="", images=None, **kwargs):
text_tokenizer = hf_model.model.get_text_tokenizer()
images = [images] if isinstance(images, Image) else images

text = text.split("<|im_start|>user\n")[1].split("<|im_end|>\n")[0]
prompt_start_and_end = {
"qwen2": ("<|im_start|>user\n", "<|im_end|>\n"),
"llama":
("<|start_header_id|>user<|end_header_id|>\n\n", "<|eot_id|>"),
"gemma2": ("<start_of_turn>user\n", "<end_of_turn>\n"),
}
for start, end in prompt_start_and_end.values():
if start in text and end in text:
text = text.split(start)[1].split(end)[0]
break

prompt, input_ids, pixel_values = hf_model.model.preprocess_inputs(
text_or_conversations=text, images=images)
Expand Down
5 changes: 4 additions & 1 deletion tests/models/multimodal/processing/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,8 @@ def _test_processing_correctness_hf(
batch_idx: int,
ignore_mm_keys: Optional[set[str]] = None,
):
if model_config.hf_config.model_type in ("mllama", "whisper", "ultravox"):
if model_config.hf_config.model_type in ("mllama", "ovis", "ultravox",
"whisper"):
# For some multimodal models, tokenizer will always add bos_token
# at the beginning of prompt by default, causing hf_processor outputs
# incorrect token ids. So we need use `add_special_tokens=False` here
Expand Down Expand Up @@ -274,6 +275,8 @@ def _test_processing_correctness_mistral(
"allenai/Molmo-7B-D-0924",
"allenai/Molmo-7B-O-0924",
"nvidia/NVLM-D-72B",
"AIDC-AI/Ovis1.6-Gemma2-9B",
"AIDC-AI/Ovis1.6-Llama3.2-3B",
"AIDC-AI/Ovis2-1B",
"google/paligemma-3b-mix-224",
"google/paligemma2-3b-ft-docci-448",
Expand Down
6 changes: 3 additions & 3 deletions tests/models/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,9 +355,9 @@ def check_available_online(
max_transformers_version="4.48",
transformers_version_reason="Use of deprecated imports which have been removed.", # noqa: E501
extras={"phi3.5": "microsoft/Phi-3.5-vision-instruct"}), # noqa: E501
"Ovis2ForConditionalGeneration": _HfExamplesInfo("AIDC-AI/Ovis2-1B",
trust_remote_code=True,
hf_overrides={"architectures": ["Ovis2ForConditionalGeneration"]}), # noqa: E501
"Ovis": _HfExamplesInfo("AIDC-AI/Ovis2-1B", trust_remote_code=True,
extras={"1.6-llama": "AIDC-AI/Ovis1.6-Llama3.2-3B",
"1.6-gemma": "AIDC-AI/Ovis1.6-Gemma2-9B"}), # noqa: E501
"Phi4MMForCausalLM": _HfExamplesInfo("microsoft/Phi-4-multimodal-instruct",
trust_remote_code=True),
"PixtralForConditionalGeneration": _HfExamplesInfo("mistralai/Pixtral-12B-2409", # noqa: E501
Expand Down
2 changes: 1 addition & 1 deletion vllm/entrypoints/chat_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,7 +512,7 @@ def _placeholder_str(self, modality: ModalityStr,
hf_config.image_token_index)

if model_type in ("aya_vision", "chameleon", "deepseek_vl_v2",
"internvl_chat", "ovis2", "skywork_chat",
"internvl_chat", "ovis", "skywork_chat",
"NVLM_D", "h2ovl_chat", "idefics3", "smolvlm"):
return "<image>"
if model_type in ("mllama", "llama4"):
Expand Down
127 changes: 2 additions & 125 deletions vllm/model_executor/models/aimv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,129 +5,14 @@
from typing import Optional

import torch
from torch import nn, softmax
import torch.nn as nn
from torch.nn import functional as F
from torch.nn.functional import gumbel_softmax, pad

from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import ReplicatedLinear
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.transformers_utils.configs.ovis2 import (AIMv2Config,
Aimv2VisualTokenizerConfig)

IMAGE_INDICATOR_IDS = [-301, -302, -303, -304,
-305] # kept for vocab prefixed tokens


def st_argmax(y_soft: torch.Tensor, dim: int): # straight-through softmax
index = y_soft.max(dim, keepdim=True)[1]
y_hard = torch.zeros_like(
y_soft, memory_format=torch.legacy_contiguous_format).scatter_(
dim, index, 1.0)
ret = y_hard - y_soft.detach() + y_soft
return ret


class Aimv2VisualTokenizer(torch.nn.Module):

def __init__(self,
config: Aimv2VisualTokenizerConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
**kwargs):
super().__init__()
self.config = config
self.backbone = AIMv2Model(
config=config.backbone_config, # noqa
quant_config=quant_config,
prefix=f"{prefix}.visual_tokenizer")
# reserved tokens for IMAGE_INDICATORS
head_dim = config.vocab_size - len(IMAGE_INDICATOR_IDS)
self.head = torch.nn.Sequential(
ReplicatedLinear(
config.backbone_config.hidden_size * config.hidden_stride *
config.hidden_stride,
head_dim,
bias=False,
), torch.nn.LayerNorm(head_dim))

@property
def dtype(self):
return self.backbone.dtype

@property
def device(self):
return self.backbone.device

def tokenize(self, logits):
if self.config.tokenize_function == 'softmax':
tokens = softmax(logits, dim=-1)
elif self.config.tokenize_function == 'gumbel_argmax':
tokens = gumbel_softmax(logits, tau=self.config.tau, hard=True)
elif self.config.tokenize_function == 'st_argmax':
tokens = st_argmax(logits, dim=-1)
else:
raise ValueError(
'Invalid `max_type`, expected softmax or gumbel_argmax '
f'or st_argmax, but got {self.config.tokenize_function}')
return tokens

def encode(self, pixel_values):
features = self.backbone(pixel_values)
if self.config.drop_cls_token:
features = features[:, 1:, :]

# merge number of `hidden_stride * hidden_stride` hidden states together
# to reduce token sequence length
# e.g., for hidden_stride=2, this leads to a token length reduction:
# 1024 -> 256 for aimv2
if self.config.hidden_stride > 1:
# this `d` maybe different from the above `d``
n, L, d = features.shape
sqrt_l = int(L**0.5)
assert sqrt_l**2 == L, (
"The token sequence length should be a perfect square.")
features = features.reshape(n, sqrt_l, sqrt_l, d)
pl = (self.config.hidden_stride -
(sqrt_l %
self.config.hidden_stride)) % self.config.hidden_stride
features = pad(features, (0, 0, 0, pl, 0, pl), "constant", 0)
sqrt_l += pl
features = features.reshape(n, sqrt_l // self.config.hidden_stride,
self.config.hidden_stride,
sqrt_l // self.config.hidden_stride,
self.config.hidden_stride, d)
# [n, sqrt_l/hs, sqrt_l/hs, hs, hs, d]
features = features.permute(0, 1, 3, 2, 4, 5)
# [n, sqrt_l/hs, sqrt_l/hs, hs*hs*d]
features = features.flatten(3)
# [n, sqrt_l/hs*sqrt_l/hs, hs*hs*d]
features = features.reshape(
n, -1,
self.config.hidden_stride * self.config.hidden_stride * d)

return features

def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
"""[BatchSize, ImageShape] -> [BatchSize, Token, VocabSize]"""
features = self.encode(pixel_values)
logits, _ = self.head[0](
features) # we spllit the sequncial here for not throwing an error
logits = self.head[1](logits)
tokens = self.tokenize(logits)
# tokens' shape is [BatchSize, #Token, VocabSize-5], so padding with
# [BatchSize, #Token, 5], after which, tokens' shape should become
# [BatchSize, #Token, VocabSize]
batch_size, token_len, _ = tokens.shape
padding_tensor = torch.zeros(size=(batch_size, token_len,
len(IMAGE_INDICATOR_IDS)),
dtype=tokens.dtype,
device=tokens.device,
layout=tokens.layout,
requires_grad=False)
tokens = torch.cat((tokens, padding_tensor), dim=2)
return tokens
from vllm.transformers_utils.configs.ovis import AIMv2Config


class AIMv2SwiGLUFFN(nn.Module):
Expand Down Expand Up @@ -302,14 +187,6 @@ def __init__(self,
quant_config=quant_config,
prefix=f"{prefix}.trunk")

@property
def dtype(self):
return self.trunk.blocks[0].attn.qkv.weight.dtype

@property
def device(self):
return self.trunk.blocks[0].attn.qkv.device

def forward(
self,
pixel_values: torch.Tensor,
Expand Down
Loading