Skip to content

Commit

Permalink
[Model] Support SigLIP encoder and alternative decoders for LLaVA mod…
Browse files Browse the repository at this point in the history
…els (#7153)

Co-authored-by: Roger Wang <136131678+ywang96@users.noreply.github.com>
  • Loading branch information
DarkLight1337 and ywang96 authored Aug 6, 2024
1 parent 9118217 commit 1f26efb
Show file tree
Hide file tree
Showing 14 changed files with 455 additions and 269 deletions.
3 changes: 3 additions & 0 deletions requirements-test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ sentence-transformers # required for embedding
compressed-tensors==0.4.0 # required for compressed-tensors
timm # required for internvl test

# TODO: Add this after fully implementing llava(mantis)
# git+https://github.com/TIGER-AI-Lab/Mantis.git # required for llava(mantis) test

# Benchmarking
aiohttp

Expand Down
36 changes: 31 additions & 5 deletions tests/models/test_llava.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from typing import List, Optional, Tuple, Type

import pytest
from transformers import AutoTokenizer
from transformers import AutoConfig, AutoTokenizer

from vllm.multimodal.utils import rescale_image_size
from vllm.sequence import SampleLogprobs
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE

from ..conftest import IMAGE_ASSETS, HfRunner, VllmRunner, _ImageAssets
from .utils import check_logprobs_close
Expand All @@ -18,9 +19,11 @@
"USER: <image>\nWhat is the season?\nASSISTANT:",
})

IMAGE_TOKEN_ID = 32000

models = ["llava-hf/llava-1.5-7b-hf"]
models = [
"llava-hf/llava-1.5-7b-hf",
# TODO: Get this model to produce meaningful output in vLLM
# "TIGER-Lab/Mantis-8B-siglip-llama3",
]


def vllm_to_hf_output(vllm_output: Tuple[List[int], str,
Expand All @@ -29,12 +32,15 @@ def vllm_to_hf_output(vllm_output: Tuple[List[int], str,
"""Sanitize vllm output to be comparable with hf output."""
output_ids, output_str, out_logprobs = vllm_output

config = AutoConfig.from_pretrained(model)
image_token_id = config.image_token_index

tokenizer = AutoTokenizer.from_pretrained(model)
eos_token_id = tokenizer.eos_token_id

hf_output_ids = [
token_id for idx, token_id in enumerate(output_ids)
if token_id != IMAGE_TOKEN_ID or output_ids[idx - 1] != IMAGE_TOKEN_ID
if token_id != image_token_id or output_ids[idx - 1] != image_token_id
]

assert output_str[0] == " "
Expand Down Expand Up @@ -67,6 +73,17 @@ def run_test(
Note, the text input is also adjusted to abide by vllm contract.
The text output is sanitized to be able to compare with hf.
"""
# NOTE: For local use; this isn't tested in CI yet (see TODO above)
if model.startswith("TIGER-Lab/Mantis"):
from mantis.models.mllava import MLlavaProcessor

torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[dtype]
mantis_processor = MLlavaProcessor.from_pretrained(
model, torch_dtype=torch_dtype)
assert isinstance(mantis_processor, MLlavaProcessor)
else:
mantis_processor = None

images = [asset.pil_image for asset in image_assets]

inputs_per_image = [(
Expand Down Expand Up @@ -94,6 +111,15 @@ def run_test(
]

with hf_runner(model, dtype=dtype, is_vision_model=True) as hf_model:
if mantis_processor is not None:

def process(*args, **kwargs):
output = mantis_processor(*args, **kwargs)
output["pixel_values"] = output["pixel_values"].to(torch_dtype)
return output

hf_model.processor = process

hf_outputs_per_image = [
hf_model.generate_greedy_logprobs_limit(prompts,
max_tokens,
Expand Down
9 changes: 5 additions & 4 deletions tests/models/test_llava_next.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import List, Optional, Tuple, Type, overload

import pytest
from transformers import AutoTokenizer
from transformers import AutoConfig, AutoTokenizer

from vllm.multimodal.utils import rescale_image_size
from vllm.sequence import SampleLogprobs
Expand All @@ -23,8 +23,6 @@
f"{_PREFACE} USER: <image>\nWhat is the season? ASSISTANT:",
})

IMAGE_TOKEN_ID = 32000

models = ["llava-hf/llava-v1.6-vicuna-7b-hf"]


Expand All @@ -34,12 +32,15 @@ def vllm_to_hf_output(vllm_output: Tuple[List[int], str,
"""Sanitize vllm output to be comparable with hf output."""
output_ids, output_str, out_logprobs = vllm_output

config = AutoConfig.from_pretrained(model)
image_token_id = config.image_token_index

tokenizer = AutoTokenizer.from_pretrained(model)
eos_token_id = tokenizer.eos_token_id

hf_output_ids = [
token_id for idx, token_id in enumerate(output_ids)
if token_id != IMAGE_TOKEN_ID or output_ids[idx - 1] != IMAGE_TOKEN_ID
if token_id != image_token_id or output_ids[idx - 1] != image_token_id
]

assert output_str[0] == " "
Expand Down
9 changes: 5 additions & 4 deletions tests/models/test_paligemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import List, Optional, Tuple, Type

import pytest
from transformers import AutoTokenizer
from transformers import AutoConfig, AutoTokenizer

from vllm.multimodal.utils import rescale_image_size
from vllm.sequence import SampleLogprobs
Expand All @@ -20,8 +20,6 @@
"What is in the picture?",
})

IMAGE_TOKEN_ID = 257152

models = ["google/paligemma-3b-mix-224"]

# ROCm Triton FA can run into compilation issues with these models due to,
Expand All @@ -37,12 +35,15 @@ def vllm_to_hf_output(vllm_output: Tuple[List[int], str,
"""Sanitize vllm output to be comparable with hf output."""
output_ids, output_str, out_logprobs = vllm_output

config = AutoConfig.from_pretrained(model)
image_token_id = config.image_token_index

tokenizer = AutoTokenizer.from_pretrained(model)
eos_token_id = tokenizer.eos_token_id

hf_output_ids = [
token_id for idx, token_id in enumerate(output_ids)
if token_id != IMAGE_TOKEN_ID or output_ids[idx - 1] != IMAGE_TOKEN_ID
if token_id != image_token_id or output_ids[idx - 1] != image_token_id
]

hf_output_str = output_str
Expand Down
2 changes: 1 addition & 1 deletion tests/models/test_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,4 @@
@pytest.mark.parametrize("model_cls", _MODELS)
def test_registry_imports(model_cls):
# Ensure all model classes can be imported successfully
ModelRegistry.load_model_cls(model_cls)
ModelRegistry.resolve_model_cls([model_cls])
38 changes: 28 additions & 10 deletions vllm/model_executor/model_loader/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import torch
from huggingface_hub import HfApi, hf_hub_download
from torch import nn
from transformers import AutoModelForCausalLM
from transformers import AutoModelForCausalLM, PretrainedConfig

from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoadFormat,
LoRAConfig, ModelConfig, MultiModalConfig,
Expand Down Expand Up @@ -143,6 +143,22 @@ def _get_model_initialization_kwargs(
return extra_kwargs


def build_model(model_class: Type[nn.Module], hf_config: PretrainedConfig,
cache_config: Optional[CacheConfig],
quant_config: Optional[QuantizationConfig], *,
lora_config: Optional[LoRAConfig],
multimodal_config: Optional[MultiModalConfig],
scheduler_config: Optional[SchedulerConfig]) -> nn.Module:
extra_kwargs = _get_model_initialization_kwargs(model_class, lora_config,
multimodal_config,
scheduler_config)

return model_class(config=hf_config,
cache_config=cache_config,
quant_config=quant_config,
**extra_kwargs)


def _initialize_model(
model_config: ModelConfig,
load_config: LoadConfig,
Expand All @@ -151,15 +167,17 @@ def _initialize_model(
cache_config: CacheConfig,
scheduler_config: Optional[SchedulerConfig] = None) -> nn.Module:
"""Initialize a model with the given configurations."""
model_class = get_model_architecture(model_config)[0]
quant_config = _get_quantization_config(model_config, load_config)

return model_class(config=model_config.hf_config,
cache_config=cache_config,
quant_config=quant_config,
**_get_model_initialization_kwargs(
model_class, lora_config, multimodal_config,
scheduler_config))
model_class, _ = get_model_architecture(model_config)

return build_model(
model_class,
model_config.hf_config,
quant_config=_get_quantization_config(model_config, load_config),
lora_config=lora_config,
multimodal_config=multimodal_config,
cache_config=cache_config,
scheduler_config=scheduler_config,
)


class BaseModelLoader(ABC):
Expand Down
8 changes: 1 addition & 7 deletions vllm/model_executor/model_loader/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,7 @@ def get_model_architecture(
and "MixtralForCausalLM" in architectures):
architectures = ["QuantMixtralForCausalLM"]

for arch in architectures:
model_cls = ModelRegistry.load_model_cls(arch)
if model_cls is not None:
return (model_cls, arch)
raise ValueError(
f"Model architectures {architectures} are not supported for now. "
f"Supported architectures: {ModelRegistry.get_supported_archs()}")
return ModelRegistry.resolve_model_cls(architectures)


def get_architecture_class_name(model_config: ModelConfig) -> str:
Expand Down
16 changes: 14 additions & 2 deletions vllm/model_executor/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import functools
import importlib
from typing import Dict, List, Optional, Type
from typing import Dict, List, Optional, Tuple, Type

import torch.nn as nn

Expand Down Expand Up @@ -126,7 +126,7 @@ def _get_model(model_arch: str):
return getattr(module, model_cls_name, None)

@staticmethod
def load_model_cls(model_arch: str) -> Optional[Type[nn.Module]]:
def _try_load_model_cls(model_arch: str) -> Optional[Type[nn.Module]]:
if model_arch in _OOT_MODELS:
return _OOT_MODELS[model_arch]
if model_arch not in _MODELS:
Expand All @@ -143,6 +143,18 @@ def load_model_cls(model_arch: str) -> Optional[Type[nn.Module]]:

return ModelRegistry._get_model(model_arch)

@staticmethod
def resolve_model_cls(
architectures: List[str]) -> Tuple[Type[nn.Module], str]:
for arch in architectures:
model_cls = ModelRegistry._try_load_model_cls(arch)
if model_cls is not None:
return (model_cls, arch)

raise ValueError(
f"Model architectures {architectures} are not supported for now. "
f"Supported architectures: {ModelRegistry.get_supported_archs()}")

@staticmethod
def get_supported_archs() -> List[str]:
return list(_MODELS.keys())
Expand Down
24 changes: 22 additions & 2 deletions vllm/model_executor/models/clip.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Minimal implementation of CLIPVisionModel intended to be only used
within a vision language model."""
from typing import Optional
from typing import Iterable, Optional, Tuple

import torch
import torch.nn as nn
Expand All @@ -14,6 +14,7 @@
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.multimodal.image import (cached_get_tokenizer,
repeat_and_pad_image_tokens)
from vllm.sequence import SequenceData
Expand All @@ -32,7 +33,7 @@ def get_clip_num_patches(*, image_size: int, patch_size: int) -> int:

def get_clip_image_feature_size(hf_config: CLIPVisionConfig) -> int:
return get_clip_num_patches(image_size=hf_config.image_size,
patch_size=hf_config.patch_size)
patch_size=hf_config.patch_size) + 1


def get_max_clip_image_tokens(hf_config: CLIPVisionConfig) -> int:
Expand Down Expand Up @@ -291,3 +292,22 @@ def forward(self, pixel_values: Optional[torch.Tensor] = None):
@property
def device(self):
return next(self.parameters()).device

def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
params_dict = dict(self.named_parameters())
layer_count = len(self.vision_model.encoder.layers)

for name, loaded_weight in weights:
# post_layernorm is not needed in CLIPVisionModel
if "vision_model.post_layernorm" in name:
continue
# omit layers when num_hidden_layers_override is set
if "vision_model.encoder.layers." in name:
layer_idx = int(name.split(".")[3])
if layer_idx >= layer_count:
continue

param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
24 changes: 7 additions & 17 deletions vllm/model_executor/models/internvl.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models import ModelRegistry
from vllm.model_executor.models.intern_vit import InternVisionModel
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
Expand All @@ -29,7 +28,8 @@
from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip,
get_clip_num_patches)
from .interfaces import SupportsVision
from .utils import merge_vision_embeddings
from .utils import (filter_weights, init_vllm_registered_model,
merge_vision_embeddings)

IMG_START = '<img>'
IMG_END = '</img>'
Expand Down Expand Up @@ -283,10 +283,8 @@ def __init__(self,
self.vision_model = InternVisionModel(
config.vision_config, num_hidden_layers_override=num_hidden_layers)

llm_class = ModelRegistry.load_model_cls(
config.text_config.architectures[0])
self.language_model = llm_class(config.text_config, cache_config,
quant_config)
self.language_model = init_vllm_registered_model(
config.text_config, cache_config, quant_config)

vit_hidden_size = config.vision_config.hidden_size
llm_hidden_size = config.text_config.hidden_size
Expand Down Expand Up @@ -415,24 +413,16 @@ def sample(
) -> Optional[SamplerOutput]:
return self.language_model.sample(logits, sampling_metadata)

def _filter_weights(self, weights: Iterable[Tuple[str, torch.Tensor]],
prefix: str):
for name, loaded_weight in weights:
name = name.split(".")
if prefix == name.pop(0):
name = ".".join(name)
yield name, loaded_weight

def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
# prepare weight iterators for components
vit_weights, mlp_weights, llm_weights = itertools.tee(weights, 3)

# load vision encoder
vit_weights = self._filter_weights(vit_weights, "vision_model")
vit_weights = filter_weights(vit_weights, "vision_model")
self.vision_model.load_weights(vit_weights)

# load mlp projector
mlp_weights = self._filter_weights(mlp_weights, "mlp1")
mlp_weights = filter_weights(mlp_weights, "mlp1")
mlp_params_dict = dict(self.mlp1.named_parameters())
for name, loaded_weight in mlp_weights:
param = mlp_params_dict[name]
Expand All @@ -441,5 +431,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
weight_loader(param, loaded_weight)

# load llm backbone
llm_weights = self._filter_weights(llm_weights, "language_model")
llm_weights = filter_weights(llm_weights, "language_model")
self.language_model.load_weights(llm_weights)
Loading

0 comments on commit 1f26efb

Please sign in to comment.