Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Model] Support SigLIP encoder and alternative decoders for LLaVA models #7153

Merged
merged 19 commits into from
Aug 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions requirements-test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ sentence-transformers # required for embedding
compressed-tensors==0.4.0 # required for compressed-tensors
timm # required for internvl test

# TODO: Add this after fully implementing llava(mantis)
# git+https://github.com/TIGER-AI-Lab/Mantis.git # required for llava(mantis) test

# Benchmarking
aiohttp

Expand Down
36 changes: 31 additions & 5 deletions tests/models/test_llava.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from typing import List, Optional, Tuple, Type

import pytest
from transformers import AutoTokenizer
from transformers import AutoConfig, AutoTokenizer

from vllm.multimodal.utils import rescale_image_size
from vllm.sequence import SampleLogprobs
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE

from ..conftest import IMAGE_ASSETS, HfRunner, VllmRunner, _ImageAssets
from .utils import check_logprobs_close
Expand All @@ -18,9 +19,11 @@
"USER: <image>\nWhat is the season?\nASSISTANT:",
})

IMAGE_TOKEN_ID = 32000

models = ["llava-hf/llava-1.5-7b-hf"]
models = [
"llava-hf/llava-1.5-7b-hf",
# TODO: Get this model to produce meaningful output in vLLM
# "TIGER-Lab/Mantis-8B-siglip-llama3",
]


def vllm_to_hf_output(vllm_output: Tuple[List[int], str,
Expand All @@ -29,12 +32,15 @@ def vllm_to_hf_output(vllm_output: Tuple[List[int], str,
"""Sanitize vllm output to be comparable with hf output."""
output_ids, output_str, out_logprobs = vllm_output

config = AutoConfig.from_pretrained(model)
image_token_id = config.image_token_index
DarkLight1337 marked this conversation as resolved.
Show resolved Hide resolved

tokenizer = AutoTokenizer.from_pretrained(model)
eos_token_id = tokenizer.eos_token_id

hf_output_ids = [
token_id for idx, token_id in enumerate(output_ids)
if token_id != IMAGE_TOKEN_ID or output_ids[idx - 1] != IMAGE_TOKEN_ID
if token_id != image_token_id or output_ids[idx - 1] != image_token_id
]

assert output_str[0] == " "
Expand Down Expand Up @@ -67,6 +73,17 @@ def run_test(
Note, the text input is also adjusted to abide by vllm contract.
The text output is sanitized to be able to compare with hf.
"""
# NOTE: For local use; this isn't tested in CI yet (see TODO above)
if model.startswith("TIGER-Lab/Mantis"):
DarkLight1337 marked this conversation as resolved.
Show resolved Hide resolved
from mantis.models.mllava import MLlavaProcessor

torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[dtype]
mantis_processor = MLlavaProcessor.from_pretrained(
model, torch_dtype=torch_dtype)
assert isinstance(mantis_processor, MLlavaProcessor)
else:
mantis_processor = None

images = [asset.pil_image for asset in image_assets]

inputs_per_image = [(
Expand Down Expand Up @@ -94,6 +111,15 @@ def run_test(
]

with hf_runner(model, dtype=dtype, is_vision_model=True) as hf_model:
if mantis_processor is not None:

def process(*args, **kwargs):
output = mantis_processor(*args, **kwargs)
output["pixel_values"] = output["pixel_values"].to(torch_dtype)
return output

hf_model.processor = process

hf_outputs_per_image = [
hf_model.generate_greedy_logprobs_limit(prompts,
max_tokens,
Expand Down
9 changes: 5 additions & 4 deletions tests/models/test_llava_next.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import List, Optional, Tuple, Type, overload

import pytest
from transformers import AutoTokenizer
from transformers import AutoConfig, AutoTokenizer

from vllm.multimodal.utils import rescale_image_size
from vllm.sequence import SampleLogprobs
Expand All @@ -23,8 +23,6 @@
f"{_PREFACE} USER: <image>\nWhat is the season? ASSISTANT:",
})

IMAGE_TOKEN_ID = 32000

models = ["llava-hf/llava-v1.6-vicuna-7b-hf"]


Expand All @@ -34,12 +32,15 @@ def vllm_to_hf_output(vllm_output: Tuple[List[int], str,
"""Sanitize vllm output to be comparable with hf output."""
output_ids, output_str, out_logprobs = vllm_output

config = AutoConfig.from_pretrained(model)
image_token_id = config.image_token_index

tokenizer = AutoTokenizer.from_pretrained(model)
eos_token_id = tokenizer.eos_token_id

hf_output_ids = [
token_id for idx, token_id in enumerate(output_ids)
if token_id != IMAGE_TOKEN_ID or output_ids[idx - 1] != IMAGE_TOKEN_ID
if token_id != image_token_id or output_ids[idx - 1] != image_token_id
]

assert output_str[0] == " "
Expand Down
9 changes: 5 additions & 4 deletions tests/models/test_paligemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import List, Optional, Tuple, Type

import pytest
from transformers import AutoTokenizer
from transformers import AutoConfig, AutoTokenizer

from vllm.multimodal.utils import rescale_image_size
from vllm.sequence import SampleLogprobs
Expand All @@ -20,8 +20,6 @@
"What is in the picture?",
})

IMAGE_TOKEN_ID = 257152

models = ["google/paligemma-3b-mix-224"]

# ROCm Triton FA can run into compilation issues with these models due to,
Expand All @@ -37,12 +35,15 @@ def vllm_to_hf_output(vllm_output: Tuple[List[int], str,
"""Sanitize vllm output to be comparable with hf output."""
output_ids, output_str, out_logprobs = vllm_output

config = AutoConfig.from_pretrained(model)
image_token_id = config.image_token_index

tokenizer = AutoTokenizer.from_pretrained(model)
eos_token_id = tokenizer.eos_token_id

hf_output_ids = [
token_id for idx, token_id in enumerate(output_ids)
if token_id != IMAGE_TOKEN_ID or output_ids[idx - 1] != IMAGE_TOKEN_ID
if token_id != image_token_id or output_ids[idx - 1] != image_token_id
]

hf_output_str = output_str
Expand Down
2 changes: 1 addition & 1 deletion tests/models/test_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,4 @@
@pytest.mark.parametrize("model_cls", _MODELS)
def test_registry_imports(model_cls):
# Ensure all model classes can be imported successfully
ModelRegistry.load_model_cls(model_cls)
ModelRegistry.resolve_model_cls([model_cls])
38 changes: 28 additions & 10 deletions vllm/model_executor/model_loader/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import torch
from huggingface_hub import HfApi, hf_hub_download
from torch import nn
from transformers import AutoModelForCausalLM
from transformers import AutoModelForCausalLM, PretrainedConfig

from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoadFormat,
LoRAConfig, ModelConfig, MultiModalConfig,
Expand Down Expand Up @@ -143,6 +143,22 @@ def _get_model_initialization_kwargs(
return extra_kwargs


def build_model(model_class: Type[nn.Module], hf_config: PretrainedConfig,
cache_config: Optional[CacheConfig],
quant_config: Optional[QuantizationConfig], *,
lora_config: Optional[LoRAConfig],
multimodal_config: Optional[MultiModalConfig],
scheduler_config: Optional[SchedulerConfig]) -> nn.Module:
extra_kwargs = _get_model_initialization_kwargs(model_class, lora_config,
multimodal_config,
scheduler_config)

return model_class(config=hf_config,
cache_config=cache_config,
quant_config=quant_config,
**extra_kwargs)


def _initialize_model(
model_config: ModelConfig,
load_config: LoadConfig,
Expand All @@ -151,15 +167,17 @@ def _initialize_model(
cache_config: CacheConfig,
scheduler_config: Optional[SchedulerConfig] = None) -> nn.Module:
"""Initialize a model with the given configurations."""
model_class = get_model_architecture(model_config)[0]
quant_config = _get_quantization_config(model_config, load_config)

return model_class(config=model_config.hf_config,
cache_config=cache_config,
quant_config=quant_config,
**_get_model_initialization_kwargs(
model_class, lora_config, multimodal_config,
scheduler_config))
model_class, _ = get_model_architecture(model_config)

return build_model(
model_class,
model_config.hf_config,
quant_config=_get_quantization_config(model_config, load_config),
lora_config=lora_config,
multimodal_config=multimodal_config,
cache_config=cache_config,
scheduler_config=scheduler_config,
)


class BaseModelLoader(ABC):
Expand Down
8 changes: 1 addition & 7 deletions vllm/model_executor/model_loader/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,7 @@ def get_model_architecture(
and "MixtralForCausalLM" in architectures):
architectures = ["QuantMixtralForCausalLM"]

for arch in architectures:
model_cls = ModelRegistry.load_model_cls(arch)
if model_cls is not None:
return (model_cls, arch)
raise ValueError(
f"Model architectures {architectures} are not supported for now. "
f"Supported architectures: {ModelRegistry.get_supported_archs()}")
return ModelRegistry.resolve_model_cls(architectures)


def get_architecture_class_name(model_config: ModelConfig) -> str:
Expand Down
16 changes: 14 additions & 2 deletions vllm/model_executor/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import functools
import importlib
from typing import Dict, List, Optional, Type
from typing import Dict, List, Optional, Tuple, Type

import torch.nn as nn

Expand Down Expand Up @@ -126,7 +126,7 @@ def _get_model(model_arch: str):
return getattr(module, model_cls_name, None)

@staticmethod
def load_model_cls(model_arch: str) -> Optional[Type[nn.Module]]:
def _try_load_model_cls(model_arch: str) -> Optional[Type[nn.Module]]:
if model_arch in _OOT_MODELS:
return _OOT_MODELS[model_arch]
if model_arch not in _MODELS:
Expand All @@ -143,6 +143,18 @@ def load_model_cls(model_arch: str) -> Optional[Type[nn.Module]]:

return ModelRegistry._get_model(model_arch)

@staticmethod
def resolve_model_cls(
architectures: List[str]) -> Tuple[Type[nn.Module], str]:
for arch in architectures:
model_cls = ModelRegistry._try_load_model_cls(arch)
if model_cls is not None:
return (model_cls, arch)

raise ValueError(
f"Model architectures {architectures} are not supported for now. "
f"Supported architectures: {ModelRegistry.get_supported_archs()}")

@staticmethod
def get_supported_archs() -> List[str]:
return list(_MODELS.keys())
Expand Down
24 changes: 22 additions & 2 deletions vllm/model_executor/models/clip.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Minimal implementation of CLIPVisionModel intended to be only used
within a vision language model."""
from typing import Optional
from typing import Iterable, Optional, Tuple

import torch
import torch.nn as nn
Expand All @@ -14,6 +14,7 @@
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.multimodal.image import (cached_get_tokenizer,
repeat_and_pad_image_tokens)
from vllm.sequence import SequenceData
Expand All @@ -32,7 +33,7 @@ def get_clip_num_patches(*, image_size: int, patch_size: int) -> int:

def get_clip_image_feature_size(hf_config: CLIPVisionConfig) -> int:
return get_clip_num_patches(image_size=hf_config.image_size,
patch_size=hf_config.patch_size)
patch_size=hf_config.patch_size) + 1


def get_max_clip_image_tokens(hf_config: CLIPVisionConfig) -> int:
Expand Down Expand Up @@ -291,3 +292,22 @@ def forward(self, pixel_values: Optional[torch.Tensor] = None):
@property
def device(self):
return next(self.parameters()).device

def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
params_dict = dict(self.named_parameters())
layer_count = len(self.vision_model.encoder.layers)

for name, loaded_weight in weights:
# post_layernorm is not needed in CLIPVisionModel
if "vision_model.post_layernorm" in name:
continue
# omit layers when num_hidden_layers_override is set
if "vision_model.encoder.layers." in name:
layer_idx = int(name.split(".")[3])
if layer_idx >= layer_count:
continue

param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
24 changes: 7 additions & 17 deletions vllm/model_executor/models/internvl.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models import ModelRegistry
from vllm.model_executor.models.intern_vit import InternVisionModel
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
Expand All @@ -29,7 +28,8 @@
from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip,
get_clip_num_patches)
from .interfaces import SupportsVision
from .utils import merge_vision_embeddings
from .utils import (filter_weights, init_vllm_registered_model,
merge_vision_embeddings)

IMG_START = '<img>'
IMG_END = '</img>'
Expand Down Expand Up @@ -283,10 +283,8 @@ def __init__(self,
self.vision_model = InternVisionModel(
config.vision_config, num_hidden_layers_override=num_hidden_layers)

llm_class = ModelRegistry.load_model_cls(
config.text_config.architectures[0])
self.language_model = llm_class(config.text_config, cache_config,
quant_config)
self.language_model = init_vllm_registered_model(
config.text_config, cache_config, quant_config)

vit_hidden_size = config.vision_config.hidden_size
llm_hidden_size = config.text_config.hidden_size
Expand Down Expand Up @@ -415,24 +413,16 @@ def sample(
) -> Optional[SamplerOutput]:
return self.language_model.sample(logits, sampling_metadata)

def _filter_weights(self, weights: Iterable[Tuple[str, torch.Tensor]],
prefix: str):
for name, loaded_weight in weights:
name = name.split(".")
if prefix == name.pop(0):
name = ".".join(name)
yield name, loaded_weight

def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
# prepare weight iterators for components
vit_weights, mlp_weights, llm_weights = itertools.tee(weights, 3)

# load vision encoder
vit_weights = self._filter_weights(vit_weights, "vision_model")
vit_weights = filter_weights(vit_weights, "vision_model")
self.vision_model.load_weights(vit_weights)

# load mlp projector
mlp_weights = self._filter_weights(mlp_weights, "mlp1")
mlp_weights = filter_weights(mlp_weights, "mlp1")
mlp_params_dict = dict(self.mlp1.named_parameters())
for name, loaded_weight in mlp_weights:
param = mlp_params_dict[name]
Expand All @@ -441,5 +431,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
weight_loader(param, loaded_weight)

# load llm backbone
llm_weights = self._filter_weights(llm_weights, "language_model")
llm_weights = filter_weights(llm_weights, "language_model")
self.language_model.load_weights(llm_weights)
Loading
Loading