diff --git a/tests/models/decoder_only/vision_language/test_models.py b/tests/models/decoder_only/vision_language/test_models.py index 9985cb579e10..6073364c0199 100644 --- a/tests/models/decoder_only/vision_language/test_models.py +++ b/tests/models/decoder_only/vision_language/test_models.py @@ -446,6 +446,19 @@ hf_output_post_proc=model_utils.minicpmv_trunc_hf_output, patch_hf_runner=model_utils.minicpmv_26_patch_hf_runner, ), + "minimax_vl_01": VLMTestInfo( + models=["MiniMaxAI/MiniMax-VL-01"], + prompt_formatter=lambda img_prompt: f"user: {img_prompt} assistant:", # noqa: E501 + img_idx_to_prompt=lambda _: "", + test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE), + max_model_len=8192, + max_num_seqs=4, + dtype="bfloat16", + hf_output_post_proc=model_utils.minimax_vl_01_hf_output, + patch_hf_runner=model_utils.minimax_vl_01_patch_hf_runner, + auto_cls=AutoModelForImageTextToText, + marks=[large_gpu_mark(min_gb=80)], + ), "molmo": VLMTestInfo( models=["allenai/Molmo-7B-D-0924"], test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE), diff --git a/tests/models/decoder_only/vision_language/vlm_utils/model_utils.py b/tests/models/decoder_only/vision_language/vlm_utils/model_utils.py index 49305332726e..1185d80b97e3 100644 --- a/tests/models/decoder_only/vision_language/vlm_utils/model_utils.py +++ b/tests/models/decoder_only/vision_language/vlm_utils/model_utils.py @@ -229,6 +229,14 @@ def minicpmv_trunc_hf_output(hf_output: RunnerOutput, return output_ids, output_str, out_logprobs +def minimax_vl_01_hf_output(hf_output: RunnerOutput, + model: str) -> RunnerOutput: + output_ids, output_str, out_logprobs = hf_output + if output_str.endswith(""): + output_str = output_str.split("")[0] + return output_ids, output_str, out_logprobs + + ####### Functions for converting image assets to embeddings def get_llava_embeddings(image_assets: _ImageAssets): return [asset.image_embeds for asset in image_assets] @@ -627,6 +635,17 @@ def _generate(self, *args, image_sizes=None, **kwargs): return hf_model +def minimax_vl_01_patch_hf_runner(hf_model: HfRunner) -> HfRunner: + orig_generate = hf_model.model.generate + + def _generate(self, *args, image_sizes=None, **kwargs): + return orig_generate(*args, decode_text=False, **kwargs) + + hf_model.model.generate = types.MethodType(_generate, hf_model.model) + + return hf_model + + def molmo_patch_hf_runner(hf_model: HfRunner) -> HfRunner: """Patches and returns an instance of the HfRunner to use for Molmo.""" hf_processor = hf_model.processor diff --git a/tests/models/multimodal/processing/test_minimax_vl_01.py b/tests/models/multimodal/processing/test_minimax_vl_01.py new file mode 100644 index 000000000000..d333c32dcaf6 --- /dev/null +++ b/tests/models/multimodal/processing/test_minimax_vl_01.py @@ -0,0 +1,99 @@ +# SPDX-License-Identifier: Apache-2.0 + +import pytest +from PIL import Image + +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.parse import ImageSize +from vllm.multimodal.processing import BaseMultiModalProcessor + +from ....conftest import _ImageAssets +from ...utils import build_model_context + + +@pytest.mark.parametrize("model_id", ["MiniMaxAI/MiniMax-VL-01"]) +# yapf: enable +@pytest.mark.parametrize("num_imgs", [1, 2]) +def test_processor_override( + image_assets: _ImageAssets, + model_id: str, + num_imgs: int, +): + ctx = build_model_context( + model_id, + mm_processor_kwargs=None, + limit_mm_per_prompt={"image": num_imgs}, + ) + processor = MULTIMODAL_REGISTRY.create_processor(ctx.model_config) + prompt = "" * num_imgs + image = Image.new("RGB", size=(364, 364)) + mm_data = {"image": [image] * num_imgs} + + processed_inputs = processor.apply(prompt, mm_data, {}) + image_placeholders = processed_inputs["mm_placeholders"]["image"] + + assert len(image_placeholders) == num_imgs + + +def _validate_image_prompt_replacements_one( + processor: BaseMultiModalProcessor, + num_imgs: int, + failed_size_excs: list[tuple[ImageSize, Exception]], + image_size: ImageSize, +) -> None: + prompt = "" * num_imgs + image = Image.new("RGB", size=image_size) + mm_data = {"image": [image] * num_imgs} + + try: + processed_inputs = processor.apply(prompt, mm_data, {}) + + image_placeholders = processed_inputs["mm_placeholders"]["image"] + assert len(image_placeholders) == num_imgs + + except Exception as exc: + failed_size_excs.append((image_size, exc)) + + +def _test_image_prompt_replacements( + processor, + *, + num_imgs: int, + image_sizes: list[ImageSize], +) -> None: + + failed_size_excs = list[tuple[ImageSize, Exception]]() + + for size in image_sizes: + _validate_image_prompt_replacements_one(processor, num_imgs, + failed_size_excs, size) + + if failed_size_excs: + msg = "Found failing image sizes:" \ + + "\n========\n".join(f"[{size}]\n{exc}" + for size, exc in failed_size_excs) + raise AssertionError(msg) + + +@pytest.mark.parametrize("model_id", ["MiniMaxAI/MiniMax-VL-01"]) +@pytest.mark.parametrize("num_imgs", [1, 2]) +def test_processor_prompt_replacements_regression(model_id, num_imgs): + ctx = build_model_context( + model_id, + mm_processor_kwargs=None, + limit_mm_per_prompt={"image": num_imgs}, + ) + processor = MULTIMODAL_REGISTRY.create_processor(ctx.model_config) + + image_ratios = [(171, 152), (184, 161), (198, 176), (333, 296), (369, 328), + (488, 183), (2560, 1669)] + image_sizes = [ + size for w, h in image_ratios + for size in [ImageSize(w, h), ImageSize(h, w)] + ] + + _test_image_prompt_replacements( + processor, + num_imgs=num_imgs, + image_sizes=image_sizes, + ) diff --git a/tests/models/registry.py b/tests/models/registry.py index c15ae3619844..142362e09680 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -333,6 +333,8 @@ def check_available_online( "MiniCPMV": _HfExamplesInfo("openbmb/MiniCPM-Llama3-V-2_5", extras={"2.6": "openbmb/MiniCPM-V-2_6"}, # noqa: E501 trust_remote_code=True), + "MiniMaxVL01ForConditionalGeneration": _HfExamplesInfo("MiniMaxAI/MiniMax-VL-01", # noqa: E501 + trust_remote_code=True), "Mistral3ForConditionalGeneration": _HfExamplesInfo("mistralai/Mistral-Small-3.1-24B-Instruct-2503", # noqa: E501 min_transformers_version="4.50", # noqa: E501 extras={"fp8": "nm-testing/Mistral-Small-3.1-24B-Instruct-2503-FP8-dynamic"}), # noqa: E501 diff --git a/vllm/model_executor/models/minimax_text_01.py b/vllm/model_executor/models/minimax_text_01.py index 74be08159cd8..951f4e2304a1 100644 --- a/vllm/model_executor/models/minimax_text_01.py +++ b/vllm/model_executor/models/minimax_text_01.py @@ -3,7 +3,7 @@ import copy import math import re -from typing import Dict, Iterable, List, Optional, Tuple, Union +from typing import Dict, Iterable, List, Optional, Set, Tuple, Union import torch import torch.distributed @@ -110,7 +110,17 @@ def _forward( variance = tensor_model_parallel_all_reduce( variance) / self.tp_world x = x * torch.rsqrt(variance + self.variance_epsilon) - x = x.to(orig_dtype) * self.weight + + weight = self.weight + if x.size(-1) != self.weight.size(0): + if self.weight.size(0) < x.size(-1): + repeat_count = (x.size(-1) + self.weight.size(0)) // x.size(-1) + full_weight = self.weight.repeat(repeat_count) + weight = full_weight[:x.size(-1)] + else: + weight = self.weight[:x.size(-1)] + + x = x.to(orig_dtype) * weight return x def forward( @@ -421,6 +431,10 @@ def _prefill_and_mix_infer(self, q, k, v, kv_cache, state_indices_tensor, attn_metadata): hidden = [] for _prefill_idx in range(getattr(attn_metadata, "num_prefills", 0)): + if _prefill_idx >= len(attn_metadata.query_start_loc): + break + if _prefill_idx >= len(state_indices_tensor): + break _start = attn_metadata.query_start_loc[_prefill_idx] _end = attn_metadata.query_start_loc[_prefill_idx + 1] slot_id = state_indices_tensor[_prefill_idx] @@ -443,6 +457,10 @@ def _prefill_and_mix_infer(self, q, k, v, kv_cache, state_indices_tensor, hidden.append( self._decode_infer(q, k, v, kv_cache, state_indices_tensor, attn_metadata)) + + if not hidden: + return torch.empty((0, q.size(-1)), device=q.device, dtype=q.dtype) + hidden = torch.concat(hidden, dim=0).contiguous() return hidden @@ -663,6 +681,9 @@ def __init__( self.shared_moe = False shared_intermediate = getattr(config, 'shared_intermediate_size', 0) + if isinstance(shared_intermediate, list): + shared_intermediate = shared_intermediate[ + layer_id] if layer_id < len(shared_intermediate) else 0 if shared_intermediate > 0: self.shared_moe = True self.shared_mlp = MiniMaxText01MLP( @@ -875,6 +896,8 @@ def _clear_prefill_cache(self, attn_metadata, slots_to_clear = [] for _prefill_id in range(getattr(attn_metadata, "num_prefills", 0)): + if _prefill_id >= len(seq_id_map): + break seq_id = seq_id_map[_prefill_id] if attn_metadata.context_lens_tensor[ _prefill_id] == 0 and seq_id in seq_to_slot_maps: @@ -886,13 +909,18 @@ def _clear_prefill_cache(self, attn_metadata, dtype=torch.long) minimax_cache_tensors[:, slots_tensor, ...] = 0 + def get_input_embeddings( + self, + input_ids: torch.Tensor, + ) -> torch.Tensor: + return self.embed_tokens(input_ids) + def forward(self, input_ids: Optional[torch.Tensor], positions: torch.Tensor, - kv_caches: List[torch.Tensor], - intermediate_tensors=None, + intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, - **kwargs) -> torch.Tensor: + **kwargs) -> Union[torch.Tensor, IntermediateTensors]: forward_context = get_forward_context() attn_metadata = forward_context.attn_metadata if attn_metadata is None: @@ -901,6 +929,7 @@ def forward(self, kwargs["request_ids_to_seq_ids"] = {} if "finished_requests_ids" not in kwargs: kwargs["finished_requests_ids"] = [] + ( minimax_cache_tensors, state_indices_tensor, @@ -922,15 +951,11 @@ def forward(self, hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - kv_cache_index = 0 minimax_cache_index = 0 attn_metadata.rotary_emb = self.rotary_emb for i in range(self.start_layer, self.end_layer): layer = self.layers[i] _caches = None - if isinstance(layer.self_attn, MiniMaxText01Attention): - _caches = kv_caches[kv_cache_index] - kv_cache_index += 1 if isinstance(layer.self_attn, MiniMaxText01LinearAttention): current_state_layer = minimax_cache_index _caches = minimax_cache_params.at_layer_idx( @@ -1009,15 +1034,20 @@ def get_seqlen_agnostic_capture_inputs(self, batch_size: int): return self.model.minimax_cache.get_seqlen_agnostic_capture_inputs( batch_size) + def get_input_embeddings( + self, + input_ids: torch.Tensor, + ) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + def forward(self, input_ids: torch.Tensor, positions: torch.Tensor, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor: - hidden_states = self.model(input_ids, positions, self.kv_cache, - intermediate_tensors, inputs_embeds, - **kwargs) + hidden_states = self.model(input_ids, positions, intermediate_tensors, + inputs_embeds, **kwargs) return hidden_states @@ -1043,8 +1073,9 @@ def make_empty_intermediate_tensors( }) def load_weights(self, weights: Iterable[Tuple[str, - torch.Tensor]]) -> None: + torch.Tensor]]) -> Set[str]: params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() def which_layer(name: str) -> int: if "layers" in name: @@ -1108,6 +1139,7 @@ def load_sparse_moe_weight(name: str, loaded_weight: torch.Tensor, weight_name, expert_id=expert_id, shard_id=shard_id) + loaded_params.add(name) break else: if is_pp_missing_parameter(name, self): @@ -1117,6 +1149,7 @@ def load_sparse_moe_weight(name: str, loaded_weight: torch.Tensor, default_weight_loader) weight_loader = weight_loader_with_alias(name)(weight_loader) weight_loader(param, loaded_weight) + loaded_params.add(name) return def is_shared_mlp_weight(name: str) -> bool: @@ -1154,6 +1187,7 @@ def load_shared_mlp_weight(name: str, loaded_weight: torch.Tensor, else: raise AssertionError( "MLP weight not in [gate_up_proj, down_proj]") + loaded_params.add(name) return def is_mha_weight(name: str) -> bool: @@ -1170,6 +1204,7 @@ def load_linear_attn_weight(name: str, loaded_weight: torch.Tensor, MiniMaxText01LinearAttention.weight_direct_load) weight_loader = weight_loader_with_alias(name)(weight_loader) weight_loader(param, loaded_weight) + loaded_params.add(name) return def load_flash_attn_weight(name: str, loaded_weight: torch.Tensor, @@ -1194,6 +1229,7 @@ def load_flash_attn_weight(name: str, loaded_weight: torch.Tensor, default_weight_loader) weight_loader = weight_loader_with_alias(name)(weight_loader) weight_loader(param, loaded_weight, shard_id) + loaded_params.add(name) break else: if is_pp_missing_parameter(name, self): @@ -1204,6 +1240,7 @@ def load_flash_attn_weight(name: str, loaded_weight: torch.Tensor, default_weight_loader) weight_loader = weight_loader_with_alias(name)(weight_loader) weight_loader(param, loaded_weight) + loaded_params.add(name) return def is_layer_norm_weight(name: str) -> bool: @@ -1219,6 +1256,7 @@ def load_layer_norm_weight(name: str, loaded_weight: torch.Tensor, default_weight_loader) weight_loader = weight_loader_with_alias(name)(weight_loader) weight_loader(param, loaded_weight) + loaded_params.add(name) return def load_basic_weight(name: str, loaded_weight: torch.Tensor, @@ -1230,6 +1268,7 @@ def load_basic_weight(name: str, loaded_weight: torch.Tensor, default_weight_loader) weight_loader = weight_loader_with_alias(name)(weight_loader) weight_loader(param, loaded_weight) + loaded_params.add(name) return for name, loaded_weight in weights: @@ -1258,4 +1297,4 @@ def load_basic_weight(name: str, loaded_weight: torch.Tensor, continue load_basic_weight(name, loaded_weight, self) - return + return loaded_params diff --git a/vllm/model_executor/models/minimax_vl_01.py b/vllm/model_executor/models/minimax_vl_01.py new file mode 100644 index 000000000000..14e105586b56 --- /dev/null +++ b/vllm/model_executor/models/minimax_vl_01.py @@ -0,0 +1,615 @@ +# SPDX-License-Identifier: Apache-2.0 + +from abc import abstractmethod +from collections.abc import Iterable, Mapping, Sequence +from dataclasses import dataclass +from typing import (Final, Literal, Optional, Protocol, Set, Tuple, TypedDict, + TypeVar, Union, cast) + +import numpy as np +import torch +import torch.nn as nn +from transformers import BatchFeature, CLIPVisionConfig, PretrainedConfig +from transformers.image_processing_utils import select_best_resolution + +from vllm.config import VllmConfig +from vllm.jsontree import json_map_leaves +from vllm.logger import init_logger +from vllm.model_executor.layers.activation import get_act_fn +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalDataDict +from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargs +from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems, + ImageSize, MultiModalDataItems) +from vllm.multimodal.processing import (BaseMultiModalProcessor, + BaseProcessingInfo, PromptReplacement, + PromptUpdate) +from vllm.multimodal.profiling import BaseDummyInputsBuilder +from vllm.sequence import IntermediateTensors +from vllm.transformers_utils.configs.minimax_vl_01 import MiniMaxVL01Config + +from .clip import CLIPVisionModel +from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP +from .pixtral import PixtralHFVisionModel +from .siglip import SiglipVisionModel +from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model, + maybe_prefix, merge_multimodal_embeddings) +from .vision import get_vision_encoder_info + +logger = init_logger(__name__) + + +# For dummy input only +@dataclass +class MaxImageTokenMeta: + width: int = 1024 + height: int = 1024 + + +class MiniMaxVL01ImagePixelInputs(TypedDict): + type: Literal["pixel_values"] + pixel_values: torch.Tensor + """ + Shape: `(batch_size * num_images, num_channels, height, width)` + + Note that `height` or `width` may be different per batch and image, + in which case the data is passed as a list instead of a batched tensor. + """ + + +class MiniMaxVL01ImageEmbeddingInputs(TypedDict): + type: Literal["image_embeds"] + data: torch.Tensor + """Shape: `(batch_size * num_images, image_feature_size, hidden_size)` + + `hidden_size` must match the hidden size of language model backbone. + """ + + +def image_size_to_num_patches(image_size, grid_pinpoints, patch_size: int): + if not isinstance(grid_pinpoints, list): + raise TypeError("grid_pinpoints should be a list of tuples or lists") + + # ! VERY IMPORTANT if image_size is tensor, must convert to into tuple, + # otherwise it will cause wrong calculate + if not isinstance(image_size, (list, tuple)): + if not isinstance(image_size, (torch.Tensor, np.ndarray)): + raise TypeError("image_size invalid type " + + f"{type(image_size)} with value {image_size}") + image_size = image_size.tolist() + + best_resolution = select_best_resolution(image_size, grid_pinpoints) + height, width = best_resolution + num_patches = 0 + # consider change to ceil(height/patch_size)*ceil(width/patch_size) + 1 + for i in range(0, height, patch_size): + for j in range(0, width, patch_size): + num_patches += 1 + # add the base patch + num_patches += 1 + return num_patches + + +def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size): + if not isinstance(grid_pinpoints, list): + raise TypeError("grid_pinpoints should be a list of tuples or lists") + + # ! VERY IMPORTANT if image_size is tensor, + # must convert to into tuple, + # otherwise it will cause wrong calculate + if not isinstance(image_size, (list, tuple)): + if not isinstance(image_size, (torch.Tensor, np.ndarray)): + raise TypeError( + "image_size invalid type " + + f"{type(image_size)} not valid, " + + "should be either list, tuple, np.ndarray or tensor") + image_size = image_size.tolist() + + height, width = select_best_resolution(image_size, grid_pinpoints) + return height // patch_size, width // patch_size + + +def unpad_image(tensor, original_size): + original_height, original_width = original_size + current_height, current_width = tensor.shape[1:] + + original_aspect_ratio = original_width / original_height + current_aspect_ratio = current_width / current_height + + if original_aspect_ratio > current_aspect_ratio: + new_height = int(original_height * current_width) // original_width + padding = (current_height - new_height) // 2 + unpadded_tensor = tensor[:, padding:current_height - padding, :] + else: + new_width = int(original_width * current_height) // original_height + padding = (current_width - new_width) // 2 + unpadded_tensor = tensor[:, :, padding:current_width - padding] + + return unpadded_tensor + + +class MiniMaxVL01MultiModalProjector(nn.Module): + + def __init__(self, + vision_hidden_size: int, + text_hidden_size: int, + projector_hidden_act: str, + multimodal_projector_bias: bool, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = ""): + super().__init__() + + self.linear_1 = ColumnParallelLinear(vision_hidden_size, + text_hidden_size, + bias=multimodal_projector_bias, + quant_config=quant_config, + prefix=f"{prefix}.linear_1") + self.act = get_act_fn(projector_hidden_act) + self.linear_2 = RowParallelLinear(text_hidden_size, + text_hidden_size, + bias=multimodal_projector_bias, + quant_config=quant_config, + prefix=f"{prefix}.linear_2") + + def forward(self, image_features: torch.Tensor) -> torch.Tensor: + hidden_states, _ = self.linear_1(image_features) + hidden_states = self.act(hidden_states) + hidden_states, _ = self.linear_2(hidden_states) + return hidden_states + + +class MiniMaxVL01LikeConfig(Protocol): + vision_config: Final[PretrainedConfig] + image_token_index: Final[int] + vision_feature_select_strategy: Final[str] + vision_feature_layer: Final[Union[int, list[int]]] + + +class MiniMaxVL01LikeProcessor(Protocol): + image_token: Final[str] + + +_I = TypeVar("_I", bound=BaseProcessingInfo) + + +class MiniMaxVL01DummyInputsBuilder(BaseDummyInputsBuilder[_I]): + + def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: + num_images = mm_counts.get("image", 0) + processor = self.info.get_hf_processor() + image_token = processor.image_token + return image_token * num_images + + def get_dummy_mm_data( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> MultiModalDataDict: + num_images = mm_counts.get("image", 0) + + return { + "image": + self._get_dummy_images(width=MaxImageTokenMeta.width, + height=MaxImageTokenMeta.height, + num_images=num_images) + } + + +class MiniMaxVL01ProcessingInfo(BaseProcessingInfo): + + def get_hf_config(self): + return self.ctx.get_hf_config(MiniMaxVL01Config) + + def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + return {"image": None} + + def get_vision_encoder_info(self): + return get_vision_encoder_info(self.get_hf_config()) + + def _apply_feature_select_strategy( + self, + strategy: str, + encoder_num_image_tokens: int, + ) -> int: + if strategy == "default": + return encoder_num_image_tokens - 1 + if strategy == "full": + return encoder_num_image_tokens + + msg = f"Unexpected feature select strategy: {strategy!r}" + raise NotImplementedError(msg) + + def get_num_image_tokens( + self, + *, + image_width: int, + image_height: int, + ) -> int: + hf_config = self.get_hf_config() + vision_encoder_info = self.get_vision_encoder_info() + + return self._apply_feature_select_strategy( + hf_config.vision_feature_select_strategy, + vision_encoder_info.get_num_image_tokens( + image_width=image_width, + image_height=image_height, + ), + ) + + def get_image_size_with_most_features(self) -> ImageSize: + vision_encoder_info = self.get_vision_encoder_info() + width = height = vision_encoder_info.get_image_size() + return ImageSize(width=width, height=height) + + def get_max_image_tokens(self) -> int: + target_width, target_height = self.get_image_size_with_most_features() + + return self.get_num_image_tokens( + image_width=target_width, + image_height=target_height, + ) + + +class BaseMiniMaxVL01MultiModalProcessor(BaseMultiModalProcessor[_I]): + + # Copied from BaseMultiModalProcessor + @abstractmethod + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + raise NotImplementedError + + def _get_prompt_updates( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + out_mm_kwargs: MultiModalKwargs, + ) -> Sequence[PromptUpdate]: + hf_config = self.info.get_hf_config() + image_token_id = hf_config.image_token_index + + def get_replacement(item_idx: int): + images = mm_items.get_items( + "image", (ImageEmbeddingItems, ImageProcessorItems)) + + if isinstance(images, ImageEmbeddingItems): + num_image_tokens = images.get_feature_size(item_idx) + else: + image_size = images.get_image_size(item_idx) + num_image_tokens = self.info.get_num_image_tokens( + image_width=image_size.width, + image_height=image_size.height, + ) + + return [image_token_id] * num_image_tokens + + return [ + PromptReplacement( + modality="image", + target=[image_token_id], + replacement=get_replacement, + ), + ] + + +class MiniMaxVL01MultiModalProcessor( + BaseMiniMaxVL01MultiModalProcessor[MiniMaxVL01ProcessingInfo]): + + def _call_hf_processor( + self, + prompt: str, + mm_data: Mapping[str, object], + mm_kwargs: Mapping[str, object], + ) -> BatchFeature: + processed_outputs = super()._call_hf_processor( + prompt=prompt, + mm_data=mm_data, + mm_kwargs=mm_kwargs, + ) + + pixel_values = processed_outputs.get("pixel_values") + if pixel_values is not None: + image_sizes = processed_outputs["image_sizes"] + min_len = min(len(pixel_values), len(image_sizes)) + pixel_values = pixel_values[:min_len] + image_sizes = image_sizes[:min_len] + assert len(pixel_values) == len(image_sizes) + + processed_outputs["pixel_values"] = [ + p[:, :h, :w] for p, (h, w) in zip(pixel_values, image_sizes) + ] + + return processed_outputs + + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + return { + "pixel_values": MultiModalFieldConfig.batched("image"), + "image_embeds": MultiModalFieldConfig.batched("image"), + } + + +def _get_num_hidden_layers(hf_config: MiniMaxVL01LikeConfig) -> int: + """Determine the number of hidden layers to initialize up to in the + visual encoder. + + Args: + hf_config: Model config with vision feature layer(s). + """ + feature_layers = hf_config.vision_feature_layer + num_hidden_layers = hf_config.vision_config.num_hidden_layers + # If we have one feature layer, initialize up to that layer + if isinstance(feature_layers, int): + return _get_layer_index(feature_layers, num_hidden_layers) + # If we have multiple feature layers, initialize up to the deepest one + elif isinstance(feature_layers, (list, tuple)): + return max( + _get_layer_index(idx, num_hidden_layers) for idx in feature_layers) + raise TypeError(f"vision_layer_feature type: {type(feature_layers)}" + " is not supported") + + +def _get_layer_index(feature_layer_index: int, num_hidden_layers: int) -> int: + """Given a signed vision feature layer, get the number of hidden layers + needed to leverage it. + + Args: + feature_layer_index: Index of a required layer in the visual encoder. + num_hidden_layers: The total number of hidden layers in the visual + encoder. + """ + if feature_layer_index < 0: + return num_hidden_layers + feature_layer_index + 1 + return feature_layer_index + + +def init_vision_tower_for_MiniMaxVL01( + hf_config: MiniMaxVL01LikeConfig, + quant_config: Optional[QuantizationConfig], + *, + require_post_norm: Optional[bool] = None, + prefix: str = "", +) -> Union[CLIPVisionModel, SiglipVisionModel, PixtralHFVisionModel]: + vision_config = hf_config.vision_config + + # Initialize the vision tower only up to the deepest required feature layer + num_hidden_layers = _get_num_hidden_layers(hf_config) + + if isinstance(vision_config, CLIPVisionConfig): + return CLIPVisionModel( + vision_config, + quant_config=quant_config, + num_hidden_layers_override=num_hidden_layers, + require_post_norm=require_post_norm, + prefix=prefix, + ) + + msg = f"Unsupported vision config: {type(vision_config)}" + raise NotImplementedError(msg) + + +@MULTIMODAL_REGISTRY.register_processor( + MiniMaxVL01MultiModalProcessor, + info=MiniMaxVL01ProcessingInfo, + dummy_inputs=MiniMaxVL01DummyInputsBuilder) +class MiniMaxVL01ForConditionalGeneration(nn.Module, SupportsMultiModal, + SupportsPP): + + packed_modules_mapping = { + "qkv_proj": ["q_proj", "k_proj", "v_proj"], + "gate_up_proj": ["gate_proj", "up_proj"] + } + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: + super().__init__() + + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + multimodal_config = vllm_config.model_config.multimodal_config + + self.config = config + self.multimodal_config = multimodal_config + + # TODO: Optionally initializes this for supporting embeddings. + self.vision_tower = init_vision_tower_for_MiniMaxVL01( + config, + quant_config, + require_post_norm=False, + prefix=maybe_prefix(prefix, "vision_tower")) + self.multi_modal_projector = MiniMaxVL01MultiModalProjector( + vision_hidden_size=config.vision_config.hidden_size, + text_hidden_size=config.text_config.hidden_size, + projector_hidden_act=config.projector_hidden_act, + multimodal_projector_bias=True, + quant_config=quant_config, + prefix=maybe_prefix(prefix, "multi_modal_projector")) + self.image_newline = nn.Parameter( + torch.empty(config.text_config.hidden_size)) + self.language_model = init_vllm_registered_model( + vllm_config=vllm_config, + hf_config=config.text_config, + prefix=maybe_prefix(prefix, "language_model"), + ) + self.vision_feature_layer = config.vision_feature_layer + self.vocab_size = config.text_config.vocab_size + self.pad_token_id = -1 + if self.config.pad_token_id is not None: + self.pad_token_id = self.config.pad_token_id + + self.make_empty_intermediate_tensors = ( + self.language_model.make_empty_intermediate_tensors) + + def get_input_embeddings( + self, + input_ids: torch.Tensor, + multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + ) -> torch.Tensor: + inputs_embeds = self.language_model.get_input_embeddings(input_ids) + if multimodal_embeddings is not None: + inputs_embeds = merge_multimodal_embeddings( + input_ids, + inputs_embeds, + multimodal_embeddings, + self.config.image_token_index, + ) + return inputs_embeds + + def get_language_model(self) -> torch.nn.Module: + return self.language_model + + def _select_image_features(self, image_features: torch.Tensor, *, + strategy: str) -> torch.Tensor: + if strategy == "default": + return image_features[:, 1:] + elif strategy == "full": + return image_features + + raise ValueError(f"Unexpected select feature strategy: {strategy}") + + def _image_pixels_to_features( + self, + vision_tower: Union[CLIPVisionModel], + pixel_values: Union[torch.Tensor, list[torch.Tensor]], + ) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]: + # NOTE: we skip the step to select the vision feature layer since + # this is already done inside the vision tower + image_features = vision_tower(pixel_values) + + def select_features(leaf: torch.Tensor): + return self._select_image_features( + leaf, + strategy=self.config.vision_feature_select_strategy, + ) + + return cast( + Union[torch.Tensor, tuple[torch.Tensor, ...]], + json_map_leaves(select_features, image_features), + ) + + def _process_image_pixels( + self, + inputs: Union[MiniMaxVL01ImagePixelInputs], + ) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]: + assert self.vision_tower is not None + + pixel_values = inputs["pixel_values"] + + return self._image_pixels_to_features(self.vision_tower, pixel_values) + + def _process_image_input( + self, + image_input: MiniMaxVL01ImagePixelInputs, + ) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]: + if image_input["type"] == "image_embeds": + return image_input["data"] + + assert self.vision_tower is not None + image_features = self._process_image_pixels(image_input) + + if isinstance(image_features, torch.Tensor): + return self.multi_modal_projector(image_features) + + feature_sizes = [ + image_feature.shape[0] for image_feature in image_features + ] + + image_embeds = self.multi_modal_projector(torch.cat(image_features)) + image_embeds = torch.split(image_embeds, feature_sizes) + return image_embeds + + def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor: + h = w = self.config.vision_config.image_size + expected_dims = (3, h, w) + actual_dims = tuple(data.shape[1:]) + + if actual_dims != expected_dims: + expected_expr = ("batch_size", *map(str, expected_dims)) + raise ValueError( + f"The expected shape of pixel values is {expected_expr}. " + f"You supplied {tuple(data.shape)}.") + + return data + + def _parse_and_validate_image_input( + self, **kwargs: object) -> Optional[MiniMaxVL01ImagePixelInputs]: + pixel_values = kwargs.pop("pixel_values", None) + image_embeds = kwargs.pop("image_embeds", None) + + if pixel_values is None and image_embeds is None: + return None + + if pixel_values is not None: + if not isinstance(pixel_values, (torch.Tensor, list)): + raise ValueError("Incorrect type of pixel values. " + f"Got type: {type(pixel_values)}") + + return MiniMaxVL01ImagePixelInputs( + type="pixel_values", + pixel_values=self._validate_pixel_values( + flatten_bn(pixel_values, concat=True)), + ) + + if image_embeds is not None: + if not isinstance(image_embeds, (torch.Tensor, list)): + raise ValueError("Incorrect type of image embeddings. " + f"Got type: {type(image_embeds)}") + + return MiniMaxVL01ImageEmbeddingInputs( + type="image_embeds", + data=flatten_bn(image_embeds, concat=True), + ) + + raise AssertionError("This line should be unreachable.") + + def get_multimodal_embeddings( + self, **kwargs: object) -> Optional[MultiModalEmbeddings]: + image_input = self._parse_and_validate_image_input(**kwargs) + if image_input is None: + return None + + return self._process_image_input(image_input) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs: object, + ) -> Union[torch.Tensor, IntermediateTensors]: + + if intermediate_tensors is not None: + inputs_embeds = None + elif inputs_embeds is None: + vision_embeddings = self.get_multimodal_embeddings(**kwargs) + inputs_embeds = self.get_input_embeddings(input_ids, + vision_embeddings) + input_ids = None + + hidden_states = self.language_model.model(input_ids, + positions, + intermediate_tensors, + inputs_embeds=inputs_embeds) + + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + return self.language_model.compute_logits(hidden_states, + sampling_metadata) + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + loader = AutoWeightsLoader(self) + return loader.load_weights(weights) diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 621b9d69faa5..fced9f3d798c 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -188,6 +188,7 @@ "LlavaNextVideoForConditionalGeneration": ("llava_next_video", "LlavaNextVideoForConditionalGeneration"), # noqa: E501 "LlavaOnevisionForConditionalGeneration": ("llava_onevision", "LlavaOnevisionForConditionalGeneration"), # noqa: E501 "MantisForConditionalGeneration": ("llava", "MantisForConditionalGeneration"), # noqa: E501 + "MiniMaxVL01ForConditionalGeneration": ("minimax_vl_01", "MiniMaxVL01ForConditionalGeneration"), # noqa: E501 "MiniCPMO": ("minicpmo", "MiniCPMO"), "MiniCPMV": ("minicpmv", "MiniCPMV"), "Mistral3ForConditionalGeneration": ("mistral3", "Mistral3ForConditionalGeneration"), # noqa: E501 diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index ee991eaeb34a..fd09fa21ec9f 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -34,11 +34,13 @@ H2OVLChatConfig, InternVLChatConfig, JAISConfig, KimiVLConfig, MedusaConfig, - MllamaConfig, MLPSpeculatorConfig, - MPTConfig, NemotronConfig, - NVLM_D_Config, RWConfig, - SkyworkR1VChatConfig, SolarConfig, - Telechat2Config, UltravoxConfig) + MiniMaxText01Config, + MiniMaxVL01Config, MllamaConfig, + MLPSpeculatorConfig, MPTConfig, + NemotronConfig, NVLM_D_Config, + RWConfig, SkyworkR1VChatConfig, + SolarConfig, Telechat2Config, + UltravoxConfig) # yapf: enable from vllm.transformers_utils.utils import check_gguf_file from vllm.utils import resolve_obj_by_qualname @@ -73,6 +75,8 @@ "exaone": ExaoneConfig, "h2ovl_chat": H2OVLChatConfig, "internvl_chat": InternVLChatConfig, + "minimax_text_01": MiniMaxText01Config, + "minimax_vl_01": MiniMaxVL01Config, "nemotron": NemotronConfig, "NVLM_D": NVLM_D_Config, "solar": SolarConfig, diff --git a/vllm/transformers_utils/configs/__init__.py b/vllm/transformers_utils/configs/__init__.py index 8812d4c484b1..8945c45ea86e 100644 --- a/vllm/transformers_utils/configs/__init__.py +++ b/vllm/transformers_utils/configs/__init__.py @@ -15,6 +15,8 @@ from vllm.transformers_utils.configs.jais import JAISConfig from vllm.transformers_utils.configs.kimi_vl import KimiVLConfig from vllm.transformers_utils.configs.medusa import MedusaConfig +from vllm.transformers_utils.configs.minimax_text_01 import MiniMaxText01Config +from vllm.transformers_utils.configs.minimax_vl_01 import MiniMaxVL01Config from vllm.transformers_utils.configs.mllama import MllamaConfig from vllm.transformers_utils.configs.mlp_speculator import MLPSpeculatorConfig from vllm.transformers_utils.configs.moonvit import MoonViTConfig @@ -39,6 +41,8 @@ "MedusaConfig", "EAGLEConfig", "ExaoneConfig", + "MiniMaxText01Config", + "MiniMaxVL01Config", "MllamaConfig", "MLPSpeculatorConfig", "MoonViTConfig", diff --git a/vllm/transformers_utils/configs/minimax_text_01.py b/vllm/transformers_utils/configs/minimax_text_01.py new file mode 100644 index 000000000000..660e870ac62d --- /dev/null +++ b/vllm/transformers_utils/configs/minimax_text_01.py @@ -0,0 +1,69 @@ +# SPDX-License-Identifier: Apache-2.0 +""" MiniMaxText01 model configuration""" + +from transformers.configuration_utils import PretrainedConfig + + +class MiniMaxText01Config(PretrainedConfig): + model_type = "MiniMaxText01" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=32000, + hidden_size=4096, + intermediate_size=14336, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=8, + hidden_act="silu", + max_position_embeddings=4096 * 32, + initializer_range=0.02, + rms_norm_eps=1e-5, + use_cache=True, + pad_token_id=None, + bos_token_id=None, + eos_token_id=None, + tie_word_embeddings=False, + rope_theta=1e6, + sliding_window=None, + attention_dropout=0.0, + num_experts_per_tok=2, + num_local_experts=8, + output_router_logits=False, + router_aux_loss_coef=0.001, + router_jitter_noise=0.0, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.sliding_window = sliding_window + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.attention_dropout = attention_dropout + + self.num_experts_per_tok = num_experts_per_tok + self.num_local_experts = num_local_experts + self.output_router_logits = output_router_logits + self.router_aux_loss_coef = router_aux_loss_coef + self.router_jitter_noise = router_jitter_noise + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) diff --git a/vllm/transformers_utils/configs/minimax_vl_01.py b/vllm/transformers_utils/configs/minimax_vl_01.py new file mode 100644 index 000000000000..99e0d249dc5a --- /dev/null +++ b/vllm/transformers_utils/configs/minimax_vl_01.py @@ -0,0 +1,70 @@ +# SPDX-License-Identifier: Apache-2.0 +"""MiniMaxVL01 model configuration""" + +from transformers.configuration_utils import PretrainedConfig +from transformers.models.auto import CONFIG_MAPPING + +from .minimax_text_01 import MiniMaxText01Config + + +class MiniMaxVL01Config(PretrainedConfig): + model_type = "minimax_vl_01" + + def __init__( + self, + vision_config=None, + text_config=None, + ignore_index=-100, + image_token_index=32000, + projector_hidden_act="gelu", + vision_feature_select_strategy="default", + vision_feature_layer=-2, + image_grid_pinpoints=None, + tie_word_embeddings=False, + image_seq_length=576, + **kwargs, + ): + self.ignore_index = ignore_index + self.image_token_index = image_token_index + self.projector_hidden_act = projector_hidden_act + self.image_seq_length = image_seq_length + + if vision_feature_select_strategy not in ["default", "full"]: + raise ValueError("vision_feature_select_strategy should " + + "be one of 'default', 'full'." + + f"Got: {vision_feature_select_strategy}") + + self.vision_feature_select_strategy = vision_feature_select_strategy + self.vision_feature_layer = vision_feature_layer + image_grid_pinpoints = ( + image_grid_pinpoints if image_grid_pinpoints is not None else + [[336, 672], [672, 336], [672, 672], [1008, 336], [336, 1008]]) + self.image_grid_pinpoints = image_grid_pinpoints + + if isinstance(vision_config, dict): + if "model_type" not in vision_config: + vision_config["model_type"] = "clip_vision_model" + vision_config = CONFIG_MAPPING[vision_config["model_type"]]( + **vision_config) + elif vision_config is None: + vision_config = CONFIG_MAPPING["clip_vision_model"]( + intermediate_size=4096, + hidden_size=1024, + patch_size=14, + image_size=336, + num_hidden_layers=24, + num_attention_heads=16, + vocab_size=32000, + projection_dim=768, + ) + + self.vision_config = vision_config + + if text_config is not None: + text_config = MiniMaxText01Config(**text_config) + else: + text_config = MiniMaxText01Config() + + self.text_config = text_config + + super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)