Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Model] Port over CLIPVisionModel for VLMs #5591

Merged
merged 16 commits into from
Jun 20, 2024
12 changes: 12 additions & 0 deletions csrc/activation_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,12 @@ __device__ __forceinline__ T gelu_fast_kernel(const T& x) {
return ((T)0.5) * x * (((T)1.0) + t);
}

template <typename T>
__device__ __forceinline__ T gelu_quick_kernel(const T& x) {
// x * sigmoid(1.702 * x)
return (T)(((float)x) / (1.0f + expf(-1.702f * (float)x)));
}

} // namespace vllm

void gelu_new(torch::Tensor& out, // [..., d]
Expand All @@ -148,3 +154,9 @@ void gelu_fast(torch::Tensor& out, // [..., d]
{
LAUNCH_ACTIVATION_KERNEL(vllm::gelu_fast_kernel);
}

void gelu_quick(torch::Tensor& out, // [..., d]
torch::Tensor& input) // [..., d]
{
LAUNCH_ACTIVATION_KERNEL(vllm::gelu_quick_kernel);
}
2 changes: 2 additions & 0 deletions csrc/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ void gelu_new(torch::Tensor& out, torch::Tensor& input);

void gelu_fast(torch::Tensor& out, torch::Tensor& input);

void gelu_quick(torch::Tensor& out, torch::Tensor& input);

#ifndef USE_ROCM
torch::Tensor aqlm_gemm(const torch::Tensor& input, const torch::Tensor& codes,
const torch::Tensor& codebooks,
Expand Down
4 changes: 4 additions & 0 deletions csrc/torch_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,10 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.def("gelu_fast(Tensor! out, Tensor input) -> ()");
ops.impl("gelu_fast", torch::kCUDA, &gelu_fast);

// Quick GELU implementation.
ops.def("gelu_quick(Tensor! out, Tensor input) -> ()");
ops.impl("gelu_quick", torch::kCUDA, &gelu_quick);

// Layernorm
// Apply Root Mean Square (RMS) Normalization to the input tensor.
ops.def(
Expand Down
4 changes: 4 additions & 0 deletions vllm/_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,10 @@ def gelu_new(out: torch.Tensor, x: torch.Tensor) -> None:
torch.ops._C.gelu_new(out, x)


def gelu_quick(out: torch.Tensor, x: torch.Tensor) -> None:
torch.ops._C.gelu_quick(out, x)


# page attention ops
def paged_attention_v1(
out: torch.Tensor,
Expand Down
16 changes: 16 additions & 0 deletions vllm/model_executor/layers/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,21 @@ def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
return out


class QuickGELU(CustomOp):

# https://github.com/huggingface/transformers/blob/main/src/transformers/activations.py#L90
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
"""PyTorch-native implementation equivalent to forward()."""
return x * torch.sigmoid(1.702 * x)

def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
from vllm import _custom_ops as ops

out = torch.empty_like(x)
ops.gelu_quick(out, x)
return out


class ScaledActivation(nn.Module):
"""An activation function with post-scale parameters.

Expand Down Expand Up @@ -189,6 +204,7 @@ def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
"gelu_new": NewGELU(),
"gelu_pytorch_tanh": nn.GELU(approximate="tanh"),
"relu": nn.ReLU(),
"quick_gelu": QuickGELU(),
}


Expand Down
203 changes: 203 additions & 0 deletions vllm/model_executor/models/clip.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,203 @@
"""Minimal implementation of CLIPVisionModel intended to be only used
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be nice to unit test the model in isolation to ensure its consistency with HF.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure. The vision language model test should already serve as an end-to-end integration test for CLIPVisionModel since it's only intended to be used there, but I can add separate unit tests for this model itself.

within a vision language model."""
from typing import Optional, Tuple

import torch
import torch.nn as nn
from transformers import CLIPVisionConfig
from transformers.models.clip.modeling_clip import CLIPAttention

from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)


def get_clip_num_patches(image_size: int, patch_size: int) -> int:
assert image_size % patch_size == 0
return (image_size // patch_size)**2


# Adapted from https://github.com/huggingface/transformers/blob/v4.39.0/src/transformers/models/clip/modeling_clip.py#L164 # noqa
class CLIPVisionEmbeddings(nn.Module):

def __init__(self, config: CLIPVisionConfig):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.image_size = config.image_size
self.patch_size = config.patch_size

self.class_embedding = nn.Parameter(torch.randn(self.embed_dim))

self.patch_embedding = nn.Conv2d(
in_channels=config.num_channels,
out_channels=self.embed_dim,
kernel_size=self.patch_size,
stride=self.patch_size,
bias=False,
)

self.num_patches = get_clip_num_patches(self.image_size,
self.patch_size)
self.num_positions = self.num_patches + 1
self.position_embedding = nn.Embedding(self.num_positions,
self.embed_dim)
self.register_buffer("position_ids",
torch.arange(self.num_positions).expand((1, -1)),
persistent=False)

def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
batch_size = pixel_values.shape[0]
target_dtype = self.patch_embedding.weight.dtype
patch_embeds = self.patch_embedding(pixel_values.to(
dtype=target_dtype)) # shape = [*, width, grid, grid]
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)

class_embeds = self.class_embedding.expand(batch_size, 1, -1)
embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
embeddings = embeddings + self.position_embedding(self.position_ids)

return embeddings


class CLIPMLP(nn.Module):

def __init__(self,
config: CLIPVisionConfig,
quant_config: Optional[QuantizationConfig] = None):
super().__init__()
self.config = config
self.activation_fn = get_act_fn(config.hidden_act)
self.fc1 = ColumnParallelLinear(config.hidden_size,
config.intermediate_size,
bias=True,
quant_config=quant_config)
self.fc2 = RowParallelLinear(config.intermediate_size,
config.hidden_size,
bias=True,
quant_config=quant_config)

def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states, _ = self.fc1(hidden_states)
hidden_states = self.activation_fn(hidden_states)
hidden_states, _ = self.fc2(hidden_states)

return hidden_states


class CLIPEncoderLayer(nn.Module):

def __init__(self,
config: CLIPVisionConfig,
quant_config: Optional[QuantizationConfig] = None):
super().__init__()

self.self_attn = CLIPAttention(config)
self.layer_norm1 = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps)
self.mlp = CLIPMLP(config, quant_config=quant_config)
self.layer_norm2 = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps)

def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor]:

residual = hidden_states

hidden_states = self.layer_norm1(hidden_states)
hidden_states, _ = self.self_attn(hidden_states=hidden_states)
hidden_states = residual + hidden_states

residual = hidden_states
hidden_states = self.layer_norm2(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states

return hidden_states


class CLIPEncoder(nn.Module):
"""
Transformer encoder consisting of `config.num_hidden_layers` self
attention layers. Each layer is a [`CLIPEncoderLayer`].

Args:
config: CLIPConfig
"""

def __init__(self,
config: CLIPVisionConfig,
quant_config: Optional[QuantizationConfig] = None):
super().__init__()
self.config = config
self.layers = nn.ModuleList([
CLIPEncoderLayer(config=config, quant_config=quant_config)
for _ in range(config.num_hidden_layers)
])

def forward(self,
inputs_embeds: torch.Tensor,
vision_feature_layer: int = -1):

# Encoder forward pass only up to the required layer
num_layer = len(self.layers) + vision_feature_layer + 1
hidden_states = inputs_embeds
for encoder_layer in self.layers[:num_layer]:
hidden_states = encoder_layer(hidden_states)

return hidden_states


class CLIPVisionTransformer(nn.Module):

def __init__(self,
config: CLIPVisionConfig,
quant_config: Optional[QuantizationConfig] = None):
super().__init__()
self.config = config
embed_dim = config.hidden_size

self.embeddings = CLIPVisionEmbeddings(config)

# NOTE: This typo of "layrnorm" is not fixed on purpose to match
# the original transformers code and name of the model weights.
self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
self.encoder = CLIPEncoder(config=config, quant_config=quant_config)

def forward(
self,
pixel_values: torch.Tensor,
vision_feature_layer: int = -1,
) -> torch.Tensor:

hidden_states = self.embeddings(pixel_values)
hidden_states = self.pre_layrnorm(hidden_states)
hidden_states = self.encoder(inputs_embeds=hidden_states,
vision_feature_layer=vision_feature_layer)

return hidden_states


class CLIPVisionModel(nn.Module):

config_class = CLIPVisionConfig
main_input_name = "pixel_values"

def __init__(self,
config: CLIPVisionConfig,
quant_config: Optional[QuantizationConfig] = None):
super().__init__()
self.vision_model = CLIPVisionTransformer(config=config,
quant_config=quant_config)

def forward(self,
pixel_values: Optional[torch.Tensor] = None,
vision_feature_layer: int = -1):

return self.vision_model(pixel_values=pixel_values,
vision_feature_layer=vision_feature_layer)

@property
def device(self):
return next(self.parameters()).device
17 changes: 9 additions & 8 deletions vllm/model_executor/models/llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,7 @@

import torch
import torch.nn as nn
# TODO(xwjiang): We should port CLIPVisionModel's code over to not depend on
# transformers' impl.
from transformers import CLIPVisionModel, LlavaConfig
from transformers import LlavaConfig

from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, VisionLanguageConfig
Expand All @@ -15,6 +13,7 @@
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.clip import CLIPVisionModel
from vllm.model_executor.models.llama import LlamaModel
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
Expand Down Expand Up @@ -189,12 +188,11 @@ def _select_image_features(self, image_features: torch.Tensor, *,

def _image_pixels_to_features(self, vision_tower: CLIPVisionModel,
pixel_values: torch.Tensor) -> torch.Tensor:
# TODO(xwjiang): Maybe port minimal CLIPVisionModel over.
image_outputs = vision_tower(pixel_values.to(vision_tower.device),
output_hidden_states=True)

image_features = image_outputs.hidden_states[
self.config.vision_feature_layer]
# NOTE: we skip the step to select the vision feature layer since
# this is already done inside the vision tower
image_features = vision_tower(pixel_values.to(vision_tower.device),
self.config.vision_feature_layer)

return self._select_image_features(
image_features,
Expand Down Expand Up @@ -317,6 +315,9 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
# post_layernorm is not needed in CLIPVisionModel
if "vision_model.post_layernorm" in name:
continue
for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items():
if key_to_modify in name:
name = name.replace(key_to_modify, new_key)
Expand Down
19 changes: 10 additions & 9 deletions vllm/model_executor/models/llava_next.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,7 @@
import torch
import torch.nn as nn
from PIL import Image
# TODO(xwjiang): We should port CLIPVisionModel's code over to not depend on
# transformers' impl.
from transformers import CLIPVisionModel, LlavaNextConfig
from transformers import LlavaNextConfig
from transformers.models.llava_next.modeling_llava_next import (
get_anyres_image_grid_shape, unpad_image)
from typing_extensions import NotRequired
Expand All @@ -20,6 +18,7 @@
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.clip import CLIPVisionModel
from vllm.model_executor.models.llama import LlamaModel
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalData
Expand Down Expand Up @@ -121,7 +120,7 @@ def __init__(self,

if self.vision_language_config.image_input_type == (
VisionLanguageConfig.ImageInputType.PIXEL_VALUES):
self.vision_tower = CLIPVisionModel(config.vision_config)
self.vision_tower = CLIPVisionModel(config=config.vision_config)
else:
raise TypeError("Image features are not supported by LLaVA-NeXT")

Expand Down Expand Up @@ -219,12 +218,11 @@ def _select_image_features(self, image_features: torch.Tensor, *,

def _image_pixels_to_features(self, vision_tower: CLIPVisionModel,
pixel_values: torch.Tensor) -> torch.Tensor:
# TODO(xwjiang): Maybe port minimal CLIPVisionModel over.
image_outputs = vision_tower(pixel_values.to(vision_tower.device),
output_hidden_states=True)

image_features = image_outputs.hidden_states[
self.config.vision_feature_layer]
# NOTE: we skip the step to select the vision feature layer since
# this is already done inside the vision tower
image_features = vision_tower(pixel_values.to(vision_tower.device),
self.config.vision_feature_layer)

return self._select_image_features(
image_features,
Expand Down Expand Up @@ -430,6 +428,9 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
# post_layernorm is not needed in CLIPVisionModel
if "vision_model.post_layernorm" in name:
continue
for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items():
if key_to_modify in name:
name = name.replace(key_to_modify, new_key)
Expand Down
Loading
Loading