diff --git a/docs/source/models/supported_models.md b/docs/source/models/supported_models.md index 793831fd06de..8477158a0040 100644 --- a/docs/source/models/supported_models.md +++ b/docs/source/models/supported_models.md @@ -921,6 +921,13 @@ See [this page](#generative-models) for more information on how to use generativ * ✅︎ * ✅︎ * ✅︎ +- * `SkyworkR1VChatModel` + * Skywork-R1V-38B + * T + I + * `Skywork/Skywork-R1V-38B` + * + * ✅︎ + * ✅︎ - * `UltravoxModel` * Ultravox * T + AE+ diff --git a/examples/offline_inference/vision_language.py b/examples/offline_inference/vision_language.py index 0adbe574370d..572eabe26193 100644 --- a/examples/offline_inference/vision_language.py +++ b/examples/offline_inference/vision_language.py @@ -804,6 +804,41 @@ def run_qwen2_5_vl(questions: list[str], modality: str) -> ModelRequestData: ) +# SkyworkR1V +def run_skyworkr1v(questions: list[str], modality: str) -> ModelRequestData: + assert modality == "image" + + model_name = "Skywork/Skywork-R1V-38B" + + engine_args = EngineArgs( + model=model_name, + trust_remote_code=True, + max_model_len=4096, + disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache, + ) + + tokenizer = AutoTokenizer.from_pretrained(model_name, + trust_remote_code=True) + messages = [[{ + 'role': 'user', + 'content': f"\n{question}" + }] for question in questions] + prompts = tokenizer.apply_chat_template(messages, + tokenize=False, + add_generation_prompt=True) + + # Stop tokens for SkyworkR1V + # https://huggingface.co/Skywork/Skywork-R1V-38B/blob/main/conversation.py + stop_tokens = ["<|end▁of▁sentence|>", "<|endoftext|>"] + stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens] + + return ModelRequestData( + engine_args=engine_args, + prompts=prompts, + stop_token_ids=stop_token_ids, + ) + + model_example_map = { "aria": run_aria, "blip-2": run_blip2, @@ -834,6 +869,7 @@ def run_qwen2_5_vl(questions: list[str], modality: str) -> ModelRequestData: "qwen_vl": run_qwen_vl, "qwen2_vl": run_qwen2_vl, "qwen2_5_vl": run_qwen2_5_vl, + "skywork_chat": run_skyworkr1v, } diff --git a/tests/models/decoder_only/vision_language/test_models.py b/tests/models/decoder_only/vision_language/test_models.py index d500ef5d8b80..0d1d237e5693 100644 --- a/tests/models/decoder_only/vision_language/test_models.py +++ b/tests/models/decoder_only/vision_language/test_models.py @@ -474,6 +474,20 @@ vllm_output_post_proc=model_utils.qwen_vllm_to_hf_output, prompt_path_encoder=model_utils.qwen_prompt_path_encoder, ), + "skywork_r1v": VLMTestInfo( + models=["Skywork/Skywork-R1V-38B"], + test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE), + prompt_formatter=lambda img_prompt: f"<|begin▁of▁sentence|><|User|>\n{img_prompt}<|Assistant|>\n", # noqa: E501 + single_image_prompts=IMAGE_ASSETS.prompts({ + "stop_sign": "\nWhat's the content in the center of the image?", # noqa: E501 + "cherry_blossom": "\nWhat is the season?", + }), + multi_image_prompt="\n\nDescribe the two images in short.", # noqa: E501 + max_model_len=4096, + use_tokenizer_eos=True, + patch_hf_runner=model_utils.skyworkr1v_patch_hf_runner, + marks=[large_gpu_mark(min_gb=80)], + ), ### Tensor parallel / multi-gpu broadcast tests "chameleon-broadcast": VLMTestInfo( models=["facebook/chameleon-7b"], diff --git a/tests/models/decoder_only/vision_language/vlm_utils/model_utils.py b/tests/models/decoder_only/vision_language/vlm_utils/model_utils.py index c84bf6dc15f4..2ddf28aca4f6 100644 --- a/tests/models/decoder_only/vision_language/vlm_utils/model_utils.py +++ b/tests/models/decoder_only/vision_language/vlm_utils/model_utils.py @@ -376,6 +376,63 @@ def __call__(self, text: str, images: Union[Image, list[Image]], return hf_model +def skyworkr1v_patch_hf_runner(hf_model: HfRunner) -> HfRunner: + """Patches and returns an instance of the HfRunner to use for SkyworkR1V.""" + + class SkyworkR1VProcessor: + """A simple processor for SkyworkR1V.""" + + def __init__(self, hf_runner: HfRunner): + self.num_image_token = hf_runner.model.num_image_token + self.tokenizer = hf_runner.tokenizer + + self.config = AutoConfig.from_pretrained(hf_runner.model_name, + trust_remote_code=True) + self.vision_config = self.config.vision_config + self.use_thumbnail = self.config.use_thumbnail + self.min_num = self.config.min_dynamic_patch + self.max_num = self.config.max_dynamic_patch + self.image_size = self.vision_config.image_size + + def __call__(self, text: str, images: Union[Image, list[Image]], + **kwargs): + from vllm.model_executor.models.skyworkr1v import ( + IMG_CONTEXT, IMG_END, IMG_START, + image_to_pixel_values_skyworkr1v) + images = [images] if isinstance(images, Image) else images + pixel_values = [ + image_to_pixel_values_skyworkr1v( + image, + input_size=self.image_size, + min_num=self.min_num, + max_num=self.max_num, + use_thumbnail=self.use_thumbnail, + ) for image in images + ] + num_patches_list = [ + pixel_value.shape[0] for pixel_value in pixel_values + ] + pixel_values = torch.cat(pixel_values, dim=0) + for num_patches in num_patches_list: + context_tokens = IMG_CONTEXT * self.num_image_token \ + * num_patches + image_tokens = IMG_START + context_tokens + IMG_END + text = text.replace('', image_tokens, 1) + prompt = self.tokenizer(text, return_tensors="pt") + prompt.update({"pixel_values": pixel_values}) + return prompt + + img_context_token_id = hf_model.tokenizer.convert_tokens_to_ids( + "") + hf_model.model.img_context_token_id = img_context_token_id + hf_model.processor = SkyworkR1VProcessor(hf_model) + hf_model.model.get_output_embeddings = lambda: \ + hf_model.model.language_model.get_output_embeddings() + hf_model.model.generate = types.MethodType(_internvl_generate, + hf_model.model) + return hf_model + + def internvl_patch_hf_runner(hf_model: HfRunner) -> HfRunner: """Patches and returns an instance of the HfRunner to use for InternVL.""" diff --git a/tests/models/multimodal/processing/test_common.py b/tests/models/multimodal/processing/test_common.py index 078ed21537b8..e4f1d297fc09 100644 --- a/tests/models/multimodal/processing/test_common.py +++ b/tests/models/multimodal/processing/test_common.py @@ -262,22 +262,23 @@ def _test_processing_correctness_mistral( "llava-hf/llava-onevision-qwen2-0.5b-ov-hf", "meta-llama/Llama-3.2-11B-Vision-Instruct", "TIGER-Lab/Mantis-8B-siglip-llama3", - "mistralai/Pixtral-12B-2409", - "mistral-community/pixtral-12b", "openbmb/MiniCPM-Llama3-V-2_5", "openbmb/MiniCPM-o-2_6", "openbmb/MiniCPM-V-2_6", "allenai/Molmo-7B-D-0924", "allenai/Molmo-7B-O-0924", "nvidia/NVLM-D-72B", + "google/paligemma-3b-mix-224", + "google/paligemma2-3b-ft-docci-448", + "mistralai/Pixtral-12B-2409", + "mistral-community/pixtral-12b", "Qwen/Qwen-VL-Chat", "Qwen/Qwen2-VL-2B-Instruct", "Qwen/Qwen2.5-VL-3B-Instruct", "Qwen/Qwen2-Audio-7B-Instruct", + "Skywork/Skywork-R1V-38B", "fixie-ai/ultravox-v0_5-llama-3_2-1b", "openai/whisper-large-v3", - "google/paligemma-3b-mix-224", - "google/paligemma2-3b-ft-docci-448", ]) @pytest.mark.parametrize("hit_rate", [0.3, 0.5, 1.0]) @pytest.mark.parametrize("num_batches", [32]) diff --git a/tests/models/registry.py b/tests/models/registry.py index d7946b75b797..ff0c37a6afd7 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -294,6 +294,7 @@ def check_available_online( "Qwen2VLForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2-VL-2B-Instruct"), # noqa: E501 "Qwen2_5_VLForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2.5-VL-3B-Instruct", # noqa: E501 min_transformers_version="4.49"), # noqa: E501 + "SkyworkR1VChatModel": _HfExamplesInfo("Skywork/Skywork-R1V-38B"), "UltravoxModel": _HfExamplesInfo("fixie-ai/ultravox-v0_5-llama-3_2-1b", # noqa: E501 trust_remote_code=True), # [Encoder-decoder] diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index 73a69d3037f7..24382142768b 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -496,7 +496,7 @@ def _placeholder_str(self, modality: ModalityStr, return self._cached_token_str(self._tokenizer, hf_config.image_token_index) if model_type in ("chameleon", "deepseek_vl_v2", "internvl_chat", - "NVLM_D", "h2ovl_chat"): + "skywork_chat", "NVLM_D", "h2ovl_chat"): return "" if model_type == "mllama": return "<|image|>" diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 7797d9a2cc20..9288a4b81748 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -190,6 +190,7 @@ # [Encoder-decoder] "Florence2ForConditionalGeneration": ("florence2", "Florence2ForConditionalGeneration"), # noqa: E501 "MllamaForConditionalGeneration": ("mllama", "MllamaForConditionalGeneration"), # noqa: E501 + "SkyworkR1VChatModel": ("skyworkr1v", "SkyworkR1VChatModel"), "WhisperForConditionalGeneration": ("whisper", "WhisperForConditionalGeneration"), # noqa: E501 } diff --git a/vllm/model_executor/models/skyworkr1v.py b/vllm/model_executor/models/skyworkr1v.py new file mode 100644 index 000000000000..ac5de0e36b89 --- /dev/null +++ b/vllm/model_executor/models/skyworkr1v.py @@ -0,0 +1,1014 @@ +# SPDX-License-Identifier: Apache-2.0 + +# adapted from https://huggingface.co/Skywork/Skywork-R1V-38B/blob/main/modeling_skywork_chat.py +# -------------------------------------------------------- +# SkyworkR1V +# Copyright (c) 2025 Skywork +# Licensed under The MIT License [see LICENSE for details] +# -------------------------------------------------------- +from abc import ABC, abstractmethod +from collections.abc import Iterable, Mapping, Sequence +from functools import cached_property +from typing import Literal, Optional, Set, Tuple, TypedDict, TypeVar, Union + +import torch +import torch.nn as nn +import torchvision.transforms as T +from PIL import Image +from transformers import BatchEncoding, PretrainedConfig, TensorType + +from vllm.config import VllmConfig +from vllm.model_executor.layers.linear import ReplicatedLinear +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.quantization.awq import AWQConfig +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler +from vllm.model_executor.models.intern_vit import (InternVisionModel, + InternVisionPatchModel) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs, + NestedTensors) +from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems, + ImageSize, MultiModalDataItems) +from vllm.multimodal.processing import (BaseMultiModalProcessor, + BaseProcessingInfo, PromptReplacement, + PromptUpdate, PromptUpdateDetails) +from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs +from vllm.sequence import IntermediateTensors +from vllm.transformers_utils.tokenizer import AnyTokenizer + +from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP +from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model, + maybe_prefix, merge_multimodal_embeddings) +from .vision import scatter_patch_features, select_patch_features + +IMG_START = '' +IMG_END = '' +IMG_CONTEXT = '' + +IMAGENET_MEAN = (0.485, 0.456, 0.406) +IMAGENET_STD = (0.229, 0.224, 0.225) + + +class SkyworkR1VImagePixelInputs(TypedDict): + type: Literal["pixel_values"] + pixel_values_flat: torch.Tensor + """ + Shape: + `(batch_size * num_images * (1 + num_patches), num_channels, height, width)` + """ + + num_patches: torch.Tensor + """Shape: `(batch_size * num_images)`""" + + embed_is_patch: Union[torch.Tensor, list[torch.Tensor]] + """ + A boolean mask indicating which image embeddings correspond + to patch tokens. + + Shape: `(batch_size * num_images, num_embeds)` + """ + + +class SkyworkR1VImageEmbeddingInputs(TypedDict): + type: Literal["image_embeds"] + data: Union[torch.Tensor, list[torch.Tensor]] + """ + A tensor of shape `(num_images, total_image_feature_size, hidden_size)` + or a list of tensors of shape `(total_image_feature_size, hidden_size)` + + `hidden_size` must match the hidden size of language model backbone. + """ + + +SkyworkR1VImageInputs = Union[SkyworkR1VImagePixelInputs, + SkyworkR1VImageEmbeddingInputs] + + +# adapted from https://huggingface.co/Skywork/Skywork-R1V-38B/ +def build_transform(input_size: int): + MEAN, STD = IMAGENET_MEAN, IMAGENET_STD + return T.Compose([ + T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img), + T.Resize((input_size, input_size), + interpolation=T.InterpolationMode.BICUBIC), + T.ToTensor(), + T.Normalize(mean=MEAN, std=STD) + ]) + + +# adapted from https://huggingface.co/Skywork/Skywork-R1V-38B/ +def find_closest_aspect_ratio( + aspect_ratio: float, + target_ratios: list[tuple[int, int]], + *, + width: int, + height: int, + image_size: int, +) -> tuple[int, int]: + best_ratio_diff = float('inf') + best_ratio = (1, 1) + area = width * height + for ratio in target_ratios: + target_aspect_ratio = ratio[0] / ratio[1] + ratio_diff = abs(aspect_ratio - target_aspect_ratio) + if ratio_diff < best_ratio_diff: + best_ratio_diff = ratio_diff + best_ratio = ratio + elif ratio_diff == best_ratio_diff: + if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]: + best_ratio = ratio + return best_ratio + + +def resolve_skyworkr1v_min_max_num( + *, + min_dynamic_patch: int, + max_dynamic_patch: int, + dynamic_image_size: bool, + use_thumbnail: bool, +) -> tuple[int, int]: + min_dynamic_patch = min_dynamic_patch if dynamic_image_size else 1 + max_dynamic_patch = max_dynamic_patch if dynamic_image_size else 1 + + if use_thumbnail and max_dynamic_patch != 1: + max_dynamic_patch += 1 + + return min_dynamic_patch, max_dynamic_patch + + +def get_skyworkr1v_target_ratios( + min_num: int, + max_num: int, +) -> list[tuple[int, int]]: + target_ratios = {(i, j) + for n in range(min_num, max_num + 1) + for i in range(1, n + 1) + for j in range(1, n + 1) if min_num <= i * j <= max_num} + return sorted(target_ratios, key=lambda x: x[0] * x[1]) + + +def calculate_skyworkr1v_targets( + *, + orig_width: int, + orig_height: int, + target_ratios: list[tuple[int, int]], + image_size: int, + use_thumbnail: bool, +) -> tuple[int, int, int]: + aspect_ratio = orig_width / orig_height + + # find the closest aspect ratio to the target + target_aspect_ratio = find_closest_aspect_ratio( + aspect_ratio, + target_ratios, + width=orig_width, + height=orig_height, + image_size=image_size, + ) + + # calculate the target width and height + target_width = image_size * target_aspect_ratio[0] + target_height = image_size * target_aspect_ratio[1] + blocks = target_aspect_ratio[0] * target_aspect_ratio[1] + + # add thumbnail image if num_blocks != 1 + if use_thumbnail and blocks != 1: + blocks += 1 + + return blocks, target_width, target_height + + +def dynamic_preprocess_skyworkr1v( + image: Image.Image, + *, + target_ratios: list[tuple[int, int]], + image_size: int, + use_thumbnail: bool, +) -> list[Image.Image]: + orig_width, orig_height = image.size + + # calculate the number of blocks without thumbnail + blocks, target_width, target_height = calculate_skyworkr1v_targets( + orig_width=orig_width, + orig_height=orig_height, + target_ratios=target_ratios, + image_size=image_size, + use_thumbnail=False, + ) + + # resize the image + resized_img = image.resize((target_width, target_height)) + processed_images = [] + for i in range(blocks): + box = ((i % (target_width // image_size)) * image_size, + (i // (target_width // image_size)) * image_size, + ((i % (target_width // image_size)) + 1) * image_size, + ((i // (target_width // image_size)) + 1) * image_size) + # split the image + split_img = resized_img.crop(box) + processed_images.append(split_img) + + assert len(processed_images) == blocks + + if use_thumbnail and len(processed_images) != 1: + thumbnail_img = image.resize((image_size, image_size)) + processed_images.append(thumbnail_img) + + return processed_images + + +# adapted from https://huggingface.co/Skywork/Skywork-R1V-38B +def image_to_pixel_values_skyworkr1v( + image: Image.Image, + *, + input_size: int, + min_num: int, + max_num: int, + use_thumbnail: bool, +) -> torch.Tensor: + target_ratios = get_skyworkr1v_target_ratios(min_num, max_num) + + transform = build_transform(input_size=input_size) + images = dynamic_preprocess_skyworkr1v( + image, + target_ratios=target_ratios, + image_size=input_size, + use_thumbnail=use_thumbnail, + ) + + pixel_values = torch.stack([transform(image) for image in images]) + return pixel_values + + +class BaseSkyworkR1VProcessor(ABC): + """ + This model doesn't define its own HF processor, + so we implement our own one here. + + The code to insert image tokens is based on: + https://huggingface.co/Skywork/Skywork-R1V-38B/blob/main/modeling_skywork_chat.py#L252 + """ + + def __init__( + self, + config: PretrainedConfig, + tokenizer: AnyTokenizer, + *, + min_dynamic_patch: Optional[int] = None, + max_dynamic_patch: Optional[int] = None, + dynamic_image_size: Optional[bool] = None, + ) -> None: + super().__init__() + + self.config = config + self.tokenizer = tokenizer + + image_size: int = config.vision_config.image_size + patch_size: int = config.vision_config.patch_size + + if min_dynamic_patch is None: + min_dynamic_patch = config.min_dynamic_patch + assert isinstance(min_dynamic_patch, int) + + if max_dynamic_patch is None: + max_dynamic_patch = config.max_dynamic_patch + assert isinstance(max_dynamic_patch, int) + + if dynamic_image_size is None: + dynamic_image_size = config.dynamic_image_size + assert isinstance(dynamic_image_size, bool) + + self.num_image_token = int( + (image_size // patch_size)**2 * (config.downsample_ratio**2)) + self.image_size = image_size + self.min_dynamic_patch = min_dynamic_patch + self.max_dynamic_patch = max_dynamic_patch + self.dynamic_image_size = dynamic_image_size + self.use_thumbnail: bool = config.use_thumbnail + + @property + @abstractmethod + def image_token_id(self) -> int: + raise NotImplementedError + + @abstractmethod + def get_image_repl( + self, + feature_size: int, + num_patches: Optional[int], + ) -> PromptUpdateDetails[str]: + raise NotImplementedError + + def resolve_min_max_num( + self, + *, + min_dynamic_patch: Optional[int] = None, + max_dynamic_patch: Optional[int] = None, + dynamic_image_size: Optional[bool] = None, + use_thumbnail: Optional[bool] = None, + ) -> tuple[int, int]: + min_dynamic_patch = (self.min_dynamic_patch if min_dynamic_patch + is None else min_dynamic_patch) + max_dynamic_patch = (self.max_dynamic_patch if max_dynamic_patch + is None else max_dynamic_patch) + dynamic_image_size = (self.dynamic_image_size if dynamic_image_size + is None else dynamic_image_size) + use_thumbnail = (self.use_thumbnail + if use_thumbnail is None else use_thumbnail) + + return resolve_skyworkr1v_min_max_num( + min_dynamic_patch=min_dynamic_patch, + max_dynamic_patch=max_dynamic_patch, + dynamic_image_size=dynamic_image_size, + use_thumbnail=use_thumbnail, + ) + + def resolve_target_ratios( + self, + *, + min_dynamic_patch: Optional[int] = None, + max_dynamic_patch: Optional[int] = None, + dynamic_image_size: Optional[bool] = None, + use_thumbnail: Optional[bool] = None, + ) -> list[tuple[int, int]]: + min_num, max_num = self.resolve_min_max_num( + min_dynamic_patch=min_dynamic_patch, + max_dynamic_patch=max_dynamic_patch, + dynamic_image_size=dynamic_image_size, + use_thumbnail=use_thumbnail, + ) + + return get_skyworkr1v_target_ratios(min_num, max_num) + + def get_num_image_tokens( + self, + *, + image_width: int, + image_height: int, + ) -> int: + target_ratios = self.resolve_target_ratios( + use_thumbnail=False, # Applied in calculate_targets + ) + + num_patches, _, _ = calculate_skyworkr1v_targets( + orig_width=image_width, + orig_height=image_height, + image_size=self.image_size, + target_ratios=target_ratios, + use_thumbnail=self.use_thumbnail, + ) + + return num_patches * self.num_image_token + + def _images_to_pixel_values_lst( + self, + images: list[Image.Image], + min_dynamic_patch: Optional[int] = None, + max_dynamic_patch: Optional[int] = None, + dynamic_image_size: Optional[bool] = None, + ) -> list[torch.Tensor]: + min_num, max_num = self.resolve_min_max_num( + min_dynamic_patch=min_dynamic_patch, + max_dynamic_patch=max_dynamic_patch, + dynamic_image_size=dynamic_image_size, + use_thumbnail=False, # Applied in image_to_pixel_values + ) + + return [ + image_to_pixel_values_skyworkr1v( + image, + input_size=self.image_size, + min_num=min_num, + max_num=max_num, + use_thumbnail=self.use_thumbnail, + ) for image in images + ] + + def __call__( + self, + text: Optional[Union[str, list[str]]] = None, + images: Optional[Union[Image.Image, list[Image.Image]]] = None, + min_dynamic_patch: Optional[int] = None, + max_dynamic_patch: Optional[int] = None, + dynamic_image_size: Optional[bool] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + ) -> Mapping[str, NestedTensors]: + if text is None: + text = [] + if not isinstance(text, list): + text = [text] + if images is None: + images = [] + if not isinstance(images, list): + images = [images] + + if len(images) == 0: + image_inputs = {} + else: + pixel_values_lst = self._images_to_pixel_values_lst( + images, + min_dynamic_patch=min_dynamic_patch, + max_dynamic_patch=max_dynamic_patch, + dynamic_image_size=dynamic_image_size, + ) + image_inputs: dict[str, NestedTensors] = { + "pixel_values_flat": + torch.cat(pixel_values_lst), + "image_num_patches": + torch.tensor([len(item) for item in pixel_values_lst]), + } + + tokenizer = self.tokenizer + image_token_id = self.image_token_id + + embed_is_patch = list[torch.Tensor]() + + for pixel_values in pixel_values_lst: + num_patches = pixel_values.shape[0] + feature_size = num_patches * self.num_image_token + + image_repl = self.get_image_repl(feature_size, num_patches) + feature_tokens = tokenizer.encode(image_repl.features, + add_special_tokens=False) + + text = [t.replace('', image_repl.full, 1) for t in text] + embed_is_patch.append( + torch.tensor(feature_tokens) == image_token_id) + + image_inputs["embed_is_patch"] = embed_is_patch + + text_inputs = self.tokenizer(text) + + return { + **BatchEncoding(text_inputs, tensor_type=return_tensors), + **image_inputs, + } + + +class SkyworkR1VProcessor(BaseSkyworkR1VProcessor): + + @property + def image_token_id(self) -> int: + return self.tokenizer.get_vocab()[IMG_CONTEXT] + + def get_image_repl( + self, + feature_size: int, + num_patches: Optional[int], + ) -> PromptUpdateDetails[str]: + repl_features = IMG_CONTEXT * feature_size + repl_full = IMG_START + repl_features + IMG_END + + return PromptUpdateDetails(full=repl_full, features=repl_features) + + +class BaseSkyworkR1VProcessingInfo(BaseProcessingInfo): + + @abstractmethod + def get_hf_processor( + self, + *, + min_dynamic_patch: Optional[int] = None, + max_dynamic_patch: Optional[int] = None, + dynamic_image_size: Optional[bool] = None, + **kwargs: object, + ) -> BaseSkyworkR1VProcessor: + raise NotImplementedError + + def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + return {"image": None} + + def get_mm_max_tokens_per_item( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> Mapping[str, int]: + return {"image": self.get_max_image_tokens()} + + def get_num_image_tokens( + self, + *, + image_width: int, + image_height: int, + processor: Optional[BaseSkyworkR1VProcessor], + ) -> int: + if processor is None: + processor = self.get_hf_processor() + + return processor.get_num_image_tokens( + image_width=image_width, + image_height=image_height, + ) + + def get_max_image_tokens(self) -> int: + target_width, target_height = self.get_image_size_with_most_features() + + return self.get_num_image_tokens( + image_width=target_width, + image_height=target_height, + processor=None, + ) + + def get_image_size_with_most_features(self) -> ImageSize: + processor = self.get_hf_processor() + + base_size = processor.image_size + target_ratios = processor.resolve_target_ratios() + + largest_feature_size, largest_feature_pinpoint = 0, None + for wr, hr in target_ratios: + width, height = base_size * wr, base_size * hr + + feat_size = self.get_num_image_tokens( + image_width=width, + image_height=height, + processor=processor, + ) + if feat_size > largest_feature_size: + largest_feature_size = feat_size + largest_feature_pinpoint = ImageSize(width=width, + height=height) + + if largest_feature_size == 0 or largest_feature_pinpoint is None: + raise ValueError("Cannot have a largest feature size of 0!") + + return largest_feature_pinpoint + + +_I = TypeVar("_I", bound=BaseSkyworkR1VProcessingInfo) + + +class SkyworkR1VDummyInputsBuilder(BaseDummyInputsBuilder[_I]): + + def get_dummy_processor_inputs( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> ProcessorInputs: + target_width, target_height = \ + self.info.get_image_size_with_most_features() + num_images = mm_counts.get("image", 0) + + mm_data = { + "image": + self._get_dummy_images(width=target_width, + height=target_height, + num_images=num_images) + } + + return ProcessorInputs( + prompt_text="" * num_images, + mm_data=mm_data, + ) + + +class SkyworkR1VMultiModalProcessor(BaseMultiModalProcessor[_I]): + + def _call_hf_processor( + self, + prompt: str, + mm_data: Mapping[str, object], + mm_kwargs: Mapping[str, object], + ) -> Mapping[str, NestedTensors]: + processed_outputs = super()._call_hf_processor( + prompt=prompt, + mm_data=mm_data, + mm_kwargs=mm_kwargs, + ) + + hf_processor = self.info.get_hf_processor(**mm_kwargs) + image_token_id = hf_processor.image_token_id + + # Since there may be extra tokens in the feature placeholders, + # we need to pass the image token ID to the model to select the + # tokens to merge from the vision encoder outputs + processed_outputs["image_token_id"] = torch.tensor(image_token_id) + + return processed_outputs + + def _get_mm_fields_config( + self, + hf_inputs: Mapping[str, NestedTensors], + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + image_num_patches = hf_inputs.get("image_num_patches", torch.empty(0)) + num_images = len(image_num_patches) + + return dict( + pixel_values_flat=MultiModalFieldConfig.flat_from_sizes( + "image", image_num_patches), + image_num_patches=MultiModalFieldConfig.batched("image"), + embed_is_patch=MultiModalFieldConfig.batched("image"), + image_embeds=MultiModalFieldConfig.batched("image"), + image_token_id=MultiModalFieldConfig.shared("image", num_images), + ) + + def _get_prompt_updates( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + out_mm_kwargs: MultiModalKwargs, + ) -> Sequence[PromptUpdate]: + hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) + + if "image_num_patches" in out_mm_kwargs: + image_num_patches = out_mm_kwargs["image_num_patches"] + assert isinstance(image_num_patches, torch.Tensor) + image_num_patches = image_num_patches.tolist() + elif "image_embeds" in out_mm_kwargs: + # TODO: Use image size information in dictionary embedding inputs + # to compute num_patches (similar to Qwen2-VL) + image_num_patches = [None] * len(out_mm_kwargs["image_embeds"]) + else: + image_num_patches = [] + + def get_replacement_skyworkr1v(item_idx: int): + images = mm_items.get_items( + "image", (ImageEmbeddingItems, ImageProcessorItems)) + + if isinstance(images, ImageEmbeddingItems): + feature_size = images.get_feature_size(item_idx) + else: + image_size = images.get_image_size(item_idx) + feature_size = self.info.get_num_image_tokens( + image_width=image_size.width, + image_height=image_size.height, + processor=hf_processor, + ) + + num_patches = image_num_patches[item_idx] + if num_patches is not None: + assert isinstance(num_patches, int) + + return hf_processor.get_image_repl(feature_size, num_patches) + + return [ + PromptReplacement( + modality="image", + target="", + replacement=get_replacement_skyworkr1v, + ) + ] + + +class SkyworkR1VProcessingInfo(BaseSkyworkR1VProcessingInfo): + + def get_hf_processor( + self, + *, + min_dynamic_patch: Optional[int] = None, + max_dynamic_patch: Optional[int] = None, + dynamic_image_size: Optional[bool] = None, + **kwargs: object, + ) -> SkyworkR1VProcessor: + if min_dynamic_patch is not None: + kwargs["min_dynamic_patch"] = min_dynamic_patch + if max_dynamic_patch is not None: + kwargs["max_dynamic_patch"] = max_dynamic_patch + if dynamic_image_size is not None: + kwargs["dynamic_image_size"] = dynamic_image_size + + return self.ctx.init_processor( + SkyworkR1VProcessor, + config=self.get_hf_config(), + tokenizer=self.get_tokenizer(), + **kwargs, + ) + + +@MULTIMODAL_REGISTRY.register_processor( + SkyworkR1VMultiModalProcessor, + info=SkyworkR1VProcessingInfo, + dummy_inputs=SkyworkR1VDummyInputsBuilder) +class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: + super().__init__() + + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + multimodal_config = vllm_config.model_config.multimodal_config + + self.config = config + self.multimodal_config = multimodal_config + self._patch_quant_config(config, quant_config) + + image_size = config.force_image_size or config.vision_config.image_size + patch_size = config.vision_config.patch_size + self.patch_size = patch_size + self.num_image_token = int( + (image_size // patch_size)**2 * (config.downsample_ratio**2)) + self.downsample_ratio = config.downsample_ratio + self.ps_version = config.ps_version + + self.llm_arch_name = config.text_config.architectures[0] + self.is_mono = self.llm_arch_name == 'SkyworkLM2VEForCausalLM' + self.vision_model = self._init_vision_model( + config, + quant_config=quant_config, + is_mono=self.is_mono, + prefix=maybe_prefix(prefix, "vision_model"), + ) + + self.language_model = init_vllm_registered_model( + vllm_config=vllm_config, + hf_config=config.text_config, + prefix=maybe_prefix(prefix, "language_model"), + ) + + self.mlp1 = self._init_mlp1(config) + + self.img_context_token_id = None + self.visual_token_mask = None + self.make_empty_intermediate_tensors = ( + self.language_model.make_empty_intermediate_tensors) + + def _patch_quant_config(self, config: PretrainedConfig, + quant_config: QuantizationConfig): + # the awq models from OpenGVLab missing `modules_to_not_convert` + # patch the quant_config to add `modules_to_not_convert` back + if isinstance(quant_config, AWQConfig): + text_config = config.text_config + llm_quant_config = getattr(text_config, "quantization_config", + None) + if (not quant_config.modules_to_not_convert) and \ + (llm_quant_config is not None): + quant_config.modules_to_not_convert.append("vision_model") + + @cached_property + def sampler(self): + if hasattr(self.language_model, "sampler"): + return self.language_model.sampler + + return get_sampler() + + def _init_vision_model( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig], + *, + is_mono: bool, + prefix: str, + ): + if not is_mono: + vision_feature_layer = config.select_layer + if vision_feature_layer < 0: + num_hidden_layers = config.vision_config.num_hidden_layers \ + + vision_feature_layer + 1 + else: + num_hidden_layers = vision_feature_layer + 1 + + return InternVisionModel( + config.vision_config, + quant_config=quant_config, + num_hidden_layers_override=num_hidden_layers, + prefix=prefix, + ) + else: + return InternVisionPatchModel(config.vision_config) + + def _init_mlp1(self, config: PretrainedConfig) -> nn.Sequential: + vit_hidden_size = config.vision_config.hidden_size + llm_hidden_size = config.text_config.hidden_size + + return nn.Sequential( + nn.LayerNorm(vit_hidden_size * int(1 / self.downsample_ratio)**2), + ReplicatedLinear(vit_hidden_size * + int(1 / self.downsample_ratio)**2, + llm_hidden_size, + return_bias=False), + nn.GELU(), + ReplicatedLinear(llm_hidden_size, + llm_hidden_size, + return_bias=False), + ) + + def pixel_shuffle(self, x, scale_factor=0.5): + n, w, h, c = x.size() + # N, W, H, C --> N, W, H * scale, C // scale + x = x.view(n, w, int(h * scale_factor), int(c / scale_factor)) + # N, W, H * scale, C // scale --> N, H * scale, W, C // scale + x = x.permute(0, 2, 1, 3).contiguous() + x = x.view(n, int(h * scale_factor), int(w * scale_factor), + int(c / (scale_factor * scale_factor))) + if self.ps_version == 'v1': + pass + else: + x = x.permute(0, 2, 1, 3).contiguous() + return x + + def extract_feature(self, pixel_values: torch.Tensor) -> torch.Tensor: + vit_embeds = self.vision_model(pixel_values=pixel_values) + vit_embeds = vit_embeds[:, 1:, :] + + h = w = int(vit_embeds.shape[1]**0.5) + vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1) + vit_embeds = self.pixel_shuffle(vit_embeds, + scale_factor=self.downsample_ratio) + vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, + vit_embeds.shape[-1]) + vit_embeds = self.mlp1(vit_embeds) + return vit_embeds + + def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor: + + h = w = self.config.vision_config.image_size + expected_dims = (3, h, w) + + def _validate_shape(d: torch.Tensor): + actual_dims = tuple(d.shape) + + if actual_dims != expected_dims: + expected_expr = str(expected_dims) + raise ValueError( + "The expected shape of pixel values per image per batch " + f" per patch is {expected_expr}. " + f"You supplied {tuple(d.shape)}.") + + for d in data: + _validate_shape(d) + + return data + + def _parse_and_validate_image_input( + self, **kwargs: object) -> Optional[SkyworkR1VImageInputs]: + pixel_values_flat = kwargs.pop("pixel_values_flat", None) + image_num_patches = kwargs.pop("image_num_patches", None) + embed_is_patch = kwargs.pop("embed_is_patch", None) + image_embeds = kwargs.pop("image_embeds", None) + + if pixel_values_flat is None and image_embeds is None: + return None + + if image_embeds is not None: + if not isinstance(image_embeds, (torch.Tensor, list)): + raise ValueError("Incorrect type of image embeddings. " + f"Got type: {type(image_embeds)}") + + return SkyworkR1VImageEmbeddingInputs( + type="image_embeds", + data=flatten_bn(image_embeds), + ) + + image_token_id = kwargs["image_token_id"] + assert isinstance(image_token_id, torch.Tensor) + self.img_context_token_id = image_token_id.flatten().unique().item() + + if pixel_values_flat is not None: + if not isinstance(pixel_values_flat, (torch.Tensor, list)): + raise ValueError("Incorrect type of pixel values. " + f"Got type: {type(pixel_values_flat)}") + + if not isinstance(image_num_patches, (torch.Tensor, list)): + raise ValueError("Incorrect type of image_num_patches. " + f"Got type: {type(image_num_patches)}") + + if not isinstance(embed_is_patch, (torch.Tensor, list)): + raise ValueError("Incorrect type of embed_is_patch. " + f"Got type: {type(embed_is_patch)}") + + pixel_values_flat = flatten_bn(pixel_values_flat, concat=True) + image_num_patches = flatten_bn(image_num_patches, concat=True) + embed_is_patch = flatten_bn(embed_is_patch) + + return SkyworkR1VImagePixelInputs( + type="pixel_values", + pixel_values_flat=self._validate_pixel_values( + pixel_values_flat), + num_patches=image_num_patches, + embed_is_patch=embed_is_patch, + ) + + raise AssertionError("This line should be unreachable.") + + def _process_image_input( + self, + image_input: SkyworkR1VImageInputs, + ) -> Union[torch.Tensor, list[torch.Tensor], tuple[torch.Tensor, ...]]: + if image_input["type"] == "image_embeds": + return image_input["data"] + + assert self.vision_model is not None + + image_embeds = self.extract_feature(image_input["pixel_values_flat"]) + + num_patches = image_input["num_patches"] + + # Only one image in the current batch + if len(num_patches) == 1: + return image_embeds.view( + -1, self.config.text_config.hidden_size).unsqueeze(0) + + # NOTE: Image embeddings are split into separate tensors for each image + # by the size of each embedding. + feature_size = image_embeds.shape[1] + image_embeds = image_embeds.view(-1, + self.config.text_config.hidden_size) + image_feature_sizes = [ + num_patches * feature_size for num_patches in num_patches + ] + return image_embeds.split(image_feature_sizes) + + def _set_visual_token_mask(self, input_ids: torch.Tensor) -> None: + if self.is_mono: + self.visual_token_mask = ( + input_ids == self.img_context_token_id).reshape(-1, 1) + else: + self.visual_token_mask = None + + def get_multimodal_embeddings( + self, **kwargs: object) -> Optional[MultiModalEmbeddings]: + image_input = self._parse_and_validate_image_input(**kwargs) + if image_input is None: + return None + + image_features = self._process_image_input(image_input) + + if image_input["type"] != "pixel_values": + return image_features + + return scatter_patch_features( + image_features, + image_input["embed_is_patch"], + ) + + def get_input_embeddings( + self, + input_ids: torch.Tensor, + multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + ) -> torch.Tensor: + inputs_embeds = self.language_model.get_input_embeddings(input_ids) + if multimodal_embeddings is not None: + assert self.img_context_token_id is not None + self._set_visual_token_mask(input_ids) + inputs_embeds = merge_multimodal_embeddings( + input_ids, + inputs_embeds, + select_patch_features(multimodal_embeddings), + self.img_context_token_id, + ) + return inputs_embeds + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs: object, + ) -> Union[SamplerOutput, IntermediateTensors]: + + if intermediate_tensors is not None: + input_ids = None + inputs_embeds = None + + # NOTE: In v1, inputs_embeds is always generated at model runner, this + # condition is for v0 compatibility. + elif inputs_embeds is None: + vision_embeddings = self.get_multimodal_embeddings(**kwargs) + inputs_embeds = self.get_input_embeddings(input_ids, + vision_embeddings) + input_ids = None + + forward_kwargs = { + "input_ids": input_ids, + "positions": positions, + "intermediate_tensors": intermediate_tensors, + "inputs_embeds": inputs_embeds, + } + + # Only required if the model is mono-architecture + if self.visual_token_mask is not None: + forward_kwargs.update( + {"visual_token_mask": self.visual_token_mask}) + self.visual_token_mask = None + + hidden_states = self.language_model.model(**forward_kwargs) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + return self.language_model.compute_logits(hidden_states, + sampling_metadata) + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + return self.language_model.sample(logits, sampling_metadata) + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + skip_prefixes = [ + "action_embed", "temporal_embed", "track_embed", + "track_embed_decoder", "box_token", "cg_criterion", "cg_model", + "loc_encoder", "loc_decoder", "sam", "temporal_token", + "track_token" + ] + loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes) + return loader.load_weights(weights) diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index 1937b1388471..71990468c315 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -37,8 +37,8 @@ MLPSpeculatorConfig, MPTConfig, NemotronConfig, NVLM_D_Config, Olmo2Config, RWConfig, - SolarConfig, Telechat2Config, - UltravoxConfig) + SkyworkR1VChatConfig, SolarConfig, + Telechat2Config, UltravoxConfig) # yapf: enable from vllm.transformers_utils.utils import check_gguf_file from vllm.utils import resolve_obj_by_qualname @@ -76,6 +76,7 @@ "NVLM_D": NVLM_D_Config, "olmo2": Olmo2Config, "solar": SolarConfig, + "skywork_chat": SkyworkR1VChatConfig, "telechat": Telechat2Config, "ultravox": UltravoxConfig, **_CONFIG_REGISTRY_OVERRIDE_HF diff --git a/vllm/transformers_utils/configs/__init__.py b/vllm/transformers_utils/configs/__init__.py index 9060565596b2..53699341bfba 100644 --- a/vllm/transformers_utils/configs/__init__.py +++ b/vllm/transformers_utils/configs/__init__.py @@ -20,6 +20,7 @@ from vllm.transformers_utils.configs.nemotron import NemotronConfig from vllm.transformers_utils.configs.nvlm_d import NVLM_D_Config from vllm.transformers_utils.configs.olmo2 import Olmo2Config +from vllm.transformers_utils.configs.skyworkr1v import SkyworkR1VChatConfig from vllm.transformers_utils.configs.solar import SolarConfig from vllm.transformers_utils.configs.telechat2 import Telechat2Config from vllm.transformers_utils.configs.ultravox import UltravoxConfig @@ -42,6 +43,7 @@ "NemotronConfig", "NVLM_D_Config", "Olmo2Config", + "SkyworkR1VChatConfig", "SolarConfig", "Telechat2Config", "UltravoxConfig", diff --git a/vllm/transformers_utils/configs/skyworkr1v.py b/vllm/transformers_utils/configs/skyworkr1v.py new file mode 100644 index 000000000000..ef5f9ba85c23 --- /dev/null +++ b/vllm/transformers_utils/configs/skyworkr1v.py @@ -0,0 +1,53 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Adapted from +# https://huggingface.co/Skywork/Skywork-R1V-38B/blob/main/configuration_skywork_chat.py +# -------------------------------------------------------- +# SkyworkR1V +# Copyright (c) 2025 Skywork +# Licensed under The MIT License [see LICENSE for details] +# -------------------------------------------------------- +from transformers.configuration_utils import PretrainedConfig + + +class SkyworkR1VChatConfig(PretrainedConfig): + model_type = 'internvl_chat' + is_composition = True + + def __init__(self, + vision_config=None, + llm_config=None, + use_backbone_lora=0, + use_llm_lora=0, + select_layer=-1, + force_image_size=None, + downsample_ratio=0.5, + template=None, + dynamic_image_size=False, + use_thumbnail=False, + ps_version='v1', + min_dynamic_patch=1, + max_dynamic_patch=6, + **kwargs): + super().__init__(**kwargs) + + if vision_config is None: + vision_config = {} + + if llm_config is None: + llm_config = {} + + self.vision_config = PretrainedConfig(**vision_config) + self.text_config = PretrainedConfig(**llm_config) + + self.use_backbone_lora = use_backbone_lora + self.use_llm_lora = use_llm_lora + self.select_layer = select_layer + self.force_image_size = force_image_size + self.downsample_ratio = downsample_ratio + self.template = template + self.dynamic_image_size = dynamic_image_size + self.use_thumbnail = use_thumbnail + self.ps_version = ps_version # pixel shuffle version + self.min_dynamic_patch = min_dynamic_patch + self.max_dynamic_patch = max_dynamic_patch