diff --git a/requirements-test.txt b/requirements-test.txt index 5f3fd15c7ee56..62d6cc49eade4 100644 --- a/requirements-test.txt +++ b/requirements-test.txt @@ -20,6 +20,9 @@ sentence-transformers # required for embedding compressed-tensors==0.4.0 # required for compressed-tensors timm # required for internvl test +# TODO: Add this after fully implementing llava(mantis) +# git+https://github.com/TIGER-AI-Lab/Mantis.git # required for llava(mantis) test + # Benchmarking aiohttp diff --git a/tests/models/test_llava.py b/tests/models/test_llava.py index 79ab58c364f64..749d3353717e2 100644 --- a/tests/models/test_llava.py +++ b/tests/models/test_llava.py @@ -1,10 +1,11 @@ from typing import List, Optional, Tuple, Type import pytest -from transformers import AutoTokenizer +from transformers import AutoConfig, AutoTokenizer from vllm.multimodal.utils import rescale_image_size from vllm.sequence import SampleLogprobs +from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE from ..conftest import IMAGE_ASSETS, HfRunner, VllmRunner, _ImageAssets from .utils import check_logprobs_close @@ -18,9 +19,11 @@ "USER: \nWhat is the season?\nASSISTANT:", }) -IMAGE_TOKEN_ID = 32000 - -models = ["llava-hf/llava-1.5-7b-hf"] +models = [ + "llava-hf/llava-1.5-7b-hf", + # TODO: Get this model to produce meaningful output in vLLM + # "TIGER-Lab/Mantis-8B-siglip-llama3", +] def vllm_to_hf_output(vllm_output: Tuple[List[int], str, @@ -29,12 +32,15 @@ def vllm_to_hf_output(vllm_output: Tuple[List[int], str, """Sanitize vllm output to be comparable with hf output.""" output_ids, output_str, out_logprobs = vllm_output + config = AutoConfig.from_pretrained(model) + image_token_id = config.image_token_index + tokenizer = AutoTokenizer.from_pretrained(model) eos_token_id = tokenizer.eos_token_id hf_output_ids = [ token_id for idx, token_id in enumerate(output_ids) - if token_id != IMAGE_TOKEN_ID or output_ids[idx - 1] != IMAGE_TOKEN_ID + if token_id != image_token_id or output_ids[idx - 1] != image_token_id ] assert output_str[0] == " " @@ -67,6 +73,17 @@ def run_test( Note, the text input is also adjusted to abide by vllm contract. The text output is sanitized to be able to compare with hf. """ + # NOTE: For local use; this isn't tested in CI yet (see TODO above) + if model.startswith("TIGER-Lab/Mantis"): + from mantis.models.mllava import MLlavaProcessor + + torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[dtype] + mantis_processor = MLlavaProcessor.from_pretrained( + model, torch_dtype=torch_dtype) + assert isinstance(mantis_processor, MLlavaProcessor) + else: + mantis_processor = None + images = [asset.pil_image for asset in image_assets] inputs_per_image = [( @@ -94,6 +111,15 @@ def run_test( ] with hf_runner(model, dtype=dtype, is_vision_model=True) as hf_model: + if mantis_processor is not None: + + def process(*args, **kwargs): + output = mantis_processor(*args, **kwargs) + output["pixel_values"] = output["pixel_values"].to(torch_dtype) + return output + + hf_model.processor = process + hf_outputs_per_image = [ hf_model.generate_greedy_logprobs_limit(prompts, max_tokens, diff --git a/tests/models/test_llava_next.py b/tests/models/test_llava_next.py index b6d72dee5c5b5..60c7fc33b72fe 100644 --- a/tests/models/test_llava_next.py +++ b/tests/models/test_llava_next.py @@ -1,7 +1,7 @@ from typing import List, Optional, Tuple, Type, overload import pytest -from transformers import AutoTokenizer +from transformers import AutoConfig, AutoTokenizer from vllm.multimodal.utils import rescale_image_size from vllm.sequence import SampleLogprobs @@ -23,8 +23,6 @@ f"{_PREFACE} USER: \nWhat is the season? ASSISTANT:", }) -IMAGE_TOKEN_ID = 32000 - models = ["llava-hf/llava-v1.6-vicuna-7b-hf"] @@ -34,12 +32,15 @@ def vllm_to_hf_output(vllm_output: Tuple[List[int], str, """Sanitize vllm output to be comparable with hf output.""" output_ids, output_str, out_logprobs = vllm_output + config = AutoConfig.from_pretrained(model) + image_token_id = config.image_token_index + tokenizer = AutoTokenizer.from_pretrained(model) eos_token_id = tokenizer.eos_token_id hf_output_ids = [ token_id for idx, token_id in enumerate(output_ids) - if token_id != IMAGE_TOKEN_ID or output_ids[idx - 1] != IMAGE_TOKEN_ID + if token_id != image_token_id or output_ids[idx - 1] != image_token_id ] assert output_str[0] == " " diff --git a/tests/models/test_paligemma.py b/tests/models/test_paligemma.py index e1c39ee6fecb6..f3f682b1c2cda 100644 --- a/tests/models/test_paligemma.py +++ b/tests/models/test_paligemma.py @@ -2,7 +2,7 @@ from typing import List, Optional, Tuple, Type import pytest -from transformers import AutoTokenizer +from transformers import AutoConfig, AutoTokenizer from vllm.multimodal.utils import rescale_image_size from vllm.sequence import SampleLogprobs @@ -20,8 +20,6 @@ "What is in the picture?", }) -IMAGE_TOKEN_ID = 257152 - models = ["google/paligemma-3b-mix-224"] # ROCm Triton FA can run into compilation issues with these models due to, @@ -37,12 +35,15 @@ def vllm_to_hf_output(vllm_output: Tuple[List[int], str, """Sanitize vllm output to be comparable with hf output.""" output_ids, output_str, out_logprobs = vllm_output + config = AutoConfig.from_pretrained(model) + image_token_id = config.image_token_index + tokenizer = AutoTokenizer.from_pretrained(model) eos_token_id = tokenizer.eos_token_id hf_output_ids = [ token_id for idx, token_id in enumerate(output_ids) - if token_id != IMAGE_TOKEN_ID or output_ids[idx - 1] != IMAGE_TOKEN_ID + if token_id != image_token_id or output_ids[idx - 1] != image_token_id ] hf_output_str = output_str diff --git a/tests/models/test_registry.py b/tests/models/test_registry.py index 547ab10051f1b..b058e2755c245 100644 --- a/tests/models/test_registry.py +++ b/tests/models/test_registry.py @@ -6,4 +6,4 @@ @pytest.mark.parametrize("model_cls", _MODELS) def test_registry_imports(model_cls): # Ensure all model classes can be imported successfully - ModelRegistry.load_model_cls(model_cls) + ModelRegistry.resolve_model_cls([model_cls]) diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index a5c5cb87bc460..44c04c9ba8ddc 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -16,7 +16,7 @@ import torch from huggingface_hub import HfApi, hf_hub_download from torch import nn -from transformers import AutoModelForCausalLM +from transformers import AutoModelForCausalLM, PretrainedConfig from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoadFormat, LoRAConfig, ModelConfig, MultiModalConfig, @@ -143,6 +143,22 @@ def _get_model_initialization_kwargs( return extra_kwargs +def build_model(model_class: Type[nn.Module], hf_config: PretrainedConfig, + cache_config: Optional[CacheConfig], + quant_config: Optional[QuantizationConfig], *, + lora_config: Optional[LoRAConfig], + multimodal_config: Optional[MultiModalConfig], + scheduler_config: Optional[SchedulerConfig]) -> nn.Module: + extra_kwargs = _get_model_initialization_kwargs(model_class, lora_config, + multimodal_config, + scheduler_config) + + return model_class(config=hf_config, + cache_config=cache_config, + quant_config=quant_config, + **extra_kwargs) + + def _initialize_model( model_config: ModelConfig, load_config: LoadConfig, @@ -151,15 +167,17 @@ def _initialize_model( cache_config: CacheConfig, scheduler_config: Optional[SchedulerConfig] = None) -> nn.Module: """Initialize a model with the given configurations.""" - model_class = get_model_architecture(model_config)[0] - quant_config = _get_quantization_config(model_config, load_config) - - return model_class(config=model_config.hf_config, - cache_config=cache_config, - quant_config=quant_config, - **_get_model_initialization_kwargs( - model_class, lora_config, multimodal_config, - scheduler_config)) + model_class, _ = get_model_architecture(model_config) + + return build_model( + model_class, + model_config.hf_config, + quant_config=_get_quantization_config(model_config, load_config), + lora_config=lora_config, + multimodal_config=multimodal_config, + cache_config=cache_config, + scheduler_config=scheduler_config, + ) class BaseModelLoader(ABC): diff --git a/vllm/model_executor/model_loader/utils.py b/vllm/model_executor/model_loader/utils.py index f7e0f56c1a46e..331b859d2adec 100644 --- a/vllm/model_executor/model_loader/utils.py +++ b/vllm/model_executor/model_loader/utils.py @@ -28,13 +28,7 @@ def get_model_architecture( and "MixtralForCausalLM" in architectures): architectures = ["QuantMixtralForCausalLM"] - for arch in architectures: - model_cls = ModelRegistry.load_model_cls(arch) - if model_cls is not None: - return (model_cls, arch) - raise ValueError( - f"Model architectures {architectures} are not supported for now. " - f"Supported architectures: {ModelRegistry.get_supported_archs()}") + return ModelRegistry.resolve_model_cls(architectures) def get_architecture_class_name(model_config: ModelConfig) -> str: diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index 94c3cea98be7b..ebb77a802d5cb 100644 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -1,6 +1,6 @@ import functools import importlib -from typing import Dict, List, Optional, Type +from typing import Dict, List, Optional, Tuple, Type import torch.nn as nn @@ -126,7 +126,7 @@ def _get_model(model_arch: str): return getattr(module, model_cls_name, None) @staticmethod - def load_model_cls(model_arch: str) -> Optional[Type[nn.Module]]: + def _try_load_model_cls(model_arch: str) -> Optional[Type[nn.Module]]: if model_arch in _OOT_MODELS: return _OOT_MODELS[model_arch] if model_arch not in _MODELS: @@ -143,6 +143,18 @@ def load_model_cls(model_arch: str) -> Optional[Type[nn.Module]]: return ModelRegistry._get_model(model_arch) + @staticmethod + def resolve_model_cls( + architectures: List[str]) -> Tuple[Type[nn.Module], str]: + for arch in architectures: + model_cls = ModelRegistry._try_load_model_cls(arch) + if model_cls is not None: + return (model_cls, arch) + + raise ValueError( + f"Model architectures {architectures} are not supported for now. " + f"Supported architectures: {ModelRegistry.get_supported_archs()}") + @staticmethod def get_supported_archs() -> List[str]: return list(_MODELS.keys()) diff --git a/vllm/model_executor/models/clip.py b/vllm/model_executor/models/clip.py index b4f628061f19c..805ade39389de 100644 --- a/vllm/model_executor/models/clip.py +++ b/vllm/model_executor/models/clip.py @@ -1,6 +1,6 @@ """Minimal implementation of CLIPVisionModel intended to be only used within a vision language model.""" -from typing import Optional +from typing import Iterable, Optional, Tuple import torch import torch.nn as nn @@ -14,6 +14,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, RowParallelLinear) from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.multimodal.image import (cached_get_tokenizer, repeat_and_pad_image_tokens) from vllm.sequence import SequenceData @@ -32,7 +33,7 @@ def get_clip_num_patches(*, image_size: int, patch_size: int) -> int: def get_clip_image_feature_size(hf_config: CLIPVisionConfig) -> int: return get_clip_num_patches(image_size=hf_config.image_size, - patch_size=hf_config.patch_size) + patch_size=hf_config.patch_size) + 1 def get_max_clip_image_tokens(hf_config: CLIPVisionConfig) -> int: @@ -291,3 +292,22 @@ def forward(self, pixel_values: Optional[torch.Tensor] = None): @property def device(self): return next(self.parameters()).device + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + params_dict = dict(self.named_parameters()) + layer_count = len(self.vision_model.encoder.layers) + + for name, loaded_weight in weights: + # post_layernorm is not needed in CLIPVisionModel + if "vision_model.post_layernorm" in name: + continue + # omit layers when num_hidden_layers_override is set + if "vision_model.encoder.layers." in name: + layer_idx = int(name.split(".")[3]) + if layer_idx >= layer_count: + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) diff --git a/vllm/model_executor/models/internvl.py b/vllm/model_executor/models/internvl.py index 4749251271487..8850fd7c6763b 100644 --- a/vllm/model_executor/models/internvl.py +++ b/vllm/model_executor/models/internvl.py @@ -18,7 +18,6 @@ from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.models import ModelRegistry from vllm.model_executor.models.intern_vit import InternVisionModel from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY @@ -29,7 +28,8 @@ from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip, get_clip_num_patches) from .interfaces import SupportsVision -from .utils import merge_vision_embeddings +from .utils import (filter_weights, init_vllm_registered_model, + merge_vision_embeddings) IMG_START = '' IMG_END = '' @@ -283,10 +283,8 @@ def __init__(self, self.vision_model = InternVisionModel( config.vision_config, num_hidden_layers_override=num_hidden_layers) - llm_class = ModelRegistry.load_model_cls( - config.text_config.architectures[0]) - self.language_model = llm_class(config.text_config, cache_config, - quant_config) + self.language_model = init_vllm_registered_model( + config.text_config, cache_config, quant_config) vit_hidden_size = config.vision_config.hidden_size llm_hidden_size = config.text_config.hidden_size @@ -415,24 +413,16 @@ def sample( ) -> Optional[SamplerOutput]: return self.language_model.sample(logits, sampling_metadata) - def _filter_weights(self, weights: Iterable[Tuple[str, torch.Tensor]], - prefix: str): - for name, loaded_weight in weights: - name = name.split(".") - if prefix == name.pop(0): - name = ".".join(name) - yield name, loaded_weight - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # prepare weight iterators for components vit_weights, mlp_weights, llm_weights = itertools.tee(weights, 3) # load vision encoder - vit_weights = self._filter_weights(vit_weights, "vision_model") + vit_weights = filter_weights(vit_weights, "vision_model") self.vision_model.load_weights(vit_weights) # load mlp projector - mlp_weights = self._filter_weights(mlp_weights, "mlp1") + mlp_weights = filter_weights(mlp_weights, "mlp1") mlp_params_dict = dict(self.mlp1.named_parameters()) for name, loaded_weight in mlp_weights: param = mlp_params_dict[name] @@ -441,5 +431,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader(param, loaded_weight) # load llm backbone - llm_weights = self._filter_weights(llm_weights, "language_model") + llm_weights = filter_weights(llm_weights, "language_model") self.language_model.load_weights(llm_weights) diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index 4e7e6c47f0a0b..9a11bcc4c54ce 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -1,34 +1,30 @@ -from typing import Iterable, List, Literal, Optional, Tuple, TypedDict +import itertools +from typing import Iterable, List, Literal, Optional, Tuple, TypedDict, Union import torch import torch.nn as nn -from transformers import CLIPVisionConfig, LlavaConfig +from transformers import CLIPVisionConfig, LlavaConfig, SiglipVisionConfig from vllm.attention import AttentionMetadata from vllm.config import CacheConfig, MultiModalConfig from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs from vllm.model_executor.layers.activation import get_act_fn -from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) -from vllm.model_executor.layers.sampler import Sampler -from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.models.clip import CLIPVisionModel -from vllm.model_executor.models.llama import LlamaModel from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.sequence import IntermediateTensors, SamplerOutput -from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip, - get_max_clip_image_tokens, input_processor_for_clip) +from .clip import (CLIPVisionModel, dummy_image_for_clip, + dummy_seq_data_for_clip, get_max_clip_image_tokens, + input_processor_for_clip) from .interfaces import SupportsVision -from .utils import merge_vision_embeddings - -_KEYS_TO_MODIFY_MAPPING = { - "language_model.lm_head": "lm_head", - "language_model.model": "language_model", -} +from .siglip import (SiglipVisionModel, dummy_image_for_siglip, + dummy_seq_data_for_siglip, get_max_siglip_image_tokens, + input_processor_for_siglip) +from .utils import (filter_weights, init_vllm_registered_model, + merge_vision_embeddings) # TODO(xwjiang): Run benchmark and decide if TP. @@ -67,25 +63,48 @@ def get_max_llava_image_tokens(ctx: InputContext): vision_config = hf_config.vision_config if isinstance(vision_config, CLIPVisionConfig): - return get_max_clip_image_tokens(vision_config) - - msg = f"Unsupported vision config: {type(vision_config)}" - raise NotImplementedError(msg) + num_image_tokens = get_max_clip_image_tokens(vision_config) + elif isinstance(vision_config, SiglipVisionConfig): + num_image_tokens = get_max_siglip_image_tokens(vision_config) + else: + msg = f"Unsupported vision config: {type(vision_config)}" + raise NotImplementedError(msg) + + strategy = hf_config.vision_feature_select_strategy + if strategy == "default": + return num_image_tokens - 1 + elif strategy == "full": + return num_image_tokens + else: + raise ValueError(f"Unexpected select feature strategy: {strategy}") def dummy_data_for_llava(ctx: InputContext, seq_len: int): hf_config = ctx.get_hf_config(LlavaConfig) vision_config = hf_config.vision_config + image_feature_size = get_max_llava_image_tokens(ctx) + if isinstance(vision_config, CLIPVisionConfig): seq_data = dummy_seq_data_for_clip( vision_config, seq_len, image_token_id=hf_config.image_token_index, + image_feature_size_override=image_feature_size, ) mm_data = dummy_image_for_clip(vision_config) return seq_data, mm_data + elif isinstance(vision_config, SiglipVisionConfig): + seq_data = dummy_seq_data_for_siglip( + vision_config, + seq_len, + image_token_id=hf_config.image_token_index, + image_feature_size_override=image_feature_size, + ) + + mm_data = dummy_image_for_siglip(vision_config) + return seq_data, mm_data msg = f"Unsupported vision config: {type(vision_config)}" raise NotImplementedError(msg) @@ -100,12 +119,49 @@ def input_processor_for_llava(ctx: InputContext, llm_inputs: LLMInputs): hf_config = ctx.get_hf_config(LlavaConfig) vision_config = hf_config.vision_config + image_feature_size = get_max_llava_image_tokens(ctx) + if isinstance(vision_config, CLIPVisionConfig): return input_processor_for_clip( model_config, vision_config, llm_inputs, image_token_id=hf_config.image_token_index, + image_feature_size_override=image_feature_size, + ) + elif isinstance(vision_config, SiglipVisionConfig): + return input_processor_for_siglip( + model_config, + vision_config, + llm_inputs, + image_token_id=hf_config.image_token_index, + image_feature_size_override=image_feature_size, + ) + + msg = f"Unsupported vision config: {type(vision_config)}" + raise NotImplementedError(msg) + + +def _init_vision_tower(hf_config: LlavaConfig): + vision_config = hf_config.vision_config + + # Initialize the vision tower only up to the required feature layer + vision_feature_layer = hf_config.vision_feature_layer + if vision_feature_layer < 0: + num_hidden_layers = hf_config.vision_config.num_hidden_layers \ + + vision_feature_layer + 1 + else: + num_hidden_layers = vision_feature_layer + 1 + + if isinstance(vision_config, CLIPVisionConfig): + return CLIPVisionModel( + vision_config, + num_hidden_layers_override=num_hidden_layers, + ) + elif isinstance(vision_config, SiglipVisionConfig): + return SiglipVisionModel( + vision_config, + num_hidden_layers_override=num_hidden_layers, ) msg = f"Unsupported vision config: {type(vision_config)}" @@ -128,36 +184,15 @@ def __init__(self, self.config = config self.multimodal_config = multimodal_config - # Initialize the vision tower only up to the required feature layer - vision_feature_layer = config.vision_feature_layer - if vision_feature_layer < 0: - num_hidden_layers = config.vision_config.num_hidden_layers \ - + vision_feature_layer + 1 - else: - num_hidden_layers = vision_feature_layer + 1 - # TODO: Optionally initializes this for supporting embeddings. - self.vision_tower = CLIPVisionModel( - config.vision_config, num_hidden_layers_override=num_hidden_layers) + self.vision_tower = _init_vision_tower(config) self.multi_modal_projector = LlavaMultiModalProjector( vision_hidden_size=config.vision_config.hidden_size, text_hidden_size=config.text_config.hidden_size, projector_hidden_act=config.projector_hidden_act) - self.quant_config = quant_config - self.language_model = LlamaModel(config.text_config, cache_config, - quant_config) - self.unpadded_vocab_size = config.text_config.vocab_size - self.lm_head = ParallelLMHead( - self.unpadded_vocab_size, - config.text_config.hidden_size, - org_num_embeddings=self.language_model.org_vocab_size, - quant_config=quant_config) - logit_scale = getattr(config, "logit_scale", 1.0) - self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, - config.text_config.vocab_size, - logit_scale) - self.sampler = Sampler() + self.language_model = init_vllm_registered_model( + config.text_config, cache_config, quant_config) def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor: h = w = self.config.vision_config.image_size @@ -198,8 +233,11 @@ def _select_image_features(self, image_features: torch.Tensor, *, raise ValueError(f"Unexpected select feature strategy: {strategy}") - def _image_pixels_to_features(self, vision_tower: CLIPVisionModel, - pixel_values: torch.Tensor) -> torch.Tensor: + def _image_pixels_to_features( + self, + vision_tower: Union[CLIPVisionModel, SiglipVisionModel], + pixel_values: torch.Tensor, + ) -> torch.Tensor: # NOTE: we skip the step to select the vision feature layer since # this is already done inside the vision tower @@ -272,7 +310,8 @@ def forward( if image_input is not None: vision_embeddings = self._process_image_input(image_input) - inputs_embeds = self.language_model.get_input_embeddings(input_ids) + inputs_embeds = self.language_model.model.get_input_embeddings( + input_ids) inputs_embeds = merge_vision_embeddings( input_ids, inputs_embeds, vision_embeddings, @@ -282,68 +321,44 @@ def forward( else: inputs_embeds = None - hidden_states = self.language_model(input_ids, - positions, - kv_caches, - attn_metadata, - None, - inputs_embeds=inputs_embeds) + hidden_states = self.language_model.model(input_ids, + positions, + kv_caches, + attn_metadata, + None, + inputs_embeds=inputs_embeds) return hidden_states def compute_logits(self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata) -> torch.Tensor: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) - return logits + return self.language_model.compute_logits(hidden_states, + sampling_metadata) def sample( self, logits: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens + return self.language_model.sample(logits, sampling_metadata) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): - # only doing this for language model part for now. - stacked_params_mapping = [ - # (param_name, shard_name, shard_id) - ("qkv_proj", "q_proj", "q"), - ("qkv_proj", "k_proj", "k"), - ("qkv_proj", "v_proj", "v"), - ("gate_up_proj", "gate_proj", 0), - ("gate_up_proj", "up_proj", 1), - ] - params_dict = dict(self.named_parameters()) - for name, loaded_weight in weights: - if "rotary_emb.inv_freq" in name: - continue - # post_layernorm is not needed in CLIPVisionModel - if "vision_model.post_layernorm" in name: - continue - for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items(): - if key_to_modify in name: - name = name.replace(key_to_modify, new_key) - use_default_weight_loading = False - if "vision" in name: - if self.vision_tower is not None: - # We only do sharding for language model and - # not vision model for now. - use_default_weight_loading = True - else: - for (param_name, weight_name, - shard_id) in stacked_params_mapping: - if weight_name not in name: - continue - param = params_dict[name.replace(weight_name, param_name)] - weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id) - break - else: - use_default_weight_loading = True - if use_default_weight_loading and name in params_dict: - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) + # prepare weight iterators for components + vit_weights, mlp_weights, llm_weights = itertools.tee(weights, 3) + + # load vision encoder + vit_weights = filter_weights(vit_weights, "vision_tower") + self.vision_tower.load_weights(vit_weights) + + # load mlp projector + mlp_weights = filter_weights(mlp_weights, "multi_modal_projector") + mlp_params_dict = dict(self.multi_modal_projector.named_parameters()) + for name, loaded_weight in mlp_weights: + param = mlp_params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + + # load llm backbone + llm_weights = filter_weights(llm_weights, "language_model") + self.language_model.load_weights(llm_weights) diff --git a/vllm/model_executor/models/llava_next.py b/vllm/model_executor/models/llava_next.py index 4a67b9a583ea8..9abc480f60dec 100644 --- a/vllm/model_executor/models/llava_next.py +++ b/vllm/model_executor/models/llava_next.py @@ -1,9 +1,10 @@ +import itertools from typing import Iterable, List, Literal, Optional, Tuple, TypedDict, Union import torch import torch.nn as nn from PIL import Image -from transformers import CLIPVisionConfig, LlavaNextConfig +from transformers import CLIPVisionConfig, LlavaNextConfig, SiglipVisionConfig from transformers.models.llava_next.modeling_llava_next import ( get_anyres_image_grid_shape, unpad_image) from typing_extensions import NotRequired @@ -12,23 +13,23 @@ from vllm.config import CacheConfig, MultiModalConfig from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs from vllm.logger import init_logger -from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) -from vllm.model_executor.layers.sampler import Sampler -from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.models.clip import CLIPVisionModel -from vllm.model_executor.models.llama import LlamaModel from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.sequence import IntermediateTensors, SamplerOutput -from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip, +from .clip import (CLIPVisionModel, dummy_image_for_clip, + dummy_seq_data_for_clip, get_clip_image_feature_size, get_clip_patch_grid_length, input_processor_for_clip) from .interfaces import SupportsVision from .llava import LlavaMultiModalProjector -from .utils import merge_vision_embeddings +from .siglip import (SiglipVisionModel, dummy_image_for_siglip, + dummy_seq_data_for_siglip, get_siglip_image_feature_size, + get_siglip_patch_grid_length, input_processor_for_siglip) +from .utils import (filter_weights, init_vllm_registered_model, + merge_vision_embeddings) logger = init_logger(__name__) @@ -104,30 +105,42 @@ def get_llava_next_image_feature_size( image_size=vision_config.image_size, patch_size=vision_config.patch_size, ) - base_feature_size = num_patches * num_patches - - num_patch_height, num_patch_width = get_anyres_image_grid_shape( - image_size=(input_height, input_width), - grid_pinpoints=hf_config.image_grid_pinpoints, - patch_size=vision_config.image_size, + base_feature_size = get_clip_image_feature_size(vision_config) + elif isinstance(vision_config, SiglipVisionConfig): + num_patches = get_siglip_patch_grid_length( + image_size=vision_config.image_size, + patch_size=vision_config.patch_size, ) + base_feature_size = get_siglip_image_feature_size(vision_config) + else: + msg = f"Unsupported vision config: {type(vision_config)}" + raise NotImplementedError(msg) + + strategy = hf_config.vision_feature_select_strategy + if strategy == "default": + base_feature_size -= 1 + elif strategy == "full": + pass + else: + raise ValueError(f"Unexpected select feature strategy: {strategy}") - ( - unpadded_feature_size, - newline_feature_size, - ) = _get_llava_next_num_unpadded_features(input_height, input_width, - num_patches, - num_patch_height, - num_patch_width) + num_patch_height, num_patch_width = get_anyres_image_grid_shape( + image_size=(input_height, input_width), + grid_pinpoints=hf_config.image_grid_pinpoints, + patch_size=vision_config.image_size, + ) - return unpadded_feature_size + newline_feature_size + base_feature_size + ( + unpadded_feature_size, + newline_feature_size, + ) = _get_llava_next_num_unpadded_features(input_height, input_width, + num_patches, num_patch_height, + num_patch_width) - msg = f"Unsupported vision config: {type(vision_config)}" - raise NotImplementedError(msg) + return unpadded_feature_size + newline_feature_size + base_feature_size def get_max_llava_next_image_tokens(ctx: InputContext): - return get_llava_next_image_feature_size( ctx.get_hf_config(LlavaNextConfig), input_height=MAX_IMAGE_FEATURE_SIZE_HEIGHT, @@ -155,6 +168,21 @@ def dummy_data_for_llava_next(ctx: InputContext, seq_len: int): image_height_override=MAX_IMAGE_FEATURE_SIZE_HEIGHT, ) + return seq_data, mm_data + elif isinstance(vision_config, SiglipVisionConfig): + seq_data = dummy_seq_data_for_siglip( + vision_config, + seq_len, + image_token_id=hf_config.image_token_index, + image_feature_size_override=image_feature_size, + ) + + mm_data = dummy_image_for_siglip( + vision_config, + image_width_override=MAX_IMAGE_FEATURE_SIZE_WIDTH, + image_height_override=MAX_IMAGE_FEATURE_SIZE_HEIGHT, + ) + return seq_data, mm_data msg = f"Unsupported vision config: {type(vision_config)}" @@ -194,6 +222,40 @@ def input_processor_for_llava_next(ctx: InputContext, llm_inputs: LLMInputs): image_token_id=hf_config.image_token_index, image_feature_size_override=image_feature_size, ) + elif isinstance(vision_config, SiglipVisionConfig): + return input_processor_for_siglip( + model_config, + vision_config, + llm_inputs, + image_token_id=hf_config.image_token_index, + image_feature_size_override=image_feature_size, + ) + + msg = f"Unsupported vision config: {type(vision_config)}" + raise NotImplementedError(msg) + + +def _init_vision_tower(hf_config: LlavaNextConfig): + vision_config = hf_config.vision_config + + # Initialize the vision tower only up to the required feature layer + vision_feature_layer = hf_config.vision_feature_layer + if vision_feature_layer < 0: + num_hidden_layers = hf_config.vision_config.num_hidden_layers \ + + vision_feature_layer + 1 + else: + num_hidden_layers = vision_feature_layer + 1 + + if isinstance(vision_config, CLIPVisionConfig): + return CLIPVisionModel( + vision_config, + num_hidden_layers_override=num_hidden_layers, + ) + elif isinstance(vision_config, SiglipVisionConfig): + return SiglipVisionModel( + vision_config, + num_hidden_layers_override=num_hidden_layers, + ) msg = f"Unsupported vision config: {type(vision_config)}" raise NotImplementedError(msg) @@ -215,36 +277,15 @@ def __init__(self, self.config = config self.multimodal_config = multimodal_config - # Initialize the vision tower only up to the required feature layer - vision_feature_layer = config.vision_feature_layer - if vision_feature_layer < 0: - num_hidden_layers = config.vision_config.num_hidden_layers \ - + vision_feature_layer + 1 - else: - num_hidden_layers = vision_feature_layer + 1 - # TODO: Optionally initializes this for supporting embeddings. - self.vision_tower = CLIPVisionModel( - config.vision_config, num_hidden_layers_override=num_hidden_layers) + self.vision_tower = _init_vision_tower(config) self.multi_modal_projector = LlavaMultiModalProjector( vision_hidden_size=config.vision_config.hidden_size, text_hidden_size=config.text_config.hidden_size, projector_hidden_act=config.projector_hidden_act) - self.quant_config = quant_config - self.language_model = LlamaModel(config.text_config, cache_config, - quant_config) - self.unpadded_vocab_size = config.text_config.vocab_size - self.lm_head = ParallelLMHead( - self.unpadded_vocab_size, - config.text_config.hidden_size, - org_num_embeddings=self.language_model.org_vocab_size, - quant_config=quant_config) - logit_scale = getattr(config, "logit_scale", 1.0) - self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, - config.text_config.vocab_size, - logit_scale) - self.sampler = Sampler() + self.language_model = init_vllm_registered_model( + config.text_config, cache_config, quant_config) self.image_newline = nn.Parameter( torch.empty(config.text_config.hidden_size)) @@ -310,8 +351,11 @@ def _select_image_features(self, image_features: torch.Tensor, *, raise ValueError(f"Unexpected select feature strategy: {strategy}") - def _image_pixels_to_features(self, vision_tower: CLIPVisionModel, - pixel_values: torch.Tensor) -> torch.Tensor: + def _image_pixels_to_features( + self, + vision_tower: Union[CLIPVisionModel, SiglipVisionModel], + pixel_values: torch.Tensor, + ) -> torch.Tensor: # NOTE: we skip the step to select the vision feature layer since # this is already done inside the vision tower @@ -496,7 +540,8 @@ def forward( if image_input is not None: vision_embeddings = self._process_image_input(image_input) - inputs_embeds = self.language_model.get_input_embeddings(input_ids) + inputs_embeds = self.language_model.model.get_input_embeddings( + input_ids) inputs_embeds = merge_vision_embeddings( input_ids, inputs_embeds, vision_embeddings, @@ -506,68 +551,54 @@ def forward( else: inputs_embeds = None - hidden_states = self.language_model(input_ids, - positions, - kv_caches, - attn_metadata, - None, - inputs_embeds=inputs_embeds) + hidden_states = self.language_model.model(input_ids, + positions, + kv_caches, + attn_metadata, + None, + inputs_embeds=inputs_embeds) return hidden_states def compute_logits(self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata) -> torch.Tensor: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) - return logits + return self.language_model.compute_logits(hidden_states, + sampling_metadata) def sample( self, logits: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens + return self.language_model.sample(logits, sampling_metadata) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): - # only doing this for language model part for now. - stacked_params_mapping = [ - # (param_name, shard_name, shard_id) - ("qkv_proj", "q_proj", "q"), - ("qkv_proj", "k_proj", "k"), - ("qkv_proj", "v_proj", "v"), - ("gate_up_proj", "gate_proj", 0), - ("gate_up_proj", "up_proj", 1), - ] - params_dict = dict(self.named_parameters()) - for name, loaded_weight in weights: - if "rotary_emb.inv_freq" in name: - continue - # post_layernorm is not needed in CLIPVisionModel - if "vision_model.post_layernorm" in name: - continue - for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items(): - if key_to_modify in name: - name = name.replace(key_to_modify, new_key) - use_default_weight_loading = False - if "vision" in name: - if self.vision_tower is not None: - # We only do sharding for language model and - # not vision model for now. - use_default_weight_loading = True - else: - for (param_name, weight_name, - shard_id) in stacked_params_mapping: - if weight_name not in name: - continue - param = params_dict[name.replace(weight_name, param_name)] - weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id) - break - else: - use_default_weight_loading = True - if use_default_weight_loading and name in params_dict: - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) + # prepare weight iterators for components + vit_weights, mlp_weights, newline_weights, llm_weights = itertools.tee( + weights, 4) + + # load vision encoder + vit_weights = filter_weights(vit_weights, "vision_tower") + self.vision_tower.load_weights(vit_weights) + + # load mlp projector + mlp_weights = filter_weights(mlp_weights, "multi_modal_projector") + mlp_params_dict = dict(self.multi_modal_projector.named_parameters()) + for name, loaded_weight in mlp_weights: + param = mlp_params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + + # load newline + newline_weights = filter_weights(newline_weights, "image_newline") + for name, loaded_weight in newline_weights: + assert name == "" + param = self.image_newline + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + + # load llm backbone + llm_weights = filter_weights(llm_weights, "language_model") + self.language_model.load_weights(llm_weights) diff --git a/vllm/model_executor/models/siglip.py b/vllm/model_executor/models/siglip.py index 6faef45c9a6d3..5ba14f73394f3 100644 --- a/vllm/model_executor/models/siglip.py +++ b/vllm/model_executor/models/siglip.py @@ -2,12 +2,12 @@ within a vision language model.""" import math -from typing import Optional, Tuple +from typing import Iterable, Optional, Tuple import torch from PIL import Image from torch import nn -from transformers import SiglipConfig, SiglipVisionConfig +from transformers import SiglipVisionConfig from transformers.models.siglip.modeling_siglip import SiglipAttention from vllm_flash_attn import flash_attn_func from xformers.ops import memory_efficient_attention @@ -22,13 +22,15 @@ from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.multimodal.image import (cached_get_tokenizer, repeat_and_pad_image_tokens) from vllm.sequence import SequenceData def get_siglip_patch_grid_length(*, image_size: int, patch_size: int) -> int: - assert image_size % patch_size == 0 + # Since interpolation is applied, the image size need not be divisible + # assert image_size % patch_size == 0 return image_size // patch_size @@ -454,7 +456,7 @@ class SiglipEncoderLayer(nn.Module): def __init__( self, - config: SiglipConfig, + config: SiglipVisionConfig, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() @@ -474,7 +476,7 @@ def __init__( def forward( self, hidden_states: torch.Tensor, - ) -> Tuple[torch.Tensor]: + ) -> Tuple[torch.Tensor, None]: residual = hidden_states hidden_states = self.layer_norm1(hidden_states) @@ -493,22 +495,27 @@ class SiglipEncoder(nn.Module): def __init__( self, - config: SiglipConfig, + config: SiglipVisionConfig, quant_config: Optional[QuantizationConfig] = None, + num_hidden_layers_override: Optional[int] = None, ): super().__init__() self.config = config + + if num_hidden_layers_override is None: + num_hidden_layers = config.num_hidden_layers + else: + num_hidden_layers = num_hidden_layers_override + self.layers = nn.ModuleList([ - SiglipEncoderLayer( - config, - quant_config=quant_config, - ) for _ in range(config.num_hidden_layers) + SiglipEncoderLayer(config, quant_config=quant_config) + for _ in range(num_hidden_layers) ]) def forward( self, inputs_embeds: torch.Tensor, - ) -> Tuple: + ) -> torch.Tensor: hidden_states = inputs_embeds for encoder_layer in self.layers: hidden_states, _ = encoder_layer(hidden_states) @@ -553,6 +560,7 @@ def __init__( self, config: SiglipVisionConfig, quant_config: Optional[QuantizationConfig] = None, + num_hidden_layers_override: Optional[int] = None, ): super().__init__() self.config = config @@ -562,6 +570,7 @@ def __init__( self.encoder = SiglipEncoder( config, quant_config=quant_config, + num_hidden_layers_override=num_hidden_layers_override, ) self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) @@ -600,11 +609,13 @@ def __init__( self, config: SiglipVisionConfig, quant_config: Optional[QuantizationConfig] = None, + num_hidden_layers_override: Optional[int] = None, ): super().__init__() self.vision_model = SiglipVisionTransformer( config, quant_config, + num_hidden_layers_override=num_hidden_layers_override, ) def get_input_embeddings(self) -> nn.Module: @@ -619,3 +630,19 @@ def forward( pixel_values=pixel_values, interpolate_pos_encoding=interpolate_pos_encoding, ) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + params_dict = dict(self.named_parameters()) + layer_count = len(self.vision_model.encoder.layers) + + for name, loaded_weight in weights: + # omit layers when num_hidden_layers_override is set + if "vision_model.encoder.layers." in name: + layer_idx = int(name.split(".")[3]) + if layer_idx >= layer_count: + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index 91b4a27814bf4..d1bb030c6c90f 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -1,22 +1,70 @@ -from typing import Dict, List, Protocol, Tuple +from typing import Dict, Iterable, List, Optional, Protocol, Tuple import torch +import torch.nn as nn from torch.func import functional_call +from transformers import PretrainedConfig +from vllm.config import (CacheConfig, LoRAConfig, MultiModalConfig, + SchedulerConfig) +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.model_loader.loader import build_model +from vllm.model_executor.models import ModelRegistry from vllm.multimodal import BatchedTensors from vllm.utils import is_pin_memory_available +def filter_weights(weights: Iterable[Tuple[str, torch.Tensor]], prefix: str): + """ + Helper function to load weights for inner vLLM models. + + See also: + :ref:`init_vllm_registered_model` + """ + for name, loaded_weight in weights: + name = name.split(".") + if prefix == name.pop(0): + name = ".".join(name) + yield name, loaded_weight + + +def init_vllm_registered_model( + hf_config: PretrainedConfig, + cache_config: Optional[CacheConfig], + quant_config: Optional[QuantizationConfig], + *, + lora_config: Optional[LoRAConfig] = None, + multimodal_config: Optional[MultiModalConfig] = None, + scheduler_config: Optional[SchedulerConfig] = None, +) -> nn.Module: + """ + Helper function to initialize an inner model registered to vLLM, + based on the arguments passed to the outer vLLM model. + """ + model_class, _ = ModelRegistry.resolve_model_cls(hf_config.architectures) + + return build_model( + model_class, + hf_config, + cache_config, + quant_config, + lora_config=lora_config, + multimodal_config=multimodal_config, + scheduler_config=scheduler_config, + ) + + def merge_vision_embeddings(input_ids: torch.Tensor, inputs_embeds: torch.Tensor, vision_embeddings: BatchedTensors, image_token_id: int) -> torch.Tensor: """ - Merge `vision_embeddings` into `inputs_embeds` by overwriting the positions - in `inputs_embeds` corresponding to placeholder image tokens in `input_ids`. + Merge ``vision_embeddings`` into ``inputs_embeds`` by overwriting the + positions in ``inputs_embeds`` corresponding to placeholder image tokens in + ``input_ids``. Note: - This updates `inputs_embeds` in place. + This updates ``inputs_embeds`` in place. """ mask = (input_ids == image_token_id) num_expected_tokens = mask.sum()