diff --git a/docs/source/models/supported_models.rst b/docs/source/models/supported_models.rst
index 084be1e2a4f8e..0c0a54281e3f3 100644
--- a/docs/source/models/supported_models.rst
+++ b/docs/source/models/supported_models.rst
@@ -242,6 +242,11 @@ Multimodal Language Models
- Image\ :sup:`+`
- :code:`openbmb/MiniCPM-V-2` (see note), :code:`openbmb/MiniCPM-Llama3-V-2_5`, :code:`openbmb/MiniCPM-V-2_6`, etc.
-
+ * - :code:`QWenLMHeadModel`
+ - Qwen
+ - Image
+ - :code:`Qwen/Qwen-VL`, :code:`Qwen/Qwen-VL-Chat`, etc.
+ -
* - :code:`UltravoxModel`
- Ultravox
- Audio\ :sup:`E+`
diff --git a/examples/offline_inference_vision_language.py b/examples/offline_inference_vision_language.py
index 9a0e9d4bc5362..aa1580343aee7 100644
--- a/examples/offline_inference_vision_language.py
+++ b/examples/offline_inference_vision_language.py
@@ -159,6 +159,20 @@ def run_blip2(question):
return llm, prompt, stop_token_ids
+# Qwen
+def run_qwen_vl(question):
+
+ llm = LLM(
+ model="Qwen/Qwen-VL",
+ trust_remote_code=True,
+ max_num_seqs=5,
+ )
+
+ prompt = f"{question}Picture 1: \n"
+ stop_token_ids = None
+ return llm, prompt, stop_token_ids
+
+
model_example_map = {
"llava": run_llava,
"llava-next": run_llava_next,
@@ -169,6 +183,7 @@ def run_blip2(question):
"minicpmv": run_minicpmv,
"blip-2": run_blip2,
"internvl_chat": run_internvl,
+ "qwen_vl": run_qwen_vl,
}
diff --git a/tests/models/test_qwen.py b/tests/models/test_qwen.py
index 0f974fcc1885c..05f5cbf8c3435 100644
--- a/tests/models/test_qwen.py
+++ b/tests/models/test_qwen.py
@@ -1,48 +1,165 @@
-from typing import Type
+import pathlib
+from typing import List, Optional, Type
import pytest
-from ..conftest import HfRunner, VllmRunner
+from vllm.multimodal.utils import rescale_image_size
+
+from ..conftest import IMAGE_ASSETS, HfRunner, VllmRunner, _ImageAssets
from .utils import check_logprobs_close
-models = ["qwen/qwen-vl"]
+pytestmark = pytest.mark.vlm
+text_only_models = [
+ "Qwen/Qwen-7B-Chat" # Has no visual component
+]
-@pytest.mark.parametrize("dtype", ["half"])
-@pytest.mark.parametrize("max_tokens", [32])
-@pytest.mark.parametrize("num_logprobs", [5])
-@pytest.mark.parametrize("model", models)
-def test_text_only_qwen_model(
+multimodal_models = ["Qwen/Qwen-VL"]
+
+HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
+ "stop_sign":
+ "Picture 1: \nWhat's the content of the image?: ",
+ "cherry_blossom":
+ "Picture 1: \nWhat is the season?: ",
+})
+
+
+### Tests for multimodal Qwen models
+def run_test(
+ tmp_path: pathlib.PosixPath,
hf_runner: Type[HfRunner],
vllm_runner: Type[VllmRunner],
- example_prompts,
+ image_assets: _ImageAssets,
model: str,
*,
+ size_factors: List[float],
dtype: str,
max_tokens: int,
num_logprobs: int,
+ tensor_parallel_size: int,
+ distributed_executor_backend: Optional[str] = None,
):
- # This test checks language inputs only, since the visual component
- # for qwen-vl is still unsupported in VLLM. In the near-future, the
- # implementation and this test will be extended to consider
- # visual inputs as well.
+ """Inference result should be the same between hf and vllm.
+
+ All the image fixtures for the test is under tests/images.
+ For huggingface runner, we provide the PIL images as input.
+ For vllm runner, we provide MultiModalDataDict objects
+ and corresponding MultiModalConfig as input.
+ Note, the text input is also adjusted to abide by vllm contract.
+ The text output is sanitized to be able to compare with hf.
+ """
+ images = [asset.pil_image for asset in image_assets]
+
+ # Export the images to a tempdir and substitute it into the hf prompt;
+ # the contents between / will be ignored by VLLM, but the
+ # transformers implementation for the visual transformer parses this to
+ # reload it in the forward call; the contents are treated as a URL or a
+ # local path.
+ for idx, asset in enumerate(image_assets):
+ image_tmp_path = tmp_path / f"{asset.name}.jpg"
+ asset.pil_image.save(image_tmp_path)
+ HF_IMAGE_PROMPTS[idx] = HF_IMAGE_PROMPTS[idx].replace(
+ "", f"{image_tmp_path}")
+
+ inputs_per_image = [(
+ [prompt for _ in size_factors],
+ [rescale_image_size(image, factor) for factor in size_factors],
+ ) for image, prompt in zip(images, HF_IMAGE_PROMPTS)]
+
+ # NOTE: take care of the order. run vLLM first, and then run HF.
+ # vLLM needs a fresh new process without cuda initialization.
+ # if we run HF first, the cuda initialization will be done and it
+ # will hurt multiprocessing backend with fork method (the default method).
+
+ # max_model_len should be greater than image_feature_size
+ # Qwen encodes images into a fixed content size of 256
+ with vllm_runner(model,
+ max_model_len=300,
+ max_num_seqs=1,
+ dtype=dtype,
+ tensor_parallel_size=tensor_parallel_size,
+ distributed_executor_backend=distributed_executor_backend,
+ enforce_eager=True) as vllm_model:
+ vllm_outputs_per_image = [
+ vllm_model.generate_greedy_logprobs(prompts,
+ max_tokens,
+ num_logprobs=num_logprobs,
+ images=images)
+ for prompts, images in inputs_per_image
+ ]
+
with hf_runner(model, dtype=dtype) as hf_model:
- hf_outputs = hf_model.generate_greedy_logprobs_limit(
- example_prompts,
- max_tokens,
- num_logprobs=num_logprobs,
+ hf_outputs_per_image = [
+ hf_model.generate_greedy_logprobs_limit(prompts,
+ max_tokens,
+ num_logprobs=num_logprobs,
+ images=images)
+ for prompts, images in inputs_per_image
+ ]
+
+ for hf_outputs, vllm_outputs in zip(hf_outputs_per_image,
+ vllm_outputs_per_image):
+
+ check_logprobs_close(
+ outputs_0_lst=hf_outputs,
+ outputs_1_lst=vllm_outputs,
+ name_0="hf",
+ name_1="vllm",
)
+
+@pytest.mark.parametrize("model", multimodal_models)
+@pytest.mark.parametrize(
+ "size_factors",
+ [
+ # No image
+ [],
+ # Single-scale
+ [1.0],
+ # Single-scale, batched
+ [1.0, 1.0, 1.0],
+ # Multi-scale
+ [0.25, 0.5, 1.0],
+ ],
+)
+@pytest.mark.parametrize("dtype", ["bfloat16"])
+@pytest.mark.parametrize("max_tokens", [8])
+@pytest.mark.parametrize("num_logprobs", [5])
+def test_multimodal_models(tmp_path, hf_runner, vllm_runner, image_assets,
+ model, size_factors, dtype, max_tokens,
+ num_logprobs) -> None:
+ run_test(
+ tmp_path,
+ hf_runner,
+ vllm_runner,
+ image_assets,
+ model,
+ size_factors=size_factors,
+ dtype=dtype,
+ max_tokens=max_tokens,
+ num_logprobs=num_logprobs,
+ tensor_parallel_size=1,
+ )
+
+
+# Ensure that a text-only Qwen model can still be loaded and
+# used for inference in VLLM without throwing.
+@pytest.mark.parametrize("model", text_only_models)
+@pytest.mark.parametrize("dtype", ["bfloat16"])
+@pytest.mark.parametrize("max_tokens", [32])
+@pytest.mark.parametrize("num_logprobs", [5])
+def test_text_only_qwen_model_can_be_loaded_and_run(
+ vllm_runner: Type[VllmRunner],
+ example_prompts,
+ model: str,
+ *,
+ dtype: str,
+ max_tokens: int,
+ num_logprobs: int,
+):
with vllm_runner(model, dtype=dtype) as vllm_model:
- vllm_outputs = vllm_model.generate_greedy_logprobs(
+ vllm_model.generate_greedy_logprobs(
example_prompts,
max_tokens,
num_logprobs=num_logprobs,
)
-
- check_logprobs_close(
- outputs_0_lst=hf_outputs,
- outputs_1_lst=vllm_outputs,
- name_0="hf",
- name_1="vllm",
- )
diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py
index 9a7493649c795..f9f9536a7c160 100644
--- a/vllm/entrypoints/chat_utils.py
+++ b/vllm/entrypoints/chat_utils.py
@@ -150,6 +150,8 @@ def _placeholder_str(self, modality: ModalityStr,
if model_type in ("blip-2", "chatglm", "fuyu", "paligemma"):
# These models do not use image tokens in the prompt
return None
+ if model_type == "qwen":
+ return f"Picture {current_count}: "
if model_type.startswith("llava"):
return self._cached_token_str(self._tokenizer,
hf_config.image_token_index)
diff --git a/vllm/model_executor/layers/resampler.py b/vllm/model_executor/layers/resampler.py
new file mode 100644
index 0000000000000..8cd938fc85fb2
--- /dev/null
+++ b/vllm/model_executor/layers/resampler.py
@@ -0,0 +1,273 @@
+# coding=utf-8
+# Adapted from
+# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
+# https://huggingface.co/Qwen/Qwen-7B/blob/main/modeling_qwen.py
+# https://github.com/facebookresearch/mae/blob/efb2a8062c206524e35e47d04501ed4f544c0ae8/util/pos_embed.py#L20
+#
+# Copyright 2023 The Qwen team.
+# Copyright 2023 The vLLM team.
+# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
+#
+# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
+# and OPT implementations in this library. It has been modified from its
+# original forms to accommodate minor architectural differences compared
+# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Shared resampler perceiver network used in multimodal models and
+related helpers for sincos positional embeddings.
+
+Example models: Qwen (Qwen-VL), Minicpmv2.0
+"""
+import math
+from functools import partial
+from typing import Callable, Optional, Tuple, Union
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+from torch import nn
+from torch.nn.init import trunc_normal_
+
+from vllm.model_executor.layers.linear import ReplicatedLinear
+
+DEFAULT_LN = partial(nn.LayerNorm, eps=1e-6)
+
+
+def get_abs_pos(abs_pos: torch.Tensor, tgt_size: Union[torch.Tensor,
+ int]) -> torch.Tensor:
+ # abs_pos: L, C
+ # tgt_size: (H, W)
+ # return: M, C
+ src_size = int(math.sqrt(abs_pos.size(0)))
+ dtype = abs_pos.dtype
+ if isinstance(tgt_size, int):
+ tgt_size = (tgt_size, tgt_size)
+ if (src_size == tgt_size[0] and src_size == tgt_size[1]):
+ return abs_pos
+ return (F.interpolate(
+ abs_pos.float().reshape(1, src_size, src_size, -1).permute(0, 3, 1, 2),
+ size=(tgt_size[0], tgt_size[1]),
+ mode="bicubic",
+ align_corners=False,
+ ).permute(0, 2, 3, 1).flatten(0, 2).to(dtype=dtype))
+
+
+# sin/cos positional embedding helpers are adapted from:
+# https://github.com/facebookresearch/mae/blob/efb2a8062c206524e35e47d04501ed4f544c0ae8/util/pos_embed.py#L20
+def get_1d_sincos_pos_embed_from_grid(
+ embed_dim: int, pos: np.ndarray,
+ version: Tuple[int, int] = (2, 0)) -> torch.Tensor:
+ """
+ embed_dim: output dimension for each position
+ pos: a list of positions to be encoded: size (M,) / (H, W)
+ out: (M, D) / (H, W, D)
+ """
+ assert embed_dim % 2 == 0
+ omega = np.arange(embed_dim // 2, dtype=np.float32)
+ omega /= embed_dim / 2.0
+ omega = 1.0 / 10000**omega # (D/2,)
+
+ if version == (2, 0):
+ pos = pos.reshape(-1) # (M,)
+ out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
+ emb_sin = np.sin(out) # (M, D/2)
+ emb_cos = np.cos(out) # (M, D/2)
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
+ else:
+ out = np.einsum("hw,d->hwd", pos, omega) # (H, W, D/2), outer product
+ emb_sin = np.sin(out) # (H, W, D/2)
+ emb_cos = np.cos(out) # (H, W, D/2)
+ emb = np.concatenate([emb_sin, emb_cos], axis=-1) # (H, W, D)
+ return emb
+
+
+def get_2d_sincos_pos_embed_from_grid(
+ embed_dim: int, grid: np.ndarray,
+ version: Tuple[int, int] = (2, 0)) -> torch.Tensor:
+ assert embed_dim % 2 == 0
+
+ # use half of dimensions to encode grid_h
+ emb_h = get_1d_sincos_pos_embed_from_grid(
+ embed_dim // 2, grid[0], version) # (H*W, D/2) or (H, W, D/2)
+ emb_w = get_1d_sincos_pos_embed_from_grid(
+ embed_dim // 2, grid[1], version) # (H*W, D/2) or (H, W, D/2)
+
+ if version == (2, 0):
+ emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
+ else:
+ emb = np.concatenate([emb_h, emb_w], axis=-1) # (H, W, D)
+ return emb
+
+
+def get_2d_sincos_pos_embed(
+ embed_dim: int,
+ grid_size: Union[int, Tuple[int, int]],
+ cls_token: bool = False,
+ version: Tuple[int, int] = (2, 0),
+) -> torch.Tensor:
+ """
+ grid_size: int of the grid height and width
+ return:
+ pos_embed: [grid_size*grid_size, embed_dim] or
+ [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
+ """
+ if isinstance(grid_size, int):
+ grid_h_size, grid_w_size = grid_size, grid_size
+ else:
+ grid_h_size, grid_w_size = grid_size[0], grid_size[1]
+
+ grid_h = np.arange(grid_h_size, dtype=np.float32)
+ grid_w = np.arange(grid_w_size, dtype=np.float32)
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
+ grid = np.stack(grid, axis=0)
+ assert isinstance(grid, np.ndarray) and \
+ grid.shape == (2, grid_h_size, grid_w_size)
+
+ if version == (2, 0):
+ grid = grid.reshape([2, 1, grid_h_size, grid_w_size])
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid, version)
+ if cls_token:
+ pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed],
+ axis=0)
+ else:
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid, version)
+ return pos_embed
+
+
+class BaseResampler(nn.Module):
+ """
+ A 2D perceiver-resampler network with one cross attention layers by
+ (grid_size**2) learnable queries and 2d sincos pos_emb.
+ Outputs:
+ A tensor with the shape of (grid_size**2, embed_dim)
+ """
+
+ def __init__(
+ self,
+ num_queries: int,
+ embed_dim: int,
+ num_heads: int,
+ kv_dim: Optional[int] = None,
+ norm_layer: Callable[[int], nn.LayerNorm] = DEFAULT_LN,
+ do_post_projection: bool = True,
+ ) -> None:
+ super().__init__()
+
+ self.num_queries = num_queries
+ self.embed_dim = embed_dim
+ self.num_heads = num_heads
+
+ self.query = nn.Parameter(torch.zeros(self.num_queries, embed_dim))
+ trunc_normal_(self.query, std=0.02)
+ if kv_dim is not None and kv_dim != embed_dim:
+ self.kv_proj = ReplicatedLinear(kv_dim, embed_dim, bias=False)
+ else:
+ # Maintain the same return value with ReplicatedLinear.forward
+ self.kv_proj = lambda *args, **kwargs: ( # type: ignore # noqa
+ nn.Identity()(*args, **kwargs),
+ None,
+ )
+ self.attn = nn.MultiheadAttention(embed_dim, num_heads)
+ self.ln_q = norm_layer(embed_dim)
+ self.ln_kv = norm_layer(embed_dim)
+ self.do_post_projection = do_post_projection
+ self.ln_post = norm_layer(embed_dim) if do_post_projection else None
+ self.proj = nn.Parameter(
+ (embed_dim**-0.5) *
+ torch.randn(embed_dim, embed_dim)) if do_post_projection else None
+
+ def _init_weights(self, m: nn.Module) -> None:
+ if isinstance(m, nn.Linear):
+ trunc_normal_(m.weight, std=0.02)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.LayerNorm):
+ nn.init.constant_(m.bias, 0)
+ nn.init.constant_(m.weight, 1.0)
+
+ def _repeat(self, query, N: int):
+ return query.unsqueeze(1).repeat(1, N, 1)
+
+
+class Resampler2(BaseResampler):
+ """Resampler-perceiver network to be used for a variety of model types,
+ e.g., Qwen-vl / Minicpmv 2.0. The main difference is the addition of the
+ do_post_projection arg, which indicates whether or not there should be
+ a post layer normalization and projector after the attention. This is
+ present in minicpmv2.0, but not qwen-vl.
+ """
+
+ def __init__(
+ self,
+ grid_size: int,
+ embed_dim: int,
+ num_heads: int,
+ kv_dim: Optional[int] = None,
+ norm_layer: Callable[[int], nn.LayerNorm] = DEFAULT_LN,
+ adaptive: bool = False,
+ do_post_projection: bool = True,
+ ) -> None:
+ super().__init__(grid_size**2,
+ embed_dim,
+ num_heads,
+ kv_dim,
+ norm_layer,
+ do_post_projection=do_post_projection)
+
+ self.adaptive = adaptive
+ pos_embed_arr = get_2d_sincos_pos_embed(embed_dim,
+ grid_size,
+ version=(2, 0))
+
+ self.pos_embed = nn.Parameter(
+ torch.from_numpy(pos_embed_arr).requires_grad_(False))
+
+ self.apply(self._init_weights)
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ tgt_sizes: Optional[torch.Tensor] = None,
+ attn_mask: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ if tgt_sizes is None:
+ tgt_sizes = int(math.sqrt(x.size(1)))
+ if self.adaptive:
+ pos_embed_arr = get_2d_sincos_pos_embed(self.embed_dim,
+ tgt_sizes,
+ version=(2, 0))
+ pos_embed = torch.from_numpy(pos_embed_arr).to(device=x.device,
+ dtype=x.dtype)
+ else:
+ pos_embed = get_abs_pos(self.pos_embed,
+ tgt_sizes).to(device=x.device,
+ dtype=x.dtype)
+
+ x, _ = self.kv_proj(x)
+ x = self.ln_kv(x).permute(1, 0, 2)
+
+ N = x.shape[1]
+ q = self.ln_q(self.query)
+ out = self.attn(
+ self._repeat(q, N) + self.pos_embed.unsqueeze(1),
+ x + pos_embed.unsqueeze(1),
+ x,
+ attn_mask=attn_mask,
+ )[0]
+ x = out.permute(1, 0, 2)
+ if self.do_post_projection:
+ x = self.ln_post(x)
+ x = x @ self.proj
+ return x
diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py
index e30370596496a..4db847029566f 100644
--- a/vllm/model_executor/models/__init__.py
+++ b/vllm/model_executor/models/__init__.py
@@ -51,7 +51,6 @@
"PhiForCausalLM": ("phi", "PhiForCausalLM"),
"Phi3ForCausalLM": ("llama", "LlamaForCausalLM"),
"PhiMoEForCausalLM": ("phimoe", "PhiMoEForCausalLM"),
- "QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
"Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
"Qwen2MoeForCausalLM": ("qwen2_moe", "Qwen2MoeForCausalLM"),
"RWForCausalLM": ("falcon", "FalconForCausalLM"),
@@ -88,6 +87,7 @@
"PaliGemmaForConditionalGeneration"),
"Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
"UltravoxModel": ("ultravox", "UltravoxModel"),
+ "QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
}
_CONDITIONAL_GENERATION_MODELS = {
"BartModel": ("bart", "BartForConditionalGeneration"),
diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py
index dd10729b9ffb5..f8be9490ee55d 100644
--- a/vllm/model_executor/models/minicpmv.py
+++ b/vllm/model_executor/models/minicpmv.py
@@ -26,11 +26,9 @@
from array import array
from functools import partial
from typing import (Any, Callable, Iterable, List, Mapping, Optional, Tuple,
- TypedDict, Union)
+ TypedDict)
-import numpy as np
import torch
-import torch.nn.functional as F
import torch.types
from PIL import Image
from torch import nn
@@ -44,6 +42,8 @@
from vllm.model_executor.layers.linear import ReplicatedLinear
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
+from vllm.model_executor.layers.resampler import (Resampler2,
+ get_2d_sincos_pos_embed)
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.model_loader.utils import set_default_torch_dtype
@@ -98,101 +98,6 @@ class MiniCPMVImagePixelInputs(TypedDict):
DEFAULT_LN = partial(nn.LayerNorm, eps=1e-6)
-def get_abs_pos(abs_pos: torch.Tensor, tgt_size: torch.Tensor):
- # abs_pos: L, C
- # tgt_size: (H, W)
- # return: M, C
- src_size = int(math.sqrt(abs_pos.size(0)))
- # tgt_size = int(math.sqrt(tgt_size))
- dtype = abs_pos.dtype
-
- return (F.interpolate(
- abs_pos.float().reshape(1, src_size, src_size, -1).permute(0, 3, 1, 2),
- size=(tgt_size[0], tgt_size[1]),
- mode="bicubic",
- align_corners=False,
- ).permute(0, 2, 3, 1).flatten(0, 2).to(dtype=dtype))
-
-
-# https://github.com/facebookresearch/mae/blob/efb2a8062c206524e35e47d04501ed4f544c0ae8/util/pos_embed.py#L20
-def get_2d_sincos_pos_embed(
- embed_dim: int,
- grid_size: Union[int, Tuple[int, int]],
- cls_token: bool = False,
- version: Tuple[int, int] = (2, 0),
-):
- """
- grid_size: int of the grid height and width
- return:
- pos_embed: [grid_size*grid_size, embed_dim] or
- [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
- """
- if isinstance(grid_size, int):
- grid_h_size, grid_w_size = grid_size, grid_size
- else:
- grid_h_size, grid_w_size = grid_size[0], grid_size[1]
-
- grid_h = np.arange(grid_h_size, dtype=np.float32)
- grid_w = np.arange(grid_w_size, dtype=np.float32)
- grid = np.meshgrid(grid_w, grid_h) # here w goes first
- grid = np.stack(grid, axis=0)
-
- if version == (2, 0):
- grid = grid.reshape([2, 1, grid_h_size, grid_w_size])
- pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid, version)
- if cls_token:
- pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed],
- axis=0)
- else:
- pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid, version)
- return pos_embed
-
-
-def get_2d_sincos_pos_embed_from_grid(embed_dim: int,
- grid: np.ndarray,
- version: Tuple[int, int] = (2, 0)):
- assert embed_dim % 2 == 0
-
- # use half of dimensions to encode grid_h
- emb_h = get_1d_sincos_pos_embed_from_grid(
- embed_dim // 2, grid[0], version) # (H*W, D/2) or (H, W, D/2)
- emb_w = get_1d_sincos_pos_embed_from_grid(
- embed_dim // 2, grid[1], version) # (H*W, D/2) or (H, W, D/2)
-
- if version == (2, 0):
- emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
- else:
- emb = np.concatenate([emb_h, emb_w], axis=-1) # (H, W, D)
- return emb
-
-
-def get_1d_sincos_pos_embed_from_grid(embed_dim: int,
- pos: np.ndarray,
- version: Tuple[int, int] = (2, 0)):
- """
- embed_dim: output dimension for each position
- pos: a list of positions to be encoded: size (M,) / (H, W)
- out: (M, D) / (H, W, D)
- """
- assert embed_dim % 2 == 0
- omega = np.arange(embed_dim // 2, dtype=np.float32)
- omega /= embed_dim / 2.0
- omega = 1.0 / 10000**omega # (D/2,)
-
- if version == (2, 0):
- pos = pos.reshape(-1) # (M,)
- out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
- emb_sin = np.sin(out) # (M, D/2)
- emb_cos = np.cos(out) # (M, D/2)
- emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
- else:
- out = np.einsum("hw,d->hwd", pos, omega) # (H, W, D/2), outer product
- emb_sin = np.sin(out) # (H, W, D/2)
- emb_cos = np.cos(out) # (H, W, D/2)
- emb = np.concatenate([emb_sin, emb_cos], axis=-1) # (H, W, D)
- return emb
-
-
class BaseResampler(nn.Module):
"""
A 2D perceiver-resampler network with one cross attention layers by
@@ -245,62 +150,6 @@ def _repeat(self, query, N: int):
return query.unsqueeze(1).repeat(1, N, 1)
-class Resampler2(BaseResampler):
-
- def __init__(
- self,
- grid_size: int,
- embed_dim: int,
- num_heads: int,
- kv_dim: Optional[int] = None,
- norm_layer: Callable[[int], nn.LayerNorm] = DEFAULT_LN,
- adaptive: bool = False,
- ) -> None:
- super().__init__(grid_size**2, embed_dim, num_heads, kv_dim,
- norm_layer)
-
- self.adaptive = adaptive
- pos_embed_arr = get_2d_sincos_pos_embed(embed_dim,
- grid_size,
- version=(2, 0))
- self.pos_embed = nn.Parameter(
- torch.from_numpy(pos_embed_arr).float()).requires_grad_(False)
-
- self.apply(self._init_weights)
-
- def forward(
- self,
- x: torch.Tensor,
- tgt_sizes: torch.Tensor,
- attn_mask: Optional[torch.Tensor] = None,
- ):
- if self.adaptive:
- pos_embed_arr = get_2d_sincos_pos_embed(self.embed_dim,
- tgt_sizes,
- version=(2, 0))
- pos_embed = torch.from_numpy(pos_embed_arr).to(device=x.device,
- dtype=x.dtype)
- else:
- pos_embed = get_abs_pos(self.pos_embed, tgt_sizes)
-
- x, _ = self.kv_proj(x)
- x = self.ln_kv(x).permute(1, 0, 2)
-
- N = x.shape[1]
- q = self.ln_q(self.query)
- out = self.attn(
- self._repeat(q, N) + self.pos_embed.unsqueeze(1),
- x + pos_embed.unsqueeze(1),
- x,
- attn_mask=attn_mask,
- )[0]
- x = out.permute(1, 0, 2)
-
- x = self.ln_post(x)
- x = x @ self.proj
- return x
-
-
class Resampler2_5(BaseResampler):
def __init__(
@@ -782,7 +631,8 @@ def init_resampler(self, embed_dim: int, vision_dim: int) -> nn.Module:
num_heads=embed_dim // 128,
grid_size=int(math.sqrt(self.config.query_num)),
kv_dim=vision_dim,
- adaptive=True,
+ adaptive=False,
+ do_post_projection=True,
)
return resampler
diff --git a/vllm/model_executor/models/qwen.py b/vllm/model_executor/models/qwen.py
index 8298e3bac4465..a726ec10984c0 100644
--- a/vllm/model_executor/models/qwen.py
+++ b/vllm/model_executor/models/qwen.py
@@ -4,36 +4,402 @@
# Copyright (c) Alibaba Cloud.
# LICENSE: https://huggingface.co/Qwen/Qwen-7B/blob/main/LICENSE
"""Inference-only QWen model compatible with HuggingFace weights."""
-from typing import Any, Dict, Iterable, List, Optional, Tuple
+import math
+import re
+from array import array
+from functools import partial
+from typing import (Any, Callable, Dict, Iterable, List, Literal, Mapping,
+ Optional, Tuple, TypedDict, Union)
+
+import numpy as np
import torch
+from PIL import Image
from torch import nn
+from torchvision import transforms
+from torchvision.transforms import InterpolationMode
from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata
-from vllm.config import CacheConfig
+from vllm.config import CacheConfig, MultiModalConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
-from vllm.model_executor.layers.activation import SiluAndMul
+from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
+from vllm.logger import init_logger
+from vllm.model_executor.layers.activation import SiluAndMul, get_act_fn
from vllm.model_executor.layers.layernorm import RMSNorm
-from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
+from vllm.model_executor.layers.linear import (ColumnParallelLinear,
+ MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
+from vllm.model_executor.layers.resampler import Resampler2, get_abs_pos
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
+from vllm.model_executor.models.interfaces import SupportsMultiModal
from vllm.model_executor.sampling_metadata import SamplingMetadata
-from vllm.sequence import IntermediateTensors
-from vllm.utils import print_warning_once
+from vllm.multimodal import MULTIMODAL_REGISTRY
+from vllm.multimodal.base import MultiModalInputs
+from vllm.multimodal.utils import cached_get_tokenizer
+from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors,
+ SequenceData)
+
+from .utils import flatten_bn, is_pp_missing_parameter, make_layers
+
+logger = init_logger(__name__)
+
+# NOTE: Qwen models have a few other special tags, e.g., ref, bbox, quad;
+# for the time being, these tags are not considered as special at encoding
+# time. This may change as VLLMs multimodal API changes in the future.
+IMG_START = ""
+IMG_END = ""
+IMG_PAD = ""
+# Image context is fixed at 256 for all images
+MAX_QWEN_IMG_TOKENS = 256
+# Image normalization params
+CLIP_MEAN = (0.48145466, 0.4578275, 0.40821073)
+CLIP_STD = (0.26862954, 0.26130258, 0.27577711)
+
+
+class QwenImagePixelInputs(TypedDict):
+ type: Literal["pixel_values"]
+ data: torch.Tensor
+ """
+ Shape: `(batch_size * num_images, 3, image_size, image_size)`
+
+ Note that image_size is the value in the vision config to which we resize
+ the image to in the normalization transform. Currently multi-image support
+ can only be leveraged by passing image embeddings directly.
+ """
+
+
+class QwenImageEmbeddingInputs(TypedDict):
+ type: Literal["image_embeds"]
+ data: torch.Tensor
+ """Shape: `(batch_size * num_images, 256, hidden_size)`
+
+ `hidden_size` must match the hidden size of the language model backbone
+ and is stored in the visual config of the model if we have one.
+ """
+
+
+QwenImageInputs = Union[QwenImagePixelInputs, QwenImageEmbeddingInputs]
+
+
+class VisualAttention(nn.Module):
+ """self-attention layer class.
+ Self-attention layer takes input with size [s, b, h]
+ and returns output of the same size.
+ """
+
+ def __init__(
+ self,
+ embed_dim: int,
+ num_heads: int,
+ bias: bool = True,
+ kdim: Optional[int] = None,
+ vdim: Optional[int] = None,
+ ):
+ super().__init__()
+ self.embed_dim = embed_dim
+ self.kdim = kdim if kdim is not None else embed_dim
+ self.vdim = vdim if vdim is not None else embed_dim
+ self._qkv_same_embed_dim = self.kdim == embed_dim \
+ and self.vdim == embed_dim
+
+ self.num_heads = num_heads
+
+ # Per attention head and per partition values.
+ assert embed_dim % num_heads == 0
+ self.hidden_size_per_attention_head = embed_dim // num_heads
+ self.num_attention_heads_per_partition = num_heads
+ self.hidden_size_per_partition = embed_dim
+
+ # Strided linear layer.
+ assert self._qkv_same_embed_dim, \
+ 'Visual Attention implementation only supports self-attention'
+ self.in_proj = nn.Linear(embed_dim, 3 * embed_dim)
+ self.out_proj = nn.Linear(embed_dim, embed_dim)
+ self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ attn_mask: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ # query/key/value: [sq, b, h]
+ sq, b, _ = x.size()
+ mixed_x_layer = self.in_proj(x)
+
+ # [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn]
+ new_tensor_shape = mixed_x_layer.size()[:-1] + \
+ (self.num_attention_heads_per_partition,
+ 3 * self.hidden_size_per_attention_head)
+ mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)
+
+ # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]
+ query_layer, key_layer, value_layer = mixed_x_layer.split(
+ self.hidden_size_per_attention_head, dim=-1)
+
+ # [sq, b, np, hn] -> [sq, b * np, hn]
+ query_layer = query_layer.view(
+ sq, b * self.num_attention_heads_per_partition,
+ self.hidden_size_per_attention_head).transpose(0, 1)
+ # [sk, b, np, hn] -> [sk, b * np, hn]
+ key_layer = key_layer.view(
+ sq, b * self.num_attention_heads_per_partition,
+ self.hidden_size_per_attention_head).transpose(0, 1)
+
+ q_scaled = query_layer / self.norm_factor
+ if attn_mask is not None:
+ attention_probs = torch.baddbmm(attn_mask, q_scaled,
+ key_layer.transpose(-2, -1))
+ else:
+ attention_probs = torch.bmm(q_scaled, key_layer.transpose(-2, -1))
+ attention_probs = attention_probs.softmax(dim=-1)
+
+ value_layer = value_layer.view(
+ sq, b * self.num_attention_heads_per_partition,
+ self.hidden_size_per_attention_head).transpose(0, 1)
+
+ # matmul: [b * np, sq, hn]
+ context_layer = torch.bmm(attention_probs, value_layer)
-from .utils import is_pp_missing_parameter, make_layers
+ # change view [b, np, sq, hn]
+ context_layer = context_layer.view(
+ b, self.num_attention_heads_per_partition, sq,
+ self.hidden_size_per_attention_head)
+
+ # [b, np, sq, hn] --> [sq, b, np, hn]
+ context_layer = context_layer.permute(2, 0, 1, 3).contiguous()
+
+ # [sq, b, np, hn] --> [sq, b, hp]
+ new_context_layer_shape = context_layer.size()[:-2] + \
+ (self.hidden_size_per_partition,)
+ context_layer = context_layer.view(*new_context_layer_shape)
+
+ output = self.out_proj(context_layer)
+
+ return output
+
+
+class QwenVMLP(nn.Module):
+ """MLP for the visual component of the Qwen model."""
+
+ def __init__(
+ self,
+ hidden_size: int,
+ intermediate_size: int,
+ quant_config: Optional[QuantizationConfig] = None,
+ ):
+ super().__init__()
+ self.c_fc = ColumnParallelLinear(hidden_size,
+ intermediate_size,
+ bias=True,
+ quant_config=quant_config)
+ self.act_fn = get_act_fn("gelu", quant_config, intermediate_size)
+ self.c_proj = RowParallelLinear(
+ intermediate_size,
+ hidden_size,
+ bias=True,
+ quant_config=quant_config,
+ )
+
+ def forward(self, x):
+ x, _ = self.c_fc(x)
+ x = self.act_fn(x)
+ x, _ = self.c_proj(x)
+ return x
+
+
+class VisualAttentionBlock(nn.Module):
+
+ def __init__(
+ self,
+ d_model: int,
+ n_head: int,
+ mlp_ratio: float = 4.0,
+ norm_layer: Callable = nn.LayerNorm,
+ quant_config: Optional[QuantizationConfig] = None,
+ ):
+ super().__init__()
+
+ self.ln_1 = norm_layer(d_model)
+ self.ln_2 = norm_layer(d_model)
+ mlp_width = int(d_model * mlp_ratio)
+ self.attn = VisualAttention(d_model, n_head)
+ self.mlp = QwenVMLP(
+ hidden_size=d_model,
+ intermediate_size=mlp_width,
+ quant_config=quant_config,
+ )
+
+ def attention(
+ self,
+ x: torch.Tensor,
+ attn_mask: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ attn_mask = attn_mask.to(x.dtype) if attn_mask is not None else None
+ return self.attn(x, attn_mask=attn_mask)
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ attn_mask: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ x = x + self.attention(self.ln_1(x), attn_mask=attn_mask)
+ x = x + self.mlp(self.ln_2(x))
+ return x
+
+
+class TransformerBlock(nn.Module):
+
+ def __init__(
+ self,
+ width: int,
+ layers: int,
+ heads: int,
+ mlp_ratio: float = 4.0,
+ norm_layer: Callable = nn.LayerNorm,
+ quant_config: Optional[QuantizationConfig] = None,
+ ):
+ super().__init__()
+ self.width = width
+ self.layers = layers
+
+ self.resblocks = nn.ModuleList([
+ VisualAttentionBlock(width,
+ heads,
+ mlp_ratio,
+ norm_layer=norm_layer,
+ quant_config=quant_config)
+ for _ in range(layers)
+ ])
+
+ def get_cast_dtype(self) -> torch.dtype:
+ return self.resblocks[0].mlp.c_fc.weight.dtype
+
+ def get_cast_device(self) -> torch.device:
+ return self.resblocks[0].mlp.c_fc.weight.device
+
+ def forward(self,
+ x: torch.Tensor,
+ attn_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
+ for r in self.resblocks:
+ x = r(x, attn_mask=attn_mask)
+ return x
+
+
+class VisionTransformer(nn.Module):
+
+ def __init__(self,
+ image_size: int,
+ patch_size: int,
+ width: int,
+ layers: int,
+ heads: int,
+ mlp_ratio: float,
+ n_queries: int = 256,
+ output_dim: int = 512,
+ image_start_id: int = 151857,
+ quant_config: Optional[QuantizationConfig] = None,
+ **kwargs):
+ super().__init__()
+ image_height, image_width = self.image_size = (image_size, image_size)
+ patch_height, patch_width = self.patch_size = (patch_size, patch_size)
+ self.grid_size = (image_height // patch_height,
+ image_width // patch_width)
+ self.output_dim = output_dim
+ self.conv1 = nn.Conv2d(in_channels=3,
+ out_channels=width,
+ kernel_size=patch_size,
+ stride=patch_size,
+ bias=False)
+
+ # class embeddings and positional embeddings
+ scale = width**-0.5
+ self.positional_embedding = nn.Parameter(scale *
+ torch.randn(256, width))
+
+ norm_layer = partial(nn.LayerNorm, eps=1e-6)
+
+ self.ln_pre = norm_layer(width)
+ self.transformer = TransformerBlock(width,
+ layers,
+ heads,
+ mlp_ratio,
+ norm_layer=norm_layer,
+ quant_config=quant_config)
+
+ self.attn_pool = Resampler2(
+ grid_size=int(math.sqrt(n_queries)),
+ embed_dim=output_dim,
+ num_heads=output_dim // 128,
+ kv_dim=width,
+ norm_layer=norm_layer,
+ adaptive=False,
+ do_post_projection=False,
+ ).to(
+ device=self.positional_embedding.device,
+ dtype=self.positional_embedding.dtype,
+ )
+
+ self.ln_post = norm_layer(output_dim)
+ self.proj = nn.Parameter(
+ (output_dim**-0.5) * torch.randn(output_dim, output_dim))
+ self.image_start_id = image_start_id
+ self.image_end_id = image_start_id + 1
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x = x.to(
+ dtype=self.transformer.get_cast_dtype(),
+ device=self.transformer.get_cast_device(),
+ )
+
+ # to patches
+ x = self.conv1(x) # shape = [*, width, grid, grid]
+ x = x.reshape(x.shape[0], x.shape[1],
+ -1) # shape = [*, width, grid ** 2]
+ x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
+
+ x = x + get_abs_pos(self.positional_embedding, int(math.sqrt(
+ x.size(1))))
+
+ x = self.ln_pre(x)
+
+ x = x.permute(1, 0, 2) # NLD -> LND
+ x = self.transformer(x)
+ x = x.permute(1, 0, 2) # LND -> NLD
+
+ x = self.attn_pool(x)
+ x = self.ln_post(x)
+ x = x @ self.proj
+
+ return x
+
+ def get_image_positions(self,
+ input_ids: torch.Tensor) -> Optional[torch.Tensor]:
+ """Given the input IDs, extracts start/stop points corresponding to
+ images.
+
+ args:
+ Returns:
+ Optional torch tensor corresponding to start/stop pairs of images.
+ """
+ if torch.any(input_ids == self.image_start_id):
+ bos_pos = torch.where(input_ids == self.image_start_id)
+ eos_pos = torch.where(input_ids == self.image_end_id)
+ return torch.stack((bos_pos[0], eos_pos[0]), dim=1)
+ return None
class QWenMLP(nn.Module):
+ """MLP for the language component of the Qwen model, which contains a
+ MergedColumnParallelLinear merging 2 outputs via silu activation."""
def __init__(
self,
@@ -56,7 +422,7 @@ def __init__(
"Only silu is supported for now.")
self.act_fn = SiluAndMul()
- def forward(self, x):
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x, _ = self.c_proj(x)
@@ -203,6 +569,9 @@ def __init__(
lambda prefix: QWenBlock(config, cache_config, quant_config),
prefix=f"{prefix}.h")
self.ln_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
+ self.visual = VisionTransformer(**config.visual,
+ quant_config=quant_config) if hasattr(
+ config, "visual") else None
def forward(
self,
@@ -211,9 +580,33 @@ def forward(
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors],
+ pixel_values: Optional[QwenImageInputs],
) -> torch.Tensor:
+ img_pos = None
+ # If pixel / visual embeddings are provided, this is a visual model
+ if pixel_values is not None and self.visual is not None:
+ if pixel_values["type"] != "image_embeds":
+ image_embeds = self.visual(pixel_values["data"])
+ else:
+ image_embeds = pixel_values["data"]
+
+ # features should be of shape (# images, 256, hidden_dim)
+ img_pos = self.visual.get_image_positions(input_ids)
+ if isinstance(
+ img_pos,
+ np.ndarray) and img_pos.shape[0] != image_embeds.shape[0]:
+ raise ValueError(
+ f"Number of placeholders: {img_pos.shape[0]} "
+ f"does not match number of images {image_embeds.shape[0]}."
+ )
+
if get_pp_group().is_first_rank:
hidden_states = self.wte(input_ids)
+ # Merge the image embeddings into the hidden states if actually have
+ # visual features and the corresponding image tokens
+ if img_pos is not None:
+ for idx, (img_bos, img_eos) in enumerate(img_pos):
+ hidden_states[img_bos + 1:img_eos] = image_embeds[idx]
residual = None
else:
assert intermediate_tensors is not None
@@ -237,16 +630,241 @@ def forward(
return hidden_states
-class QWenLMHeadModel(nn.Module):
+def get_image_text(image_num: int, padding: bool) -> str:
+ """Retrieves a placeholder text that when tokenized, will be expanded with
+ image pads.
+
+ Args:
+ image_num: The number of the image that we want a text prompt for.
+ Images should be indexed starting at 1.
+ padding: Whether or not padding should be manually added.
+
+ Returns:
+ Text placeholder prompt for the image being considered.
+ """
+ image_start = f"Picture {image_num}: {IMG_START}"
+ image_end = f"{IMG_END}\n"
+ if not padding:
+ return f"{image_start}{image_end}"
+ return f"{image_start}{MAX_QWEN_IMG_TOKENS * IMG_PAD}{image_end}"
+
+
+def input_processor_for_qwen(ctx: InputContext,
+ llm_inputs: LLMInputs) -> LLMInputs:
+ """Processes the inputs, which may or may not be multimodal.
+ Multimodal inputs will only be processed if the model has a "visual"
+ component in its model config, otherwise they'll be ignored.
+
+ Args:
+ ctx: Context of the loaded model.
+ llm_inputs: LLM inputs which may have a multi_modal_data attribute.
+
+ Returns:
+ If the model is language only or not multimodal inputs were provided,
+ returns llm_inputs unmodified. Otherwise, processes the multimodal
+ images / image embeddings and adds the fixed-length image placeholders.
+ """
+ multi_modal_data = llm_inputs.get("multi_modal_data")
+
+ # Only process images if we have multimodal data and a visual config
+ hf_config = ctx.get_hf_config()
+ if (multi_modal_data is None or "image" not in multi_modal_data
+ or not hasattr(hf_config, "visual")):
+ return llm_inputs
+
+ prompt = llm_inputs.get("prompt")
+ prompt_token_ids = llm_inputs["prompt_token_ids"]
+ model_config = ctx.model_config
+ tokenizer = cached_get_tokenizer(model_config.tokenizer,
+ trust_remote_code=True)
+ image_data = multi_modal_data["image"]
+ if isinstance(image_data, torch.Tensor):
+ num_dims = len(image_data.shape)
+ if num_dims < 2 or num_dims > 3:
+ raise ValueError(
+ f"Expected img embeds to be have 3 dimensions, got {num_dims}")
+ num_images = 1 if num_dims == 2 else image_data.shape[0]
+ else:
+ # TODO - handle multiple image inputs once the API is solidified
+ num_images = 1
+
+ if prompt is None:
+ prompt = tokenizer.decode(prompt_token_ids)
+
+ # Drops anything between / tags; encoding with the tokenizer
+ # will automatically add the image pads for the context.
+ new_prompt, num_matched_images = re.subn(
+ r"(Picture \d*: ).*?(<\/img>\n)",
+ r"\1\2",
+ prompt,
+ )
+
+ if num_matched_images != num_images:
+ logger.warning(
+ "Number of matched image placeholders %s doesn't match the number "
+ "of expected images %s; check your placeholder formatting.",
+ num_matched_images, num_images)
+
+ new_prompt_token_ids = tokenizer.encode(new_prompt)
+
+ return LLMInputs(prompt=new_prompt,
+ prompt_token_ids=new_prompt_token_ids,
+ multi_modal_data=multi_modal_data)
+
+
+def input_mapper_for_qwen(ctx: InputContext, data: object) -> MultiModalInputs:
+ """Maps the input data to its MultiModalInputs (if any).
+
+ Args:
+ ctx: Context of the loaded model.
+ data: data potentially containing image/image embeddings to be mapped
+ to pixel_values in .forward() for a visual QWenLMHeadModel model.
+
+ Returns:
+ MultiModalInputs containing the stacked normalized images tensor or
+ image embeddings.
+ """
+ # Early exit if we have provided an image to a language only Qwen model
+ hf_config = ctx.get_hf_config()
+ if not hasattr(hf_config, "visual"):
+ logger.warning(
+ "Images were provided but this model has no visual config; "
+ "multimodal inputs will not be forwarded to the model.")
+ return MultiModalInputs()
+
+ model_config = ctx.model_config
+ tokenizer = cached_get_tokenizer(model_config.tokenizer,
+ trust_remote_code=True)
+
+ image_pair_tok = tokenizer.encode(IMG_START + IMG_END,
+ add_special_tokens=False,
+ return_tensors="pt").squeeze()
+ image_start_id = image_pair_tok[0]
+ image_end_id = image_pair_tok[-1]
+ if (image_start_id + 1) != image_end_id:
+ raise ValueError(
+ f"Found image end ID {image_end_id}, but expected {IMG_START} + 1")
+ if len(image_pair_tok) != (MAX_QWEN_IMG_TOKENS + 2):
+ raise ValueError(
+ f"Expected image context length of {MAX_QWEN_IMG_TOKENS}, "
+ f"but got {image_pair_tok - 2}")
+
+ hf_config = ctx.get_hf_config()
+ image_size = hf_config.visual["image_size"]
+ img_emb_size = hf_config.visual["output_dim"]
+
+ if isinstance(data, torch.Tensor):
+ # It's expected that our values have already been processed
+ # by the visual transformer; shape is expected to be:
+ # (# images, 256, hidden_size)
+ if len(data.shape) == 2:
+ # Assume only one image embed was provided; unsqueeze the extra dim
+ data = data.unsqueeze(0)
+ if len(data.shape) != 3 or data.shape[
+ 1] != MAX_QWEN_IMG_TOKENS or data.shape[2] != img_emb_size:
+ raise ValueError(
+ "Expected image embeds to be a tensor of shape"
+ f"[# images, {MAX_QWEN_IMG_TOKENS}, {img_emb_size}], but "
+ f"received shape [{data.shape}]")
+ pixel_values = data
+
+ else:
+ transform = build_normalization_transform(image_size)
+ # TODO - handle multiple image inputs once the API is solidified
+ transformed_images = [transform(data)]
+ pixel_values = torch.stack(transformed_images, dim=0)
+ return MultiModalInputs({"pixel_values": pixel_values})
+
+
+def build_normalization_transform(image_size: int) -> transforms.Compose:
+ """Builds a normalization transform which can be applied to one or
+ more input images from which we want to extract visual features.
+
+ Args:
+ image_size: size of the image to be processed for visual embeddings.
+
+ Returns:
+ Callable transform for normalizing and resizing one RGB image.
+ """
+ return transforms.Compose([
+ transforms.Resize((image_size, image_size),
+ interpolation=InterpolationMode.BICUBIC),
+ transforms.ToTensor(),
+ transforms.Normalize(mean=CLIP_MEAN, std=CLIP_STD),
+ ])
+
+
+def dummy_data_for_qwen(
+ ctx: InputContext,
+ seq_len: int,
+ mm_counts: Mapping[str, int],
+) -> Tuple[SequenceData, Optional[Dict]]:
+ """Build dummy data for warming up Qwen models; this will only contain text
+ matching the defaults for VLLM unless the model has a visual config.
+
+ Args:
+ ctx: Context of the loaded model.
+ seq_len: Number of tokens in the text sequence.
+ mm_counts: multimodal data counts.
+
+ Returns:
+ Tuple containing sequential and multimodal data.
+ """
+ hf_config = ctx.get_hf_config()
+
+ # The presence of a visual config indicates this is a multimodal model.
+ # If we don't have it, the model is considered an LLM for warmup purposes.
+ if not hasattr(hf_config, "visual"):
+ seq_data = SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE, [0] * seq_len))
+ mm_data = None
+ return seq_data, mm_data
+
+ # We have a visual component - use images to warm up
+ num_images = mm_counts["image"]
+ model_config = ctx.model_config
+ tokenizer = cached_get_tokenizer(model_config.tokenizer,
+ trust_remote_code=True)
+
+ # Build the image prompts with no imgpads; the tokenizer will add img pads
+ image_prompt = ''.join(
+ [get_image_text(idx, False) for idx in range(1, num_images + 1)])
+ toks = tokenizer.encode(image_prompt, add_special_tokens=False)
+
+ # Make sure we actually get the fixed context size per tok padding
+ num_pads = toks.count(tokenizer.encode(IMG_PAD)[0])
+ if num_pads != (num_images * MAX_QWEN_IMG_TOKENS):
+ raise ValueError(
+ f"Tokenized dummy data should encode {MAX_QWEN_IMG_TOKENS} pads"
+ f" per image, but got {num_pads} pads for {num_images} image(s)"
+ " in total. Are you using a qwen tokenizer?")
+
+ # Ensure the number of tokens is at minimum the sequence length provided
+ if len(toks) < seq_len:
+ toks += [0] * (seq_len - len(toks))
+
+ # Build the input images; width/height doesn't actually matter here since
+ # the data will get resized and the # of tokens per image is constant
+ image = Image.new("RGB", (224, 224), color=0)
+ mm_data = {"image": image if num_images == 1 else [image] * num_images}
+ return SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE, toks)), mm_data
+
+
+@MULTIMODAL_REGISTRY.register_image_input_mapper(input_mapper_for_qwen)
+@MULTIMODAL_REGISTRY.register_max_image_tokens(MAX_QWEN_IMG_TOKENS)
+@INPUT_REGISTRY.register_dummy_data(dummy_data_for_qwen)
+@INPUT_REGISTRY.register_input_processor(input_processor_for_qwen)
+class QWenLMHeadModel(nn.Module, SupportsMultiModal):
def __init__(
self,
config: PretrainedConfig,
+ multimodal_config: MultiModalConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.config = config
+ self.multimodal_config = multimodal_config
self.quant_config = quant_config
self.transformer = QWenModel(config, cache_config, quant_config)
self.lm_head = ParallelLMHead(config.vocab_size,
@@ -257,16 +875,47 @@ def __init__(
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
- def forward(
- self,
- input_ids: torch.Tensor,
- positions: torch.Tensor,
- kv_caches: List[torch.Tensor],
- attn_metadata: AttentionMetadata,
- intermediate_tensors: Optional[IntermediateTensors] = None,
- ) -> torch.Tensor:
+ def _get_image_input_type(
+ self,
+ pixel_values: Optional[torch.Tensor]) -> Optional[QwenImageInputs]:
+ """Determines if the provided pixel_values are normalized pixel values
+ or image embeddings.
+
+ Args:
+ pixel_values: Optional data to processed into visual embeddings.
+
+ Returns:
+ None of the QwenImageInputs type used to determine whether or not
+ the visual transformer needs to process the pixel_values.
+ """
+ if pixel_values is not None and self.transformer.visual is not None:
+ pixel_values = flatten_bn(pixel_values)
+ if len(pixel_values.shape) == 3 and pixel_values.shape[
+ 1] == MAX_QWEN_IMG_TOKENS and pixel_values.shape[
+ 2] == self.config.visual["output_dim"]:
+ return QwenImageEmbeddingInputs(
+ type="image_embeds",
+ data=pixel_values,
+ )
+ else:
+ # If we have the wrong shape, assume we still need to process
+ return QwenImagePixelInputs(
+ type="pixel_values",
+ data=pixel_values,
+ )
+ return None
+
+ def forward(self,
+ input_ids: torch.Tensor,
+ positions: torch.Tensor,
+ kv_caches: List[torch.Tensor],
+ attn_metadata: AttentionMetadata,
+ intermediate_tensors: Optional[IntermediateTensors] = None,
+ pixel_values: Optional[torch.Tensor] = None) -> torch.Tensor:
+ pixel_values = self._get_image_input_type(pixel_values)
hidden_states = self.transformer(input_ids, positions, kv_caches,
- attn_metadata, intermediate_tensors)
+ attn_metadata, intermediate_tensors,
+ pixel_values)
return hidden_states
def make_empty_intermediate_tensors(
@@ -328,15 +977,6 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
- # Skip loading visual weights to support Qwen-VL models
- # in cases with text-only inputs
- # TODO: add support for Qwen-VL
- if (name not in params_dict
- and name.startswith("transformer.visual.")):
- print_warning_once(
- "Only text inputs are allowed. Images won't be handled "
- "until Qwen-VL models are fully supported.")
- continue
# Skip layers on other devices.
if is_pp_missing_parameter(name, self):
continue