-
-
Notifications
You must be signed in to change notification settings - Fork 5.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Model] Port over CLIPVisionModel for VLMs #5591
Merged
Merged
Changes from all commits
Commits
Show all changes
16 commits
Select commit
Hold shift + click to select a range
6aa5892
initial
ywang96 cb3ac13
format
ywang96 4bf7f7c
iterate
ywang96 7f29478
fix init
ywang96 d59d15d
iterate
ywang96 d2f38be
iterate
ywang96 fd65d87
Merge branch 'main' into clip
ywang96 f44d414
iterate
ywang96 f0a5b55
iterate
ywang96 8625be4
cleanup
ywang96 45ac59a
update forward
ywang96 733f53e
add note on typo
ywang96 d8b08ff
add num_patches helper
ywang96 aeac701
fix init
ywang96 42c4710
Merge branch 'main' into clip
ywang96 7a3d6c8
update phi3v
ywang96 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,203 @@ | ||
"""Minimal implementation of CLIPVisionModel intended to be only used | ||
within a vision language model.""" | ||
from typing import Optional, Tuple | ||
|
||
import torch | ||
import torch.nn as nn | ||
from transformers import CLIPVisionConfig | ||
from transformers.models.clip.modeling_clip import CLIPAttention | ||
|
||
from vllm.model_executor.layers.activation import get_act_fn | ||
from vllm.model_executor.layers.linear import (ColumnParallelLinear, | ||
RowParallelLinear) | ||
from vllm.model_executor.layers.quantization.base_config import ( | ||
QuantizationConfig) | ||
|
||
|
||
def get_clip_num_patches(image_size: int, patch_size: int) -> int: | ||
assert image_size % patch_size == 0 | ||
return (image_size // patch_size)**2 | ||
|
||
|
||
# Adapted from https://github.com/huggingface/transformers/blob/v4.39.0/src/transformers/models/clip/modeling_clip.py#L164 # noqa | ||
class CLIPVisionEmbeddings(nn.Module): | ||
|
||
def __init__(self, config: CLIPVisionConfig): | ||
super().__init__() | ||
self.config = config | ||
self.embed_dim = config.hidden_size | ||
self.image_size = config.image_size | ||
self.patch_size = config.patch_size | ||
|
||
self.class_embedding = nn.Parameter(torch.randn(self.embed_dim)) | ||
|
||
self.patch_embedding = nn.Conv2d( | ||
in_channels=config.num_channels, | ||
out_channels=self.embed_dim, | ||
kernel_size=self.patch_size, | ||
stride=self.patch_size, | ||
bias=False, | ||
) | ||
|
||
self.num_patches = get_clip_num_patches(self.image_size, | ||
self.patch_size) | ||
self.num_positions = self.num_patches + 1 | ||
self.position_embedding = nn.Embedding(self.num_positions, | ||
self.embed_dim) | ||
self.register_buffer("position_ids", | ||
torch.arange(self.num_positions).expand((1, -1)), | ||
persistent=False) | ||
|
||
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: | ||
batch_size = pixel_values.shape[0] | ||
target_dtype = self.patch_embedding.weight.dtype | ||
patch_embeds = self.patch_embedding(pixel_values.to( | ||
dtype=target_dtype)) # shape = [*, width, grid, grid] | ||
patch_embeds = patch_embeds.flatten(2).transpose(1, 2) | ||
|
||
class_embeds = self.class_embedding.expand(batch_size, 1, -1) | ||
embeddings = torch.cat([class_embeds, patch_embeds], dim=1) | ||
embeddings = embeddings + self.position_embedding(self.position_ids) | ||
|
||
return embeddings | ||
|
||
|
||
class CLIPMLP(nn.Module): | ||
|
||
def __init__(self, | ||
config: CLIPVisionConfig, | ||
quant_config: Optional[QuantizationConfig] = None): | ||
super().__init__() | ||
self.config = config | ||
self.activation_fn = get_act_fn(config.hidden_act) | ||
self.fc1 = ColumnParallelLinear(config.hidden_size, | ||
config.intermediate_size, | ||
bias=True, | ||
quant_config=quant_config) | ||
self.fc2 = RowParallelLinear(config.intermediate_size, | ||
config.hidden_size, | ||
bias=True, | ||
quant_config=quant_config) | ||
|
||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: | ||
hidden_states, _ = self.fc1(hidden_states) | ||
hidden_states = self.activation_fn(hidden_states) | ||
hidden_states, _ = self.fc2(hidden_states) | ||
|
||
return hidden_states | ||
|
||
|
||
class CLIPEncoderLayer(nn.Module): | ||
|
||
def __init__(self, | ||
config: CLIPVisionConfig, | ||
quant_config: Optional[QuantizationConfig] = None): | ||
super().__init__() | ||
|
||
self.self_attn = CLIPAttention(config) | ||
self.layer_norm1 = nn.LayerNorm(config.hidden_size, | ||
eps=config.layer_norm_eps) | ||
self.mlp = CLIPMLP(config, quant_config=quant_config) | ||
self.layer_norm2 = nn.LayerNorm(config.hidden_size, | ||
eps=config.layer_norm_eps) | ||
|
||
def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor]: | ||
|
||
residual = hidden_states | ||
|
||
hidden_states = self.layer_norm1(hidden_states) | ||
hidden_states, _ = self.self_attn(hidden_states=hidden_states) | ||
hidden_states = residual + hidden_states | ||
|
||
residual = hidden_states | ||
hidden_states = self.layer_norm2(hidden_states) | ||
hidden_states = self.mlp(hidden_states) | ||
hidden_states = residual + hidden_states | ||
|
||
return hidden_states | ||
|
||
|
||
class CLIPEncoder(nn.Module): | ||
""" | ||
Transformer encoder consisting of `config.num_hidden_layers` self | ||
attention layers. Each layer is a [`CLIPEncoderLayer`]. | ||
|
||
Args: | ||
config: CLIPConfig | ||
""" | ||
|
||
def __init__(self, | ||
config: CLIPVisionConfig, | ||
ywang96 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
quant_config: Optional[QuantizationConfig] = None): | ||
super().__init__() | ||
self.config = config | ||
self.layers = nn.ModuleList([ | ||
CLIPEncoderLayer(config=config, quant_config=quant_config) | ||
for _ in range(config.num_hidden_layers) | ||
]) | ||
|
||
def forward(self, | ||
inputs_embeds: torch.Tensor, | ||
vision_feature_layer: int = -1): | ||
|
||
# Encoder forward pass only up to the required layer | ||
num_layer = len(self.layers) + vision_feature_layer + 1 | ||
hidden_states = inputs_embeds | ||
for encoder_layer in self.layers[:num_layer]: | ||
hidden_states = encoder_layer(hidden_states) | ||
|
||
return hidden_states | ||
|
||
|
||
class CLIPVisionTransformer(nn.Module): | ||
|
||
def __init__(self, | ||
config: CLIPVisionConfig, | ||
quant_config: Optional[QuantizationConfig] = None): | ||
super().__init__() | ||
self.config = config | ||
embed_dim = config.hidden_size | ||
|
||
self.embeddings = CLIPVisionEmbeddings(config) | ||
|
||
# NOTE: This typo of "layrnorm" is not fixed on purpose to match | ||
# the original transformers code and name of the model weights. | ||
self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) | ||
ywang96 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
self.encoder = CLIPEncoder(config=config, quant_config=quant_config) | ||
|
||
def forward( | ||
self, | ||
pixel_values: torch.Tensor, | ||
vision_feature_layer: int = -1, | ||
) -> torch.Tensor: | ||
|
||
hidden_states = self.embeddings(pixel_values) | ||
hidden_states = self.pre_layrnorm(hidden_states) | ||
hidden_states = self.encoder(inputs_embeds=hidden_states, | ||
vision_feature_layer=vision_feature_layer) | ||
|
||
return hidden_states | ||
|
||
|
||
class CLIPVisionModel(nn.Module): | ||
|
||
config_class = CLIPVisionConfig | ||
main_input_name = "pixel_values" | ||
|
||
def __init__(self, | ||
config: CLIPVisionConfig, | ||
quant_config: Optional[QuantizationConfig] = None): | ||
super().__init__() | ||
self.vision_model = CLIPVisionTransformer(config=config, | ||
quant_config=quant_config) | ||
|
||
def forward(self, | ||
pixel_values: Optional[torch.Tensor] = None, | ||
vision_feature_layer: int = -1): | ||
|
||
return self.vision_model(pixel_values=pixel_values, | ||
vision_feature_layer=vision_feature_layer) | ||
|
||
@property | ||
def device(self): | ||
return next(self.parameters()).device |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be nice to unit test the model in isolation to ensure its consistency with HF.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure. The vision language model test should already serve as an end-to-end integration test for
CLIPVisionModel
since it's only intended to be used there, but I can add separate unit tests for this model itself.