diff --git a/docs/source/models/supported_models.md b/docs/source/models/supported_models.md
index 793831fd06de..8477158a0040 100644
--- a/docs/source/models/supported_models.md
+++ b/docs/source/models/supported_models.md
@@ -921,6 +921,13 @@ See [this page](#generative-models) for more information on how to use generativ
* ✅︎
* ✅︎
* ✅︎
+- * `SkyworkR1VChatModel`
+ * Skywork-R1V-38B
+ * T + I
+ * `Skywork/Skywork-R1V-38B`
+ *
+ * ✅︎
+ * ✅︎
- * `UltravoxModel`
* Ultravox
* T + AE+
diff --git a/examples/offline_inference/vision_language.py b/examples/offline_inference/vision_language.py
index 0adbe574370d..572eabe26193 100644
--- a/examples/offline_inference/vision_language.py
+++ b/examples/offline_inference/vision_language.py
@@ -804,6 +804,41 @@ def run_qwen2_5_vl(questions: list[str], modality: str) -> ModelRequestData:
)
+# SkyworkR1V
+def run_skyworkr1v(questions: list[str], modality: str) -> ModelRequestData:
+ assert modality == "image"
+
+ model_name = "Skywork/Skywork-R1V-38B"
+
+ engine_args = EngineArgs(
+ model=model_name,
+ trust_remote_code=True,
+ max_model_len=4096,
+ disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
+ )
+
+ tokenizer = AutoTokenizer.from_pretrained(model_name,
+ trust_remote_code=True)
+ messages = [[{
+ 'role': 'user',
+ 'content': f"\n{question}"
+ }] for question in questions]
+ prompts = tokenizer.apply_chat_template(messages,
+ tokenize=False,
+ add_generation_prompt=True)
+
+ # Stop tokens for SkyworkR1V
+ # https://huggingface.co/Skywork/Skywork-R1V-38B/blob/main/conversation.py
+ stop_tokens = ["<|end▁of▁sentence|>", "<|endoftext|>"]
+ stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens]
+
+ return ModelRequestData(
+ engine_args=engine_args,
+ prompts=prompts,
+ stop_token_ids=stop_token_ids,
+ )
+
+
model_example_map = {
"aria": run_aria,
"blip-2": run_blip2,
@@ -834,6 +869,7 @@ def run_qwen2_5_vl(questions: list[str], modality: str) -> ModelRequestData:
"qwen_vl": run_qwen_vl,
"qwen2_vl": run_qwen2_vl,
"qwen2_5_vl": run_qwen2_5_vl,
+ "skywork_chat": run_skyworkr1v,
}
diff --git a/tests/models/decoder_only/vision_language/test_models.py b/tests/models/decoder_only/vision_language/test_models.py
index d500ef5d8b80..0d1d237e5693 100644
--- a/tests/models/decoder_only/vision_language/test_models.py
+++ b/tests/models/decoder_only/vision_language/test_models.py
@@ -474,6 +474,20 @@
vllm_output_post_proc=model_utils.qwen_vllm_to_hf_output,
prompt_path_encoder=model_utils.qwen_prompt_path_encoder,
),
+ "skywork_r1v": VLMTestInfo(
+ models=["Skywork/Skywork-R1V-38B"],
+ test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE),
+ prompt_formatter=lambda img_prompt: f"<|begin▁of▁sentence|><|User|>\n{img_prompt}<|Assistant|>\n", # noqa: E501
+ single_image_prompts=IMAGE_ASSETS.prompts({
+ "stop_sign": "\nWhat's the content in the center of the image?", # noqa: E501
+ "cherry_blossom": "\nWhat is the season?",
+ }),
+ multi_image_prompt="\n\nDescribe the two images in short.", # noqa: E501
+ max_model_len=4096,
+ use_tokenizer_eos=True,
+ patch_hf_runner=model_utils.skyworkr1v_patch_hf_runner,
+ marks=[large_gpu_mark(min_gb=80)],
+ ),
### Tensor parallel / multi-gpu broadcast tests
"chameleon-broadcast": VLMTestInfo(
models=["facebook/chameleon-7b"],
diff --git a/tests/models/decoder_only/vision_language/vlm_utils/model_utils.py b/tests/models/decoder_only/vision_language/vlm_utils/model_utils.py
index c84bf6dc15f4..2ddf28aca4f6 100644
--- a/tests/models/decoder_only/vision_language/vlm_utils/model_utils.py
+++ b/tests/models/decoder_only/vision_language/vlm_utils/model_utils.py
@@ -376,6 +376,63 @@ def __call__(self, text: str, images: Union[Image, list[Image]],
return hf_model
+def skyworkr1v_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
+ """Patches and returns an instance of the HfRunner to use for SkyworkR1V."""
+
+ class SkyworkR1VProcessor:
+ """A simple processor for SkyworkR1V."""
+
+ def __init__(self, hf_runner: HfRunner):
+ self.num_image_token = hf_runner.model.num_image_token
+ self.tokenizer = hf_runner.tokenizer
+
+ self.config = AutoConfig.from_pretrained(hf_runner.model_name,
+ trust_remote_code=True)
+ self.vision_config = self.config.vision_config
+ self.use_thumbnail = self.config.use_thumbnail
+ self.min_num = self.config.min_dynamic_patch
+ self.max_num = self.config.max_dynamic_patch
+ self.image_size = self.vision_config.image_size
+
+ def __call__(self, text: str, images: Union[Image, list[Image]],
+ **kwargs):
+ from vllm.model_executor.models.skyworkr1v import (
+ IMG_CONTEXT, IMG_END, IMG_START,
+ image_to_pixel_values_skyworkr1v)
+ images = [images] if isinstance(images, Image) else images
+ pixel_values = [
+ image_to_pixel_values_skyworkr1v(
+ image,
+ input_size=self.image_size,
+ min_num=self.min_num,
+ max_num=self.max_num,
+ use_thumbnail=self.use_thumbnail,
+ ) for image in images
+ ]
+ num_patches_list = [
+ pixel_value.shape[0] for pixel_value in pixel_values
+ ]
+ pixel_values = torch.cat(pixel_values, dim=0)
+ for num_patches in num_patches_list:
+ context_tokens = IMG_CONTEXT * self.num_image_token \
+ * num_patches
+ image_tokens = IMG_START + context_tokens + IMG_END
+ text = text.replace('', image_tokens, 1)
+ prompt = self.tokenizer(text, return_tensors="pt")
+ prompt.update({"pixel_values": pixel_values})
+ return prompt
+
+ img_context_token_id = hf_model.tokenizer.convert_tokens_to_ids(
+ "")
+ hf_model.model.img_context_token_id = img_context_token_id
+ hf_model.processor = SkyworkR1VProcessor(hf_model)
+ hf_model.model.get_output_embeddings = lambda: \
+ hf_model.model.language_model.get_output_embeddings()
+ hf_model.model.generate = types.MethodType(_internvl_generate,
+ hf_model.model)
+ return hf_model
+
+
def internvl_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
"""Patches and returns an instance of the HfRunner to use for InternVL."""
diff --git a/tests/models/multimodal/processing/test_common.py b/tests/models/multimodal/processing/test_common.py
index 078ed21537b8..e4f1d297fc09 100644
--- a/tests/models/multimodal/processing/test_common.py
+++ b/tests/models/multimodal/processing/test_common.py
@@ -262,22 +262,23 @@ def _test_processing_correctness_mistral(
"llava-hf/llava-onevision-qwen2-0.5b-ov-hf",
"meta-llama/Llama-3.2-11B-Vision-Instruct",
"TIGER-Lab/Mantis-8B-siglip-llama3",
- "mistralai/Pixtral-12B-2409",
- "mistral-community/pixtral-12b",
"openbmb/MiniCPM-Llama3-V-2_5",
"openbmb/MiniCPM-o-2_6",
"openbmb/MiniCPM-V-2_6",
"allenai/Molmo-7B-D-0924",
"allenai/Molmo-7B-O-0924",
"nvidia/NVLM-D-72B",
+ "google/paligemma-3b-mix-224",
+ "google/paligemma2-3b-ft-docci-448",
+ "mistralai/Pixtral-12B-2409",
+ "mistral-community/pixtral-12b",
"Qwen/Qwen-VL-Chat",
"Qwen/Qwen2-VL-2B-Instruct",
"Qwen/Qwen2.5-VL-3B-Instruct",
"Qwen/Qwen2-Audio-7B-Instruct",
+ "Skywork/Skywork-R1V-38B",
"fixie-ai/ultravox-v0_5-llama-3_2-1b",
"openai/whisper-large-v3",
- "google/paligemma-3b-mix-224",
- "google/paligemma2-3b-ft-docci-448",
])
@pytest.mark.parametrize("hit_rate", [0.3, 0.5, 1.0])
@pytest.mark.parametrize("num_batches", [32])
diff --git a/tests/models/registry.py b/tests/models/registry.py
index d7946b75b797..ff0c37a6afd7 100644
--- a/tests/models/registry.py
+++ b/tests/models/registry.py
@@ -294,6 +294,7 @@ def check_available_online(
"Qwen2VLForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2-VL-2B-Instruct"), # noqa: E501
"Qwen2_5_VLForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2.5-VL-3B-Instruct", # noqa: E501
min_transformers_version="4.49"), # noqa: E501
+ "SkyworkR1VChatModel": _HfExamplesInfo("Skywork/Skywork-R1V-38B"),
"UltravoxModel": _HfExamplesInfo("fixie-ai/ultravox-v0_5-llama-3_2-1b", # noqa: E501
trust_remote_code=True),
# [Encoder-decoder]
diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py
index 73a69d3037f7..24382142768b 100644
--- a/vllm/entrypoints/chat_utils.py
+++ b/vllm/entrypoints/chat_utils.py
@@ -496,7 +496,7 @@ def _placeholder_str(self, modality: ModalityStr,
return self._cached_token_str(self._tokenizer,
hf_config.image_token_index)
if model_type in ("chameleon", "deepseek_vl_v2", "internvl_chat",
- "NVLM_D", "h2ovl_chat"):
+ "skywork_chat", "NVLM_D", "h2ovl_chat"):
return ""
if model_type == "mllama":
return "<|image|>"
diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py
index 7797d9a2cc20..9288a4b81748 100644
--- a/vllm/model_executor/models/registry.py
+++ b/vllm/model_executor/models/registry.py
@@ -190,6 +190,7 @@
# [Encoder-decoder]
"Florence2ForConditionalGeneration": ("florence2", "Florence2ForConditionalGeneration"), # noqa: E501
"MllamaForConditionalGeneration": ("mllama", "MllamaForConditionalGeneration"), # noqa: E501
+ "SkyworkR1VChatModel": ("skyworkr1v", "SkyworkR1VChatModel"),
"WhisperForConditionalGeneration": ("whisper", "WhisperForConditionalGeneration"), # noqa: E501
}
diff --git a/vllm/model_executor/models/skyworkr1v.py b/vllm/model_executor/models/skyworkr1v.py
new file mode 100644
index 000000000000..ac5de0e36b89
--- /dev/null
+++ b/vllm/model_executor/models/skyworkr1v.py
@@ -0,0 +1,1014 @@
+# SPDX-License-Identifier: Apache-2.0
+
+# adapted from https://huggingface.co/Skywork/Skywork-R1V-38B/blob/main/modeling_skywork_chat.py
+# --------------------------------------------------------
+# SkyworkR1V
+# Copyright (c) 2025 Skywork
+# Licensed under The MIT License [see LICENSE for details]
+# --------------------------------------------------------
+from abc import ABC, abstractmethod
+from collections.abc import Iterable, Mapping, Sequence
+from functools import cached_property
+from typing import Literal, Optional, Set, Tuple, TypedDict, TypeVar, Union
+
+import torch
+import torch.nn as nn
+import torchvision.transforms as T
+from PIL import Image
+from transformers import BatchEncoding, PretrainedConfig, TensorType
+
+from vllm.config import VllmConfig
+from vllm.model_executor.layers.linear import ReplicatedLinear
+from vllm.model_executor.layers.quantization import QuantizationConfig
+from vllm.model_executor.layers.quantization.awq import AWQConfig
+from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
+from vllm.model_executor.models.intern_vit import (InternVisionModel,
+ InternVisionPatchModel)
+from vllm.model_executor.sampling_metadata import SamplingMetadata
+from vllm.multimodal import MULTIMODAL_REGISTRY
+from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
+ NestedTensors)
+from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
+ ImageSize, MultiModalDataItems)
+from vllm.multimodal.processing import (BaseMultiModalProcessor,
+ BaseProcessingInfo, PromptReplacement,
+ PromptUpdate, PromptUpdateDetails)
+from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
+from vllm.sequence import IntermediateTensors
+from vllm.transformers_utils.tokenizer import AnyTokenizer
+
+from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
+from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
+ maybe_prefix, merge_multimodal_embeddings)
+from .vision import scatter_patch_features, select_patch_features
+
+IMG_START = '
'
+IMG_END = ''
+IMG_CONTEXT = ''
+
+IMAGENET_MEAN = (0.485, 0.456, 0.406)
+IMAGENET_STD = (0.229, 0.224, 0.225)
+
+
+class SkyworkR1VImagePixelInputs(TypedDict):
+ type: Literal["pixel_values"]
+ pixel_values_flat: torch.Tensor
+ """
+ Shape:
+ `(batch_size * num_images * (1 + num_patches), num_channels, height, width)`
+ """
+
+ num_patches: torch.Tensor
+ """Shape: `(batch_size * num_images)`"""
+
+ embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
+ """
+ A boolean mask indicating which image embeddings correspond
+ to patch tokens.
+
+ Shape: `(batch_size * num_images, num_embeds)`
+ """
+
+
+class SkyworkR1VImageEmbeddingInputs(TypedDict):
+ type: Literal["image_embeds"]
+ data: Union[torch.Tensor, list[torch.Tensor]]
+ """
+ A tensor of shape `(num_images, total_image_feature_size, hidden_size)`
+ or a list of tensors of shape `(total_image_feature_size, hidden_size)`
+
+ `hidden_size` must match the hidden size of language model backbone.
+ """
+
+
+SkyworkR1VImageInputs = Union[SkyworkR1VImagePixelInputs,
+ SkyworkR1VImageEmbeddingInputs]
+
+
+# adapted from https://huggingface.co/Skywork/Skywork-R1V-38B/
+def build_transform(input_size: int):
+ MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
+ return T.Compose([
+ T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
+ T.Resize((input_size, input_size),
+ interpolation=T.InterpolationMode.BICUBIC),
+ T.ToTensor(),
+ T.Normalize(mean=MEAN, std=STD)
+ ])
+
+
+# adapted from https://huggingface.co/Skywork/Skywork-R1V-38B/
+def find_closest_aspect_ratio(
+ aspect_ratio: float,
+ target_ratios: list[tuple[int, int]],
+ *,
+ width: int,
+ height: int,
+ image_size: int,
+) -> tuple[int, int]:
+ best_ratio_diff = float('inf')
+ best_ratio = (1, 1)
+ area = width * height
+ for ratio in target_ratios:
+ target_aspect_ratio = ratio[0] / ratio[1]
+ ratio_diff = abs(aspect_ratio - target_aspect_ratio)
+ if ratio_diff < best_ratio_diff:
+ best_ratio_diff = ratio_diff
+ best_ratio = ratio
+ elif ratio_diff == best_ratio_diff:
+ if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
+ best_ratio = ratio
+ return best_ratio
+
+
+def resolve_skyworkr1v_min_max_num(
+ *,
+ min_dynamic_patch: int,
+ max_dynamic_patch: int,
+ dynamic_image_size: bool,
+ use_thumbnail: bool,
+) -> tuple[int, int]:
+ min_dynamic_patch = min_dynamic_patch if dynamic_image_size else 1
+ max_dynamic_patch = max_dynamic_patch if dynamic_image_size else 1
+
+ if use_thumbnail and max_dynamic_patch != 1:
+ max_dynamic_patch += 1
+
+ return min_dynamic_patch, max_dynamic_patch
+
+
+def get_skyworkr1v_target_ratios(
+ min_num: int,
+ max_num: int,
+) -> list[tuple[int, int]]:
+ target_ratios = {(i, j)
+ for n in range(min_num, max_num + 1)
+ for i in range(1, n + 1)
+ for j in range(1, n + 1) if min_num <= i * j <= max_num}
+ return sorted(target_ratios, key=lambda x: x[0] * x[1])
+
+
+def calculate_skyworkr1v_targets(
+ *,
+ orig_width: int,
+ orig_height: int,
+ target_ratios: list[tuple[int, int]],
+ image_size: int,
+ use_thumbnail: bool,
+) -> tuple[int, int, int]:
+ aspect_ratio = orig_width / orig_height
+
+ # find the closest aspect ratio to the target
+ target_aspect_ratio = find_closest_aspect_ratio(
+ aspect_ratio,
+ target_ratios,
+ width=orig_width,
+ height=orig_height,
+ image_size=image_size,
+ )
+
+ # calculate the target width and height
+ target_width = image_size * target_aspect_ratio[0]
+ target_height = image_size * target_aspect_ratio[1]
+ blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
+
+ # add thumbnail image if num_blocks != 1
+ if use_thumbnail and blocks != 1:
+ blocks += 1
+
+ return blocks, target_width, target_height
+
+
+def dynamic_preprocess_skyworkr1v(
+ image: Image.Image,
+ *,
+ target_ratios: list[tuple[int, int]],
+ image_size: int,
+ use_thumbnail: bool,
+) -> list[Image.Image]:
+ orig_width, orig_height = image.size
+
+ # calculate the number of blocks without thumbnail
+ blocks, target_width, target_height = calculate_skyworkr1v_targets(
+ orig_width=orig_width,
+ orig_height=orig_height,
+ target_ratios=target_ratios,
+ image_size=image_size,
+ use_thumbnail=False,
+ )
+
+ # resize the image
+ resized_img = image.resize((target_width, target_height))
+ processed_images = []
+ for i in range(blocks):
+ box = ((i % (target_width // image_size)) * image_size,
+ (i // (target_width // image_size)) * image_size,
+ ((i % (target_width // image_size)) + 1) * image_size,
+ ((i // (target_width // image_size)) + 1) * image_size)
+ # split the image
+ split_img = resized_img.crop(box)
+ processed_images.append(split_img)
+
+ assert len(processed_images) == blocks
+
+ if use_thumbnail and len(processed_images) != 1:
+ thumbnail_img = image.resize((image_size, image_size))
+ processed_images.append(thumbnail_img)
+
+ return processed_images
+
+
+# adapted from https://huggingface.co/Skywork/Skywork-R1V-38B
+def image_to_pixel_values_skyworkr1v(
+ image: Image.Image,
+ *,
+ input_size: int,
+ min_num: int,
+ max_num: int,
+ use_thumbnail: bool,
+) -> torch.Tensor:
+ target_ratios = get_skyworkr1v_target_ratios(min_num, max_num)
+
+ transform = build_transform(input_size=input_size)
+ images = dynamic_preprocess_skyworkr1v(
+ image,
+ target_ratios=target_ratios,
+ image_size=input_size,
+ use_thumbnail=use_thumbnail,
+ )
+
+ pixel_values = torch.stack([transform(image) for image in images])
+ return pixel_values
+
+
+class BaseSkyworkR1VProcessor(ABC):
+ """
+ This model doesn't define its own HF processor,
+ so we implement our own one here.
+
+ The code to insert image tokens is based on:
+ https://huggingface.co/Skywork/Skywork-R1V-38B/blob/main/modeling_skywork_chat.py#L252
+ """
+
+ def __init__(
+ self,
+ config: PretrainedConfig,
+ tokenizer: AnyTokenizer,
+ *,
+ min_dynamic_patch: Optional[int] = None,
+ max_dynamic_patch: Optional[int] = None,
+ dynamic_image_size: Optional[bool] = None,
+ ) -> None:
+ super().__init__()
+
+ self.config = config
+ self.tokenizer = tokenizer
+
+ image_size: int = config.vision_config.image_size
+ patch_size: int = config.vision_config.patch_size
+
+ if min_dynamic_patch is None:
+ min_dynamic_patch = config.min_dynamic_patch
+ assert isinstance(min_dynamic_patch, int)
+
+ if max_dynamic_patch is None:
+ max_dynamic_patch = config.max_dynamic_patch
+ assert isinstance(max_dynamic_patch, int)
+
+ if dynamic_image_size is None:
+ dynamic_image_size = config.dynamic_image_size
+ assert isinstance(dynamic_image_size, bool)
+
+ self.num_image_token = int(
+ (image_size // patch_size)**2 * (config.downsample_ratio**2))
+ self.image_size = image_size
+ self.min_dynamic_patch = min_dynamic_patch
+ self.max_dynamic_patch = max_dynamic_patch
+ self.dynamic_image_size = dynamic_image_size
+ self.use_thumbnail: bool = config.use_thumbnail
+
+ @property
+ @abstractmethod
+ def image_token_id(self) -> int:
+ raise NotImplementedError
+
+ @abstractmethod
+ def get_image_repl(
+ self,
+ feature_size: int,
+ num_patches: Optional[int],
+ ) -> PromptUpdateDetails[str]:
+ raise NotImplementedError
+
+ def resolve_min_max_num(
+ self,
+ *,
+ min_dynamic_patch: Optional[int] = None,
+ max_dynamic_patch: Optional[int] = None,
+ dynamic_image_size: Optional[bool] = None,
+ use_thumbnail: Optional[bool] = None,
+ ) -> tuple[int, int]:
+ min_dynamic_patch = (self.min_dynamic_patch if min_dynamic_patch
+ is None else min_dynamic_patch)
+ max_dynamic_patch = (self.max_dynamic_patch if max_dynamic_patch
+ is None else max_dynamic_patch)
+ dynamic_image_size = (self.dynamic_image_size if dynamic_image_size
+ is None else dynamic_image_size)
+ use_thumbnail = (self.use_thumbnail
+ if use_thumbnail is None else use_thumbnail)
+
+ return resolve_skyworkr1v_min_max_num(
+ min_dynamic_patch=min_dynamic_patch,
+ max_dynamic_patch=max_dynamic_patch,
+ dynamic_image_size=dynamic_image_size,
+ use_thumbnail=use_thumbnail,
+ )
+
+ def resolve_target_ratios(
+ self,
+ *,
+ min_dynamic_patch: Optional[int] = None,
+ max_dynamic_patch: Optional[int] = None,
+ dynamic_image_size: Optional[bool] = None,
+ use_thumbnail: Optional[bool] = None,
+ ) -> list[tuple[int, int]]:
+ min_num, max_num = self.resolve_min_max_num(
+ min_dynamic_patch=min_dynamic_patch,
+ max_dynamic_patch=max_dynamic_patch,
+ dynamic_image_size=dynamic_image_size,
+ use_thumbnail=use_thumbnail,
+ )
+
+ return get_skyworkr1v_target_ratios(min_num, max_num)
+
+ def get_num_image_tokens(
+ self,
+ *,
+ image_width: int,
+ image_height: int,
+ ) -> int:
+ target_ratios = self.resolve_target_ratios(
+ use_thumbnail=False, # Applied in calculate_targets
+ )
+
+ num_patches, _, _ = calculate_skyworkr1v_targets(
+ orig_width=image_width,
+ orig_height=image_height,
+ image_size=self.image_size,
+ target_ratios=target_ratios,
+ use_thumbnail=self.use_thumbnail,
+ )
+
+ return num_patches * self.num_image_token
+
+ def _images_to_pixel_values_lst(
+ self,
+ images: list[Image.Image],
+ min_dynamic_patch: Optional[int] = None,
+ max_dynamic_patch: Optional[int] = None,
+ dynamic_image_size: Optional[bool] = None,
+ ) -> list[torch.Tensor]:
+ min_num, max_num = self.resolve_min_max_num(
+ min_dynamic_patch=min_dynamic_patch,
+ max_dynamic_patch=max_dynamic_patch,
+ dynamic_image_size=dynamic_image_size,
+ use_thumbnail=False, # Applied in image_to_pixel_values
+ )
+
+ return [
+ image_to_pixel_values_skyworkr1v(
+ image,
+ input_size=self.image_size,
+ min_num=min_num,
+ max_num=max_num,
+ use_thumbnail=self.use_thumbnail,
+ ) for image in images
+ ]
+
+ def __call__(
+ self,
+ text: Optional[Union[str, list[str]]] = None,
+ images: Optional[Union[Image.Image, list[Image.Image]]] = None,
+ min_dynamic_patch: Optional[int] = None,
+ max_dynamic_patch: Optional[int] = None,
+ dynamic_image_size: Optional[bool] = None,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ ) -> Mapping[str, NestedTensors]:
+ if text is None:
+ text = []
+ if not isinstance(text, list):
+ text = [text]
+ if images is None:
+ images = []
+ if not isinstance(images, list):
+ images = [images]
+
+ if len(images) == 0:
+ image_inputs = {}
+ else:
+ pixel_values_lst = self._images_to_pixel_values_lst(
+ images,
+ min_dynamic_patch=min_dynamic_patch,
+ max_dynamic_patch=max_dynamic_patch,
+ dynamic_image_size=dynamic_image_size,
+ )
+ image_inputs: dict[str, NestedTensors] = {
+ "pixel_values_flat":
+ torch.cat(pixel_values_lst),
+ "image_num_patches":
+ torch.tensor([len(item) for item in pixel_values_lst]),
+ }
+
+ tokenizer = self.tokenizer
+ image_token_id = self.image_token_id
+
+ embed_is_patch = list[torch.Tensor]()
+
+ for pixel_values in pixel_values_lst:
+ num_patches = pixel_values.shape[0]
+ feature_size = num_patches * self.num_image_token
+
+ image_repl = self.get_image_repl(feature_size, num_patches)
+ feature_tokens = tokenizer.encode(image_repl.features,
+ add_special_tokens=False)
+
+ text = [t.replace('', image_repl.full, 1) for t in text]
+ embed_is_patch.append(
+ torch.tensor(feature_tokens) == image_token_id)
+
+ image_inputs["embed_is_patch"] = embed_is_patch
+
+ text_inputs = self.tokenizer(text)
+
+ return {
+ **BatchEncoding(text_inputs, tensor_type=return_tensors),
+ **image_inputs,
+ }
+
+
+class SkyworkR1VProcessor(BaseSkyworkR1VProcessor):
+
+ @property
+ def image_token_id(self) -> int:
+ return self.tokenizer.get_vocab()[IMG_CONTEXT]
+
+ def get_image_repl(
+ self,
+ feature_size: int,
+ num_patches: Optional[int],
+ ) -> PromptUpdateDetails[str]:
+ repl_features = IMG_CONTEXT * feature_size
+ repl_full = IMG_START + repl_features + IMG_END
+
+ return PromptUpdateDetails(full=repl_full, features=repl_features)
+
+
+class BaseSkyworkR1VProcessingInfo(BaseProcessingInfo):
+
+ @abstractmethod
+ def get_hf_processor(
+ self,
+ *,
+ min_dynamic_patch: Optional[int] = None,
+ max_dynamic_patch: Optional[int] = None,
+ dynamic_image_size: Optional[bool] = None,
+ **kwargs: object,
+ ) -> BaseSkyworkR1VProcessor:
+ raise NotImplementedError
+
+ def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
+ return {"image": None}
+
+ def get_mm_max_tokens_per_item(
+ self,
+ seq_len: int,
+ mm_counts: Mapping[str, int],
+ ) -> Mapping[str, int]:
+ return {"image": self.get_max_image_tokens()}
+
+ def get_num_image_tokens(
+ self,
+ *,
+ image_width: int,
+ image_height: int,
+ processor: Optional[BaseSkyworkR1VProcessor],
+ ) -> int:
+ if processor is None:
+ processor = self.get_hf_processor()
+
+ return processor.get_num_image_tokens(
+ image_width=image_width,
+ image_height=image_height,
+ )
+
+ def get_max_image_tokens(self) -> int:
+ target_width, target_height = self.get_image_size_with_most_features()
+
+ return self.get_num_image_tokens(
+ image_width=target_width,
+ image_height=target_height,
+ processor=None,
+ )
+
+ def get_image_size_with_most_features(self) -> ImageSize:
+ processor = self.get_hf_processor()
+
+ base_size = processor.image_size
+ target_ratios = processor.resolve_target_ratios()
+
+ largest_feature_size, largest_feature_pinpoint = 0, None
+ for wr, hr in target_ratios:
+ width, height = base_size * wr, base_size * hr
+
+ feat_size = self.get_num_image_tokens(
+ image_width=width,
+ image_height=height,
+ processor=processor,
+ )
+ if feat_size > largest_feature_size:
+ largest_feature_size = feat_size
+ largest_feature_pinpoint = ImageSize(width=width,
+ height=height)
+
+ if largest_feature_size == 0 or largest_feature_pinpoint is None:
+ raise ValueError("Cannot have a largest feature size of 0!")
+
+ return largest_feature_pinpoint
+
+
+_I = TypeVar("_I", bound=BaseSkyworkR1VProcessingInfo)
+
+
+class SkyworkR1VDummyInputsBuilder(BaseDummyInputsBuilder[_I]):
+
+ def get_dummy_processor_inputs(
+ self,
+ seq_len: int,
+ mm_counts: Mapping[str, int],
+ ) -> ProcessorInputs:
+ target_width, target_height = \
+ self.info.get_image_size_with_most_features()
+ num_images = mm_counts.get("image", 0)
+
+ mm_data = {
+ "image":
+ self._get_dummy_images(width=target_width,
+ height=target_height,
+ num_images=num_images)
+ }
+
+ return ProcessorInputs(
+ prompt_text="" * num_images,
+ mm_data=mm_data,
+ )
+
+
+class SkyworkR1VMultiModalProcessor(BaseMultiModalProcessor[_I]):
+
+ def _call_hf_processor(
+ self,
+ prompt: str,
+ mm_data: Mapping[str, object],
+ mm_kwargs: Mapping[str, object],
+ ) -> Mapping[str, NestedTensors]:
+ processed_outputs = super()._call_hf_processor(
+ prompt=prompt,
+ mm_data=mm_data,
+ mm_kwargs=mm_kwargs,
+ )
+
+ hf_processor = self.info.get_hf_processor(**mm_kwargs)
+ image_token_id = hf_processor.image_token_id
+
+ # Since there may be extra tokens in the feature placeholders,
+ # we need to pass the image token ID to the model to select the
+ # tokens to merge from the vision encoder outputs
+ processed_outputs["image_token_id"] = torch.tensor(image_token_id)
+
+ return processed_outputs
+
+ def _get_mm_fields_config(
+ self,
+ hf_inputs: Mapping[str, NestedTensors],
+ hf_processor_mm_kwargs: Mapping[str, object],
+ ) -> Mapping[str, MultiModalFieldConfig]:
+ image_num_patches = hf_inputs.get("image_num_patches", torch.empty(0))
+ num_images = len(image_num_patches)
+
+ return dict(
+ pixel_values_flat=MultiModalFieldConfig.flat_from_sizes(
+ "image", image_num_patches),
+ image_num_patches=MultiModalFieldConfig.batched("image"),
+ embed_is_patch=MultiModalFieldConfig.batched("image"),
+ image_embeds=MultiModalFieldConfig.batched("image"),
+ image_token_id=MultiModalFieldConfig.shared("image", num_images),
+ )
+
+ def _get_prompt_updates(
+ self,
+ mm_items: MultiModalDataItems,
+ hf_processor_mm_kwargs: Mapping[str, object],
+ out_mm_kwargs: MultiModalKwargs,
+ ) -> Sequence[PromptUpdate]:
+ hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
+
+ if "image_num_patches" in out_mm_kwargs:
+ image_num_patches = out_mm_kwargs["image_num_patches"]
+ assert isinstance(image_num_patches, torch.Tensor)
+ image_num_patches = image_num_patches.tolist()
+ elif "image_embeds" in out_mm_kwargs:
+ # TODO: Use image size information in dictionary embedding inputs
+ # to compute num_patches (similar to Qwen2-VL)
+ image_num_patches = [None] * len(out_mm_kwargs["image_embeds"])
+ else:
+ image_num_patches = []
+
+ def get_replacement_skyworkr1v(item_idx: int):
+ images = mm_items.get_items(
+ "image", (ImageEmbeddingItems, ImageProcessorItems))
+
+ if isinstance(images, ImageEmbeddingItems):
+ feature_size = images.get_feature_size(item_idx)
+ else:
+ image_size = images.get_image_size(item_idx)
+ feature_size = self.info.get_num_image_tokens(
+ image_width=image_size.width,
+ image_height=image_size.height,
+ processor=hf_processor,
+ )
+
+ num_patches = image_num_patches[item_idx]
+ if num_patches is not None:
+ assert isinstance(num_patches, int)
+
+ return hf_processor.get_image_repl(feature_size, num_patches)
+
+ return [
+ PromptReplacement(
+ modality="image",
+ target="",
+ replacement=get_replacement_skyworkr1v,
+ )
+ ]
+
+
+class SkyworkR1VProcessingInfo(BaseSkyworkR1VProcessingInfo):
+
+ def get_hf_processor(
+ self,
+ *,
+ min_dynamic_patch: Optional[int] = None,
+ max_dynamic_patch: Optional[int] = None,
+ dynamic_image_size: Optional[bool] = None,
+ **kwargs: object,
+ ) -> SkyworkR1VProcessor:
+ if min_dynamic_patch is not None:
+ kwargs["min_dynamic_patch"] = min_dynamic_patch
+ if max_dynamic_patch is not None:
+ kwargs["max_dynamic_patch"] = max_dynamic_patch
+ if dynamic_image_size is not None:
+ kwargs["dynamic_image_size"] = dynamic_image_size
+
+ return self.ctx.init_processor(
+ SkyworkR1VProcessor,
+ config=self.get_hf_config(),
+ tokenizer=self.get_tokenizer(),
+ **kwargs,
+ )
+
+
+@MULTIMODAL_REGISTRY.register_processor(
+ SkyworkR1VMultiModalProcessor,
+ info=SkyworkR1VProcessingInfo,
+ dummy_inputs=SkyworkR1VDummyInputsBuilder)
+class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP):
+
+ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
+ super().__init__()
+
+ config = vllm_config.model_config.hf_config
+ quant_config = vllm_config.quant_config
+ multimodal_config = vllm_config.model_config.multimodal_config
+
+ self.config = config
+ self.multimodal_config = multimodal_config
+ self._patch_quant_config(config, quant_config)
+
+ image_size = config.force_image_size or config.vision_config.image_size
+ patch_size = config.vision_config.patch_size
+ self.patch_size = patch_size
+ self.num_image_token = int(
+ (image_size // patch_size)**2 * (config.downsample_ratio**2))
+ self.downsample_ratio = config.downsample_ratio
+ self.ps_version = config.ps_version
+
+ self.llm_arch_name = config.text_config.architectures[0]
+ self.is_mono = self.llm_arch_name == 'SkyworkLM2VEForCausalLM'
+ self.vision_model = self._init_vision_model(
+ config,
+ quant_config=quant_config,
+ is_mono=self.is_mono,
+ prefix=maybe_prefix(prefix, "vision_model"),
+ )
+
+ self.language_model = init_vllm_registered_model(
+ vllm_config=vllm_config,
+ hf_config=config.text_config,
+ prefix=maybe_prefix(prefix, "language_model"),
+ )
+
+ self.mlp1 = self._init_mlp1(config)
+
+ self.img_context_token_id = None
+ self.visual_token_mask = None
+ self.make_empty_intermediate_tensors = (
+ self.language_model.make_empty_intermediate_tensors)
+
+ def _patch_quant_config(self, config: PretrainedConfig,
+ quant_config: QuantizationConfig):
+ # the awq models from OpenGVLab missing `modules_to_not_convert`
+ # patch the quant_config to add `modules_to_not_convert` back
+ if isinstance(quant_config, AWQConfig):
+ text_config = config.text_config
+ llm_quant_config = getattr(text_config, "quantization_config",
+ None)
+ if (not quant_config.modules_to_not_convert) and \
+ (llm_quant_config is not None):
+ quant_config.modules_to_not_convert.append("vision_model")
+
+ @cached_property
+ def sampler(self):
+ if hasattr(self.language_model, "sampler"):
+ return self.language_model.sampler
+
+ return get_sampler()
+
+ def _init_vision_model(
+ self,
+ config: PretrainedConfig,
+ quant_config: Optional[QuantizationConfig],
+ *,
+ is_mono: bool,
+ prefix: str,
+ ):
+ if not is_mono:
+ vision_feature_layer = config.select_layer
+ if vision_feature_layer < 0:
+ num_hidden_layers = config.vision_config.num_hidden_layers \
+ + vision_feature_layer + 1
+ else:
+ num_hidden_layers = vision_feature_layer + 1
+
+ return InternVisionModel(
+ config.vision_config,
+ quant_config=quant_config,
+ num_hidden_layers_override=num_hidden_layers,
+ prefix=prefix,
+ )
+ else:
+ return InternVisionPatchModel(config.vision_config)
+
+ def _init_mlp1(self, config: PretrainedConfig) -> nn.Sequential:
+ vit_hidden_size = config.vision_config.hidden_size
+ llm_hidden_size = config.text_config.hidden_size
+
+ return nn.Sequential(
+ nn.LayerNorm(vit_hidden_size * int(1 / self.downsample_ratio)**2),
+ ReplicatedLinear(vit_hidden_size *
+ int(1 / self.downsample_ratio)**2,
+ llm_hidden_size,
+ return_bias=False),
+ nn.GELU(),
+ ReplicatedLinear(llm_hidden_size,
+ llm_hidden_size,
+ return_bias=False),
+ )
+
+ def pixel_shuffle(self, x, scale_factor=0.5):
+ n, w, h, c = x.size()
+ # N, W, H, C --> N, W, H * scale, C // scale
+ x = x.view(n, w, int(h * scale_factor), int(c / scale_factor))
+ # N, W, H * scale, C // scale --> N, H * scale, W, C // scale
+ x = x.permute(0, 2, 1, 3).contiguous()
+ x = x.view(n, int(h * scale_factor), int(w * scale_factor),
+ int(c / (scale_factor * scale_factor)))
+ if self.ps_version == 'v1':
+ pass
+ else:
+ x = x.permute(0, 2, 1, 3).contiguous()
+ return x
+
+ def extract_feature(self, pixel_values: torch.Tensor) -> torch.Tensor:
+ vit_embeds = self.vision_model(pixel_values=pixel_values)
+ vit_embeds = vit_embeds[:, 1:, :]
+
+ h = w = int(vit_embeds.shape[1]**0.5)
+ vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1)
+ vit_embeds = self.pixel_shuffle(vit_embeds,
+ scale_factor=self.downsample_ratio)
+ vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1,
+ vit_embeds.shape[-1])
+ vit_embeds = self.mlp1(vit_embeds)
+ return vit_embeds
+
+ def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
+
+ h = w = self.config.vision_config.image_size
+ expected_dims = (3, h, w)
+
+ def _validate_shape(d: torch.Tensor):
+ actual_dims = tuple(d.shape)
+
+ if actual_dims != expected_dims:
+ expected_expr = str(expected_dims)
+ raise ValueError(
+ "The expected shape of pixel values per image per batch "
+ f" per patch is {expected_expr}. "
+ f"You supplied {tuple(d.shape)}.")
+
+ for d in data:
+ _validate_shape(d)
+
+ return data
+
+ def _parse_and_validate_image_input(
+ self, **kwargs: object) -> Optional[SkyworkR1VImageInputs]:
+ pixel_values_flat = kwargs.pop("pixel_values_flat", None)
+ image_num_patches = kwargs.pop("image_num_patches", None)
+ embed_is_patch = kwargs.pop("embed_is_patch", None)
+ image_embeds = kwargs.pop("image_embeds", None)
+
+ if pixel_values_flat is None and image_embeds is None:
+ return None
+
+ if image_embeds is not None:
+ if not isinstance(image_embeds, (torch.Tensor, list)):
+ raise ValueError("Incorrect type of image embeddings. "
+ f"Got type: {type(image_embeds)}")
+
+ return SkyworkR1VImageEmbeddingInputs(
+ type="image_embeds",
+ data=flatten_bn(image_embeds),
+ )
+
+ image_token_id = kwargs["image_token_id"]
+ assert isinstance(image_token_id, torch.Tensor)
+ self.img_context_token_id = image_token_id.flatten().unique().item()
+
+ if pixel_values_flat is not None:
+ if not isinstance(pixel_values_flat, (torch.Tensor, list)):
+ raise ValueError("Incorrect type of pixel values. "
+ f"Got type: {type(pixel_values_flat)}")
+
+ if not isinstance(image_num_patches, (torch.Tensor, list)):
+ raise ValueError("Incorrect type of image_num_patches. "
+ f"Got type: {type(image_num_patches)}")
+
+ if not isinstance(embed_is_patch, (torch.Tensor, list)):
+ raise ValueError("Incorrect type of embed_is_patch. "
+ f"Got type: {type(embed_is_patch)}")
+
+ pixel_values_flat = flatten_bn(pixel_values_flat, concat=True)
+ image_num_patches = flatten_bn(image_num_patches, concat=True)
+ embed_is_patch = flatten_bn(embed_is_patch)
+
+ return SkyworkR1VImagePixelInputs(
+ type="pixel_values",
+ pixel_values_flat=self._validate_pixel_values(
+ pixel_values_flat),
+ num_patches=image_num_patches,
+ embed_is_patch=embed_is_patch,
+ )
+
+ raise AssertionError("This line should be unreachable.")
+
+ def _process_image_input(
+ self,
+ image_input: SkyworkR1VImageInputs,
+ ) -> Union[torch.Tensor, list[torch.Tensor], tuple[torch.Tensor, ...]]:
+ if image_input["type"] == "image_embeds":
+ return image_input["data"]
+
+ assert self.vision_model is not None
+
+ image_embeds = self.extract_feature(image_input["pixel_values_flat"])
+
+ num_patches = image_input["num_patches"]
+
+ # Only one image in the current batch
+ if len(num_patches) == 1:
+ return image_embeds.view(
+ -1, self.config.text_config.hidden_size).unsqueeze(0)
+
+ # NOTE: Image embeddings are split into separate tensors for each image
+ # by the size of each embedding.
+ feature_size = image_embeds.shape[1]
+ image_embeds = image_embeds.view(-1,
+ self.config.text_config.hidden_size)
+ image_feature_sizes = [
+ num_patches * feature_size for num_patches in num_patches
+ ]
+ return image_embeds.split(image_feature_sizes)
+
+ def _set_visual_token_mask(self, input_ids: torch.Tensor) -> None:
+ if self.is_mono:
+ self.visual_token_mask = (
+ input_ids == self.img_context_token_id).reshape(-1, 1)
+ else:
+ self.visual_token_mask = None
+
+ def get_multimodal_embeddings(
+ self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
+ image_input = self._parse_and_validate_image_input(**kwargs)
+ if image_input is None:
+ return None
+
+ image_features = self._process_image_input(image_input)
+
+ if image_input["type"] != "pixel_values":
+ return image_features
+
+ return scatter_patch_features(
+ image_features,
+ image_input["embed_is_patch"],
+ )
+
+ def get_input_embeddings(
+ self,
+ input_ids: torch.Tensor,
+ multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
+ ) -> torch.Tensor:
+ inputs_embeds = self.language_model.get_input_embeddings(input_ids)
+ if multimodal_embeddings is not None:
+ assert self.img_context_token_id is not None
+ self._set_visual_token_mask(input_ids)
+ inputs_embeds = merge_multimodal_embeddings(
+ input_ids,
+ inputs_embeds,
+ select_patch_features(multimodal_embeddings),
+ self.img_context_token_id,
+ )
+ return inputs_embeds
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ positions: torch.Tensor,
+ intermediate_tensors: Optional[IntermediateTensors] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ **kwargs: object,
+ ) -> Union[SamplerOutput, IntermediateTensors]:
+
+ if intermediate_tensors is not None:
+ input_ids = None
+ inputs_embeds = None
+
+ # NOTE: In v1, inputs_embeds is always generated at model runner, this
+ # condition is for v0 compatibility.
+ elif inputs_embeds is None:
+ vision_embeddings = self.get_multimodal_embeddings(**kwargs)
+ inputs_embeds = self.get_input_embeddings(input_ids,
+ vision_embeddings)
+ input_ids = None
+
+ forward_kwargs = {
+ "input_ids": input_ids,
+ "positions": positions,
+ "intermediate_tensors": intermediate_tensors,
+ "inputs_embeds": inputs_embeds,
+ }
+
+ # Only required if the model is mono-architecture
+ if self.visual_token_mask is not None:
+ forward_kwargs.update(
+ {"visual_token_mask": self.visual_token_mask})
+ self.visual_token_mask = None
+
+ hidden_states = self.language_model.model(**forward_kwargs)
+ return hidden_states
+
+ def compute_logits(
+ self,
+ hidden_states: torch.Tensor,
+ sampling_metadata: SamplingMetadata,
+ ) -> Optional[torch.Tensor]:
+ return self.language_model.compute_logits(hidden_states,
+ sampling_metadata)
+
+ def sample(
+ self,
+ logits: torch.Tensor,
+ sampling_metadata: SamplingMetadata,
+ ) -> Optional[SamplerOutput]:
+ return self.language_model.sample(logits, sampling_metadata)
+
+ def load_weights(self, weights: Iterable[Tuple[str,
+ torch.Tensor]]) -> Set[str]:
+ skip_prefixes = [
+ "action_embed", "temporal_embed", "track_embed",
+ "track_embed_decoder", "box_token", "cg_criterion", "cg_model",
+ "loc_encoder", "loc_decoder", "sam", "temporal_token",
+ "track_token"
+ ]
+ loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes)
+ return loader.load_weights(weights)
diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py
index 1937b1388471..71990468c315 100644
--- a/vllm/transformers_utils/config.py
+++ b/vllm/transformers_utils/config.py
@@ -37,8 +37,8 @@
MLPSpeculatorConfig, MPTConfig,
NemotronConfig, NVLM_D_Config,
Olmo2Config, RWConfig,
- SolarConfig, Telechat2Config,
- UltravoxConfig)
+ SkyworkR1VChatConfig, SolarConfig,
+ Telechat2Config, UltravoxConfig)
# yapf: enable
from vllm.transformers_utils.utils import check_gguf_file
from vllm.utils import resolve_obj_by_qualname
@@ -76,6 +76,7 @@
"NVLM_D": NVLM_D_Config,
"olmo2": Olmo2Config,
"solar": SolarConfig,
+ "skywork_chat": SkyworkR1VChatConfig,
"telechat": Telechat2Config,
"ultravox": UltravoxConfig,
**_CONFIG_REGISTRY_OVERRIDE_HF
diff --git a/vllm/transformers_utils/configs/__init__.py b/vllm/transformers_utils/configs/__init__.py
index 9060565596b2..53699341bfba 100644
--- a/vllm/transformers_utils/configs/__init__.py
+++ b/vllm/transformers_utils/configs/__init__.py
@@ -20,6 +20,7 @@
from vllm.transformers_utils.configs.nemotron import NemotronConfig
from vllm.transformers_utils.configs.nvlm_d import NVLM_D_Config
from vllm.transformers_utils.configs.olmo2 import Olmo2Config
+from vllm.transformers_utils.configs.skyworkr1v import SkyworkR1VChatConfig
from vllm.transformers_utils.configs.solar import SolarConfig
from vllm.transformers_utils.configs.telechat2 import Telechat2Config
from vllm.transformers_utils.configs.ultravox import UltravoxConfig
@@ -42,6 +43,7 @@
"NemotronConfig",
"NVLM_D_Config",
"Olmo2Config",
+ "SkyworkR1VChatConfig",
"SolarConfig",
"Telechat2Config",
"UltravoxConfig",
diff --git a/vllm/transformers_utils/configs/skyworkr1v.py b/vllm/transformers_utils/configs/skyworkr1v.py
new file mode 100644
index 000000000000..ef5f9ba85c23
--- /dev/null
+++ b/vllm/transformers_utils/configs/skyworkr1v.py
@@ -0,0 +1,53 @@
+# SPDX-License-Identifier: Apache-2.0
+
+# Adapted from
+# https://huggingface.co/Skywork/Skywork-R1V-38B/blob/main/configuration_skywork_chat.py
+# --------------------------------------------------------
+# SkyworkR1V
+# Copyright (c) 2025 Skywork
+# Licensed under The MIT License [see LICENSE for details]
+# --------------------------------------------------------
+from transformers.configuration_utils import PretrainedConfig
+
+
+class SkyworkR1VChatConfig(PretrainedConfig):
+ model_type = 'internvl_chat'
+ is_composition = True
+
+ def __init__(self,
+ vision_config=None,
+ llm_config=None,
+ use_backbone_lora=0,
+ use_llm_lora=0,
+ select_layer=-1,
+ force_image_size=None,
+ downsample_ratio=0.5,
+ template=None,
+ dynamic_image_size=False,
+ use_thumbnail=False,
+ ps_version='v1',
+ min_dynamic_patch=1,
+ max_dynamic_patch=6,
+ **kwargs):
+ super().__init__(**kwargs)
+
+ if vision_config is None:
+ vision_config = {}
+
+ if llm_config is None:
+ llm_config = {}
+
+ self.vision_config = PretrainedConfig(**vision_config)
+ self.text_config = PretrainedConfig(**llm_config)
+
+ self.use_backbone_lora = use_backbone_lora
+ self.use_llm_lora = use_llm_lora
+ self.select_layer = select_layer
+ self.force_image_size = force_image_size
+ self.downsample_ratio = downsample_ratio
+ self.template = template
+ self.dynamic_image_size = dynamic_image_size
+ self.use_thumbnail = use_thumbnail
+ self.ps_version = ps_version # pixel shuffle version
+ self.min_dynamic_patch = min_dynamic_patch
+ self.max_dynamic_patch = max_dynamic_patch