From bc4469ae17e59ea7ac50f70fb63d08120eda2976 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Mon, 10 Jun 2024 20:47:15 +0800 Subject: [PATCH] [Model] Initial support for LLaVA-NeXT (#4199) Co-authored-by: Roger Wang --- docs/source/models/supported_models.rst | 6 +- tests/models/test_llava.py | 2 - tests/models/test_llava_next.py | 123 +++++++ tests/multimodal/test_processor.py | 62 +++- vllm/model_executor/models/__init__.py | 2 + vllm/model_executor/models/llava.py | 18 +- vllm/model_executor/models/llava_next.py | 445 +++++++++++++++++++++++ 7 files changed, 640 insertions(+), 18 deletions(-) create mode 100644 tests/models/test_llava_next.py create mode 100644 vllm/model_executor/models/llava_next.py diff --git a/docs/source/models/supported_models.rst b/docs/source/models/supported_models.rst index 24fa83df7d751..5d3f55be1271f 100644 --- a/docs/source/models/supported_models.rst +++ b/docs/source/models/supported_models.rst @@ -89,7 +89,11 @@ Alongside each architecture, we include some popular models that use it. - ✅︎ * - :code:`LlavaForConditionalGeneration` - LLaVA-1.5 - - :code:`llava-hf/llava-1.5-7b-hf`\*, :code:`llava-hf/llava-1.5-13b-hf`\*, etc. + - :code:`llava-hf/llava-1.5-7b-hf`, :code:`llava-hf/llava-1.5-13b-hf`, etc. + - + * - :code:`LlavaNextForConditionalGeneration` + - LLaVA-NeXT + - :code:`llava-hf/llava-v1.6-mistral-7b-hf`, :code:`llava-hf/llava-v1.6-vicuna-7b-hf`, etc. - * - :code:`MiniCPMForCausalLM` - MiniCPM diff --git a/tests/models/test_llava.py b/tests/models/test_llava.py index 1f446362167a1..a1f0cff1cc0e5 100644 --- a/tests/models/test_llava.py +++ b/tests/models/test_llava.py @@ -39,8 +39,6 @@ def iter_llava_configs(model_name: str): model_and_vl_config = [ *iter_llava_configs("llava-hf/llava-1.5-7b-hf"), - # Not enough memory - # *iter_llava_configs("llava-hf/llava-1.5-13b-hf"), ] diff --git a/tests/models/test_llava_next.py b/tests/models/test_llava_next.py new file mode 100644 index 0000000000000..aa6ee268ae588 --- /dev/null +++ b/tests/models/test_llava_next.py @@ -0,0 +1,123 @@ +from typing import List, Tuple + +import pytest +from transformers import AutoTokenizer + +from vllm.config import VisionLanguageConfig + +from ..conftest import IMAGE_FILES + +pytestmark = pytest.mark.llava + +_PREFACE = ( + "A chat between a curious human and an artificial intelligence assistant. " + "The assistant gives helpful, detailed, and polite answers to the human's " + "questions.") + +# The image token is placed before "user" on purpose so that the test can pass +HF_IMAGE_PROMPTS = [ + f"{_PREFACE} \nUSER: What's the content of the image? ASSISTANT:", + f"{_PREFACE} \nUSER: What is the season? ASSISTANT:", +] + +assert len(HF_IMAGE_PROMPTS) == len(IMAGE_FILES) + + +def iter_llava_next_configs(model_name: str): + image_hw_to_feature_size = { + (336, 336): 1176, + (672, 672): 2928, + (1344, 336): 1944, + (336, 1344): 1890, + } + + for (h, w), f in image_hw_to_feature_size.items(): + for input_type, input_shape in [ + (VisionLanguageConfig.ImageInputType.PIXEL_VALUES, (1, 3, h, w)), + ]: + yield (model_name, + VisionLanguageConfig(image_input_type=input_type, + image_feature_size=f, + image_token_id=32000, + image_input_shape=input_shape, + image_processor=model_name, + image_processor_revision=None)) + + +model_and_vl_config = [ + *iter_llava_next_configs("llava-hf/llava-v1.6-vicuna-7b-hf"), +] + + +def vllm_to_hf_output(vllm_output: Tuple[List[int], str], + vlm_config: VisionLanguageConfig, model_id: str): + """Sanitize vllm output to be comparable with hf output. + The function reduces `input_ids` from 1, 32000, 32000, ..., 32000, + x1, x2, x3 ... to 1, 32000, x1, x2, x3 ... + It also reduces `output_str` from "bla" to "bla". + """ + input_ids, output_str = vllm_output + image_token_id = vlm_config.image_token_id + + tokenizer = AutoTokenizer.from_pretrained(model_id) + image_token_str = tokenizer.decode(image_token_id) + + hf_input_ids = [ + input_id for idx, input_id in enumerate(input_ids) + if input_id != image_token_id or input_ids[idx - 1] != image_token_id + ] + hf_output_str = output_str \ + .replace(image_token_str * vlm_config.image_feature_size, " ") + + return hf_input_ids, hf_output_str + + +@pytest.mark.xfail( + reason="Inconsistent image processor being used due to lack " + "of support for dynamic image token replacement") +@pytest.mark.parametrize("model_and_config", model_and_vl_config) +@pytest.mark.parametrize("dtype", ["half"]) +@pytest.mark.parametrize("max_tokens", [128]) +def test_models(hf_runner, vllm_runner, hf_images, vllm_images, + model_and_config, dtype: str, max_tokens: int) -> None: + """Inference result should be the same between hf and vllm. + + All the image fixtures for the test is under tests/images. + For huggingface runner, we provide the PIL images as input. + For vllm runner, we provide MultiModalData objects and corresponding + vision language config as input. + Note, the text input is also adjusted to abide by vllm contract. + The text output is sanitized to be able to compare with hf. + """ + model_id, vlm_config = model_and_config + + with hf_runner(model_id, dtype=dtype, is_vision_model=True) as hf_model: + hf_outputs = hf_model.generate_greedy(HF_IMAGE_PROMPTS, + max_tokens, + images=hf_images) + + vllm_image_prompts = [ + p.replace("", "" * vlm_config.image_feature_size) + for p in HF_IMAGE_PROMPTS + ] + + with vllm_runner( + model_id, + dtype=dtype, + # should be greater than image_feature_size + max_model_len=4096, + enforce_eager=True, + **vlm_config.as_cli_args_dict(), + ) as vllm_model: + vllm_outputs = vllm_model.generate_greedy(vllm_image_prompts, + max_tokens, + images=vllm_images) + + for i in range(len(HF_IMAGE_PROMPTS)): + hf_output_ids, hf_output_str = hf_outputs[i] + vllm_output_ids, vllm_output_str = vllm_to_hf_output( + vllm_outputs[i], vlm_config, model_id) + assert hf_output_str == vllm_output_str, ( + f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}") + assert hf_output_ids == vllm_output_ids, ( + f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}") diff --git a/tests/multimodal/test_processor.py b/tests/multimodal/test_processor.py index 3df28e782dd89..51c352361702a 100644 --- a/tests/multimodal/test_processor.py +++ b/tests/multimodal/test_processor.py @@ -1,6 +1,6 @@ import numpy as np import pytest -from transformers import CLIPImageProcessor +from transformers import CLIPImageProcessor, LlavaNextImageProcessor from vllm.config import ModelConfig, VisionLanguageConfig from vllm.multimodal import MULTIMODAL_REGISTRY @@ -12,7 +12,7 @@ @pytest.mark.parametrize("dtype", ["half", "float"]) def test_clip_image_processor(hf_images, dtype): MODEL_NAME = "llava-hf/llava-1.5-7b-hf" - IMAGE_HEIGHT = IMAGE_WIDTH = 33 + IMAGE_HEIGHT = IMAGE_WIDTH = 560 hf_processor = CLIPImageProcessor.from_pretrained(MODEL_NAME) assert isinstance(hf_processor, CLIPImageProcessor) @@ -55,10 +55,61 @@ def test_clip_image_processor(hf_images, dtype): assert np.allclose(hf_arr, vllm_arr), f"Failed for key={key}" +@pytest.mark.xfail( + reason="Inconsistent image processor being used due to lack " + "of support for dynamic image token replacement") +@pytest.mark.parametrize("dtype", ["half", "float"]) +def test_llava_next_image_processor(hf_images, dtype): + MODEL_NAME = "llava-hf/llava-v1.6-34b-hf" + IMAGE_HEIGHT = IMAGE_WIDTH = 560 + + hf_processor = LlavaNextImageProcessor.from_pretrained(MODEL_NAME) + assert isinstance(hf_processor, LlavaNextImageProcessor) + + model_config = ModelConfig( + model=MODEL_NAME, + tokenizer=MODEL_NAME, + tokenizer_mode="auto", + trust_remote_code=False, + seed=0, + dtype=dtype, + revision=None, + ) + vlm_config = VisionLanguageConfig( + image_input_type=VisionLanguageConfig.ImageInputType.PIXEL_VALUES, + image_token_id=64000, + image_input_shape=(1, 3, IMAGE_HEIGHT, IMAGE_WIDTH), + image_feature_size=2928, + image_processor=MODEL_NAME, + image_processor_revision=None, + ) + + for image in hf_images: + hf_result = hf_processor.preprocess( + image, + return_tensors="pt", + ).to(dtype=_STR_DTYPE_TO_TORCH_DTYPE[dtype]) + vllm_result = MULTIMODAL_REGISTRY.process_input( + ImagePixelData(image), + model_config=model_config, + vlm_config=vlm_config, + ) + + assert hf_result.keys() == vllm_result.keys() + for key, hf_tensor in hf_result.items(): + hf_arr: np.ndarray = hf_tensor.numpy() + vllm_arr: np.ndarray = vllm_result[key].numpy() + + assert hf_arr.shape == vllm_arr.shape, f"Failed for key={key}" + assert np.allclose(hf_arr, vllm_arr), f"Failed for key={key}" + + +@pytest.mark.xfail( + reason="Example image pixels were not processed using HuggingFace") @pytest.mark.parametrize("dtype", ["float"]) def test_image_pixel_types(hf_images, vllm_image_tensors, dtype): MODEL_NAME = "llava-hf/llava-1.5-7b-hf" - IMAGE_HEIGHT = IMAGE_WIDTH = 33 + IMAGE_HEIGHT = IMAGE_WIDTH = 560 model_config = ModelConfig( model=MODEL_NAME, @@ -95,7 +146,4 @@ def test_image_pixel_types(hf_images, vllm_image_tensors, dtype): tensor_arr: np.ndarray = tensor_result[key].numpy() assert image_arr.shape == tensor_arr.shape, f"Failed for key={key}" - - # The examples in PR#3042 have slightly different preprocessing from - # HuggingFace's LlavaProcessor, causing the test to fail. - # assert np.allclose(image_arr, tensor_arr), f"Failed for key={key}" + assert np.allclose(image_arr, tensor_arr), f"Failed for key={key}" diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index a92abe6b5b8dc..4446914c67c8e 100755 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -33,6 +33,8 @@ "LlamaForCausalLM": ("llama", "LlamaForCausalLM"), "LlavaForConditionalGeneration": ("llava", "LlavaForConditionalGeneration"), + "LlavaNextForConditionalGeneration": + ("llava_next", "LlavaNextForConditionalGeneration"), # For decapoda-research/llama-* "LLaMAForCausalLM": ("llama", "LlamaForCausalLM"), "MistralForCausalLM": ("llama", "LlamaForCausalLM"), diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index 3332bcc578460..67b32a08833b6 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -1,7 +1,7 @@ from typing import Iterable, List, Literal, Optional, Tuple, TypedDict, Union import torch -from torch import nn +import torch.nn as nn # TODO(xwjiang): We should port CLIPVisionModel's code over to not depend on # transformers' impl. from transformers import CLIPVisionModel, LlavaConfig @@ -51,10 +51,10 @@ def forward(self, image_features: torch.Tensor) -> torch.Tensor: return hidden_states -def _merge_vision_embeddings(input_ids: torch.Tensor, - inputs_embeds: torch.Tensor, - vision_embeddings: torch.Tensor, - image_token_id: int) -> torch.Tensor: +def merge_vision_embeddings(input_ids: torch.Tensor, + inputs_embeds: torch.Tensor, + vision_embeddings: torch.Tensor, + image_token_id: int) -> torch.Tensor: """In place merges in vision_embeddings with inputs_embeds.""" mask = (input_ids == image_token_id) @@ -151,7 +151,8 @@ def _parse_and_validate_image_input( return None if not isinstance(pixel_values, torch.Tensor): - raise ValueError("Incorrect type of pixel values") + raise ValueError("Incorrect type of pixel values. " + f"Got type: {type(pixel_values)}") return LlavaImagePixelInputs( type="pixel_values", @@ -166,7 +167,8 @@ def _parse_and_validate_image_input( return None if not isinstance(image_features, torch.Tensor): - raise ValueError("Incorrect type of image features") + raise ValueError("Incorrect type of image features. " + f"Got type: {type(image_features)}") return LlavaImageFeatureInputs( type="image_features", @@ -268,7 +270,7 @@ def forward( vision_embeddings = self._process_image_input(image_input) inputs_embeds = self.language_model.get_input_embeddings(input_ids) - inputs_embeds = _merge_vision_embeddings( + inputs_embeds = merge_vision_embeddings( input_ids, inputs_embeds, vision_embeddings, self.vision_language_config.image_token_id) diff --git a/vllm/model_executor/models/llava_next.py b/vllm/model_executor/models/llava_next.py new file mode 100644 index 0000000000000..bb15dcb8ed917 --- /dev/null +++ b/vllm/model_executor/models/llava_next.py @@ -0,0 +1,445 @@ +from typing import (Dict, Iterable, List, Literal, Optional, Tuple, TypedDict, + Union) + +import torch +import torch.nn as nn +from PIL import Image +# TODO(xwjiang): We should port CLIPVisionModel's code over to not depend on +# transformers' impl. +from transformers import CLIPVisionModel, LlavaNextConfig +from transformers.models.llava_next.modeling_llava_next import ( + get_anyres_image_grid_shape, unpad_image) +from typing_extensions import NotRequired + +from vllm.attention import AttentionMetadata +from vllm.config import CacheConfig, ModelConfig, VisionLanguageConfig +from vllm.logger import init_logger +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models.llama import LlamaModel +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalData +from vllm.multimodal.image import ImagePixelData, get_dummy_image_data +from vllm.sequence import SamplerOutput, SequenceData + +from .llava import LlavaMultiModalProjector, merge_vision_embeddings +from .vlm_base import VisionLanguageModelBase + +logger = init_logger(__name__) + +_KEYS_TO_MODIFY_MAPPING = { + "language_model.lm_head": "lm_head", + "language_model.model": "language_model", +} + + +class LlavaNextImagePixelInputs(TypedDict): + type: Literal["pixel_values"] + data: torch.Tensor + """Shape: (batch_size, 1 + num_patches, num_channels, height, width)""" + + image_sizes: NotRequired[torch.Tensor] + """Shape: (batch_size, 2)""" + + +class LlavaNextImageFeatureInputs(TypedDict): + type: Literal["image_features"] + data: torch.Tensor + """Shape: (batch_size, 1 + num_patches, image_feature_size, hidden_size)""" + + image_sizes: NotRequired[torch.Tensor] + """Shape: (batch_size, 2)""" + + +LlavaNextImageInputs = Union[LlavaNextImagePixelInputs, + LlavaNextImageFeatureInputs] + + +def _get_dummy_image_data( + seq_len: int, + model_config: ModelConfig, + vlm_config: VisionLanguageConfig, +) -> Tuple[SequenceData, MultiModalData]: + seq_data, fake_mm_data = get_dummy_image_data(seq_len, model_config, + vlm_config) + + config_input_type = vlm_config.image_input_type + ImageInputType = VisionLanguageConfig.ImageInputType + + if config_input_type == ImageInputType.PIXEL_VALUES: + _, c, h, w = vlm_config.image_input_shape + mode = {1: "L", 3: "RGB"}[c] + fake_mm_data = ImagePixelData(Image.new(mode, (w, h), color=0)) + + return seq_data, fake_mm_data + + +def _image_pixel_processor( + data: ImagePixelData, + model_config: ModelConfig, + vlm_config: VisionLanguageConfig, +) -> Dict[str, torch.Tensor]: + image = data.image + + if isinstance(image, torch.Tensor): + pixel_values = image.to(model_config.dtype) + batch_size, _, _, h, w = pixel_values.shape + image_sizes = torch.tensor([(w, h) for _ in range(batch_size)]) + + return {"pixel_values": pixel_values, "image_sizes": image_sizes} + + # Temporary patch before dynamic number of image tokens is supported + _, _, h, w = vlm_config.image_input_shape + if (w, h) != (image.width, image.height): + logger.warning( + "Dynamic image shape is currently not supported. " + "Resizing input image to (%d, %d).", w, h) + + data.image = image.resize((w, h)) + + return MULTIMODAL_REGISTRY._get_plugin_for_data_type(ImagePixelData) \ + ._default_input_processor(data, model_config, vlm_config) + + +@MULTIMODAL_REGISTRY.register_image_pixel_input(_image_pixel_processor) +@MULTIMODAL_REGISTRY.register_dummy_data(_get_dummy_image_data) +class LlavaNextForConditionalGeneration(VisionLanguageModelBase): + """ + Args to `forward()`: + input_ids: Flattened (concatenated) input_ids corresponding to a + batch. + pixel_values: For PIXEL_VALUES, expects a batch with shape + [1, num_patches, 3, 336, 336]. + image_features: For IMAGE_FEATURES, expects a batch with shape + [1, num_patches, 1176, 1024]. + """ + + def __init__(self, + config: LlavaNextConfig, + vision_language_config: VisionLanguageConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None) -> None: + super().__init__(vision_language_config) + + # Update the type annotation from that of its superclass + self.config = config + + if self.vision_language_config.image_input_type == ( + VisionLanguageConfig.ImageInputType.PIXEL_VALUES): + self.vision_tower = CLIPVisionModel(config.vision_config) + else: + raise TypeError("Image features are not supported by LLaVA-NeXT") + + self.multi_modal_projector = LlavaMultiModalProjector( + vision_hidden_size=config.vision_config.hidden_size, + text_hidden_size=config.text_config.hidden_size, + projector_hidden_act=config.projector_hidden_act) + + self.quant_config = quant_config + self.language_model = LlamaModel(config.text_config, cache_config, + quant_config) + self.unpadded_vocab_size = config.text_config.vocab_size + self.lm_head = ParallelLMHead( + self.unpadded_vocab_size, + config.text_config.hidden_size, + org_num_embeddings=self.language_model.org_vocab_size) + logit_scale = getattr(config, "logit_scale", 1.0) + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + config.vocab_size, logit_scale) + self.sampler = Sampler() + + self.image_newline = nn.Parameter( + torch.empty(config.text_config.hidden_size)) + + def _validate_image_pixels(self, data: torch.Tensor) -> torch.Tensor: + _, num_channels, _, _ = self.vision_language_config.image_input_shape + + # Note that this is different from that of vLLM vision_language_config + # since the image is resized by the HuggingFace preprocessor + height = width = self.config.vision_config.image_size + + if list(data.shape[2:]) != [num_channels, height, width]: + raise ValueError( + f"The expected image tensor shape is batch dimension plus " + f"num_patches plus {[num_channels, height, width]}. " + f"You supplied {data.shape}. " + f"If you are using vLLM's entrypoint, make sure your " + f"supplied image input is consistent with " + f"image_input_shape in engine args.") + + return data + + def _validate_image_sizes(self, data: torch.Tensor) -> torch.Tensor: + if list(data.shape[1:]) != [2]: + raise ValueError( + f"The expected image sizes shape is batch dimension plus " + f"{[2]}. You supplied {data.shape}.") + + return data + + def _parse_and_validate_image_input( + self, **kwargs: object) -> Optional[LlavaNextImageInputs]: + pixel_values = kwargs.pop("pixel_values", None) + image_sizes = kwargs.pop("image_sizes", None) + image_features = kwargs.pop("image_features", None) + + expected_input_type = self.vision_language_config.image_input_type + ImageInputType = VisionLanguageConfig.ImageInputType + + if expected_input_type == ImageInputType.PIXEL_VALUES: + if image_features is not None: + raise ValueError( + "Expected pixel values but got image features") + if pixel_values is None: + return None + + if not isinstance(pixel_values, torch.Tensor): + raise ValueError("Incorrect type of pixel values. " + f"Got type: {type(pixel_values)}") + + if not isinstance(image_sizes, torch.Tensor): + raise ValueError("Incorrect type of image sizes. " + f"Got type: {type(image_sizes)}") + + return LlavaNextImagePixelInputs( + type="pixel_values", + data=self._validate_image_pixels(pixel_values), + image_sizes=self._validate_image_sizes(image_sizes), + ) + + assert expected_input_type != ImageInputType.IMAGE_FEATURES, ( + "Failed to validate this at initialization time") + + return None + + def _merge_image_patch_embeddings(self, image_size: torch.Tensor, + patch_embeddings: torch.Tensor, *, + strategy: str) -> torch.Tensor: + # Based on: https://github.com/haotian-liu/LLaVA/blob/main/llava/model/llava_arch.py + if strategy == "flat": + return patch_embeddings.flatten(0, 1) + + if strategy.startswith("spatial"): + orig_width, orig_height = image_size + height = width = self.config.vision_config.image_size \ + // self.config.vision_config.patch_size + + base_patch_embeds = patch_embeddings[0] + if height * width != base_patch_embeds.shape[0]: + raise ValueError( + "The number of patches is not consistent with the " + "image size.") + + if patch_embeddings.shape[0] > 1: + other_patch_embeds = patch_embeddings[1:] + + # image_aspect_ratio == "anyres" + num_patch_width, num_patch_height = get_anyres_image_grid_shape( + (orig_width, orig_height), + self.config.image_grid_pinpoints, + self.config.vision_config.image_size, + ) + other_patch_embeds = other_patch_embeds \ + .view(num_patch_width, num_patch_height, height, width, -1) + + if "unpad" in strategy: + other_patch_embeds = other_patch_embeds \ + .permute(4, 0, 2, 1, 3).contiguous() \ + .flatten(1, 2).flatten(2, 3) + other_patch_embeds = unpad_image(other_patch_embeds, + image_size) + other_patch_embeds = torch.cat(( + other_patch_embeds, + self.image_newline[:, None, None] \ + .expand(*other_patch_embeds.shape[:-1], 1) \ + .to(other_patch_embeds.device), + ), dim=-1) + other_patch_embeds = other_patch_embeds \ + .flatten(1, 2).transpose(0, 1) + else: + other_patch_embeds = other_patch_embeds \ + .permute(0, 2, 1, 3, 4).contiguous() \ + .flatten(0, 3) + + merged_patch_embeddings = torch.cat( + (base_patch_embeds, other_patch_embeds), dim=0) + else: + if "unpad" in strategy: + merged_patch_embeddings = torch.cat( + (base_patch_embeds, + self.image_newline[None] \ + .to(base_patch_embeds.device) + ), dim=0) + else: + merged_patch_embeddings = base_patch_embeds + + return merged_patch_embeddings + + raise ValueError(f"Unexpected patch merge strategy: {strategy}") + + def _process_image_pixels( + self, inputs: LlavaNextImagePixelInputs) -> torch.Tensor: + assert self.vision_tower is not None + + pixel_values = inputs["data"] + + b, num_patches, c, h, w = pixel_values.shape + stacked_pixel_values = pixel_values.view(b * num_patches, c, h, w) + + stacked_image_features = self._image_pixels_to_features( + self.vision_tower, stacked_pixel_values) + + return stacked_image_features.view(b, num_patches, + *stacked_image_features.shape[-2:]) + + def _process_image_input( + self, image_input: LlavaNextImageInputs) -> torch.Tensor: + if image_input["type"] == "pixel_values": + assert self.vision_tower is not None + image_features = self._process_image_pixels(image_input) + else: + image_features = image_input["data"] + + patch_embeddings = self.multi_modal_projector(image_features) + + image_sizes = image_input.get("image_sizes") + if image_sizes is None: + batch_size = image_input["data"].shape[0] + vision_config = self.config.vision_config + default_width = default_height = vision_config.image_size + image_sizes = torch.as_tensor([[default_width, default_height] + for _ in range(batch_size)]) + + merged_patch_embeddings = [ + self._merge_image_patch_embeddings(image_sizes[i], + patch_features, + strategy="spatial_unpad") + for i, patch_features in enumerate(patch_embeddings) + ] + + return torch.stack(merged_patch_embeddings, dim=0) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + **kwargs: object, + ) -> SamplerOutput: + """Run forward pass for Llava 1.5. + + One key thing to understand is the `input_ids` already accounts for the + positions of the to-be-inserted image embeddings. + Concretely, consider a text prompt: + "\nUSER: What's the content of the image?\nASSISTANT:". + Tokenizer outputs: + [1, 32000, 29871, 13, 11889, 29901, 1724, 29915, 29879, 278, + 2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566, 29901]. + The to-be-inserted image has a size of 576 (24 * 24) along the context + length dimension. + `input_ids` is thus [1, 32000, ..., 32000, 29871, 13, 11889, 29901, + 1724, 29915, 29879, 278, 2793, 310, 278, 1967, 29973, 13, 22933, + 9047, 13566, 29901]. + There will be 576 `32000` in the `input_ids`. + (32000 is the token id for ``.) + + This way, the `positions` and `attn_metadata` are consistent + with the `input_ids`. + + The model takes two types of image inputs: + PIXEL_VALUES and IMAGE_FEATURES. + The following shows how each maps to huggingface implementation. + PIXEL_VALUES: + - https://github.com/huggingface/transformers/blob/07bdbeb/src/transformers/models/llava/modeling_llava.py#L353 + IMAGE_FEATURES: + - https://github.com/huggingface/transformers/blob/07bdbeb/src/transformers/models/llava/modeling_llava.py#L430 + before going through the multi modal projector. + + Args: + input_ids: Flattened (concatenated) input_ids corresponding to a + batch. + pixel_values: For PIXEL_VALUES, expects a batch with shape + [1, 3, 336, 336]. + image_features: For IMAGE_FEATURES, expects a batch with shape + [1, 576, 1024]. + """ + image_input = self._parse_and_validate_image_input(**kwargs) + + if image_input is not None: + vision_embeddings = self._process_image_input(image_input) + inputs_embeds = self.language_model.get_input_embeddings(input_ids) + + inputs_embeds = merge_vision_embeddings( + input_ids, inputs_embeds, vision_embeddings, + self.vision_language_config.image_token_id) + + input_ids = None + else: + inputs_embeds = None + + hidden_states = self.language_model(input_ids, + positions, + kv_caches, + attn_metadata, + inputs_embeds=inputs_embeds) + + return hidden_states + + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(self.lm_head.weight, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + # only doing this for language model part for now. + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + params_dict = dict(self.named_parameters()) + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items(): + if key_to_modify in name: + name = name.replace(key_to_modify, new_key) + use_default_weight_loading = False + if "vision" in name: + if self.vision_tower is not None: + # We only do sharding for language model and + # not vision model for now. + use_default_weight_loading = True + else: + for (param_name, weight_name, + shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + param = params_dict[name.replace(weight_name, param_name)] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + use_default_weight_loading = True + if use_default_weight_loading: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight)