Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[LLaVa] Some improvements #27895

Merged
merged 3 commits into from
Dec 11, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 11 additions & 11 deletions src/transformers/models/llava/modeling_llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,22 +270,22 @@ def resize_token_embeddings(self, new_num_tokens: Optional[int] = None, pad_to_m
def _merge_input_ids_with_image_features(
self, image_features, inputs_embeds, input_ids, attention_mask, position_ids
):
nb_images, image_hidden_dim, embed_dim = image_features.shape
num_images, num_image_patches, embed_dim = image_features.shape
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is not really a "hidden dim" but rather the number of image patches. As below "num_image_tokens" is already used, this could be confusing with "num_image_patches". The former actually refers to the number of special image tokens, hence I've renamed that.

batch_size, sequence_length = input_ids.shape
left_padding = not torch.sum(input_ids[:, -1] == torch.tensor(self.pad_token_id))
# 1. Create a mask to know where image tokens are
image_token_mask = input_ids == self.config.image_token_index
num_image_tokens = torch.sum(image_token_mask, dim=-1)
# 1. Create a mask to know where special image tokens are
special_image_token_mask = input_ids == self.config.image_token_index
num_special_image_tokens = torch.sum(special_image_token_mask, dim=-1)
# Compute the maximum embed dimension
max_embed_dim = (num_image_tokens.max() * (image_hidden_dim - 1)) + sequence_length
max_embed_dim = (num_special_image_tokens.max() * (num_image_patches - 1)) + sequence_length
batch_indices, non_image_indices = torch.where(input_ids != self.config.image_token_index)

# 2. Compute the positions where text should be written
# Calculate new positions for text tokens in merged image-text sequence.
# `image_token_mask` identifies image tokens. Each image token will be replaced by `nb_text_tokens_per_images - 1` text tokens.
# `special_image_token_mask` identifies image tokens. Each image token will be replaced by `nb_text_tokens_per_images - 1` text tokens.
# `torch.cumsum` computes how each image token shifts subsequent text token positions.
# - 1 to adjust for zero-based indexing, as `cumsum` inherently increases indices by one.
new_token_positions = torch.cumsum((image_token_mask * (image_hidden_dim - 1) + 1), -1) - 1
new_token_positions = torch.cumsum((special_image_token_mask * (num_image_patches - 1) + 1), -1) - 1
nb_image_pad = max_embed_dim - 1 - new_token_positions[:, -1]
if left_padding:
new_token_positions += nb_image_pad[:, None] # offset for left padding
Expand All @@ -310,8 +310,8 @@ def _merge_input_ids_with_image_features(

if image_to_overwrite.sum() != image_features.shape[:-1].numel():
raise ValueError(
f"The input provided to the model are wrong. The number of image tokens is {torch.sum(image_token_mask)} while"
f" the number of image given to the model is {nb_images}. This prevents correct indexing and breaks batch generation."
f"The input provided to the model are wrong. The number of image tokens is {torch.sum(special_image_token_mask)} while"
f" the number of image given to the model is {num_images}. This prevents correct indexing and breaks batch generation."
)

final_embedding[image_to_overwrite] = image_features.contiguous().reshape(-1, embed_dim)
Expand Down Expand Up @@ -353,8 +353,8 @@ def forward(
>>> import requests
>>> from transformers import AutoProcessor, LlavaForConditionalGeneration

>>> model = LlavaForConditionalGeneration.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
>>> processor = AutoProcessor.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
>>> model = LlavaForConditionalGeneration.from_pretrained("llava-hf/llava-1.5-7b-hf")
>>> processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf")

>>> prompt = "<image>\nUSER: What's the content of the image?\nASSISTANT:"
>>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
Expand Down
30 changes: 3 additions & 27 deletions src/transformers/models/llava/processing_llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@
"""


import warnings
from typing import Callable, List, Optional, Union
from typing import List, Optional, Union

from ...feature_extraction_utils import BatchFeature
from ...image_utils import ImageInput
Expand All @@ -45,23 +44,7 @@ class LlavaProcessor(ProcessorMixin):
image_processor_class = "CLIPImageProcessor"
tokenizer_class = ("LlamaTokenizer", "LlamaTokenizerFast")

# Copied from transformers.models.clip.processing_clip.CLIPProcessor.__init__
def __init__(self, image_processor=None, tokenizer=None, **kwargs):
feature_extractor = None
if "feature_extractor" in kwargs:
warnings.warn(
"The `feature_extractor` argument is deprecated and will be removed in v5, use `image_processor`"
" instead.",
FutureWarning,
)
feature_extractor = kwargs.pop("feature_extractor")

image_processor = image_processor if image_processor is not None else feature_extractor
if image_processor is None:
raise ValueError("You need to specify an `image_processor`.")
if tokenizer is None:
raise ValueError("You need to specify a `tokenizer`.")

def __init__(self, image_processor=None, tokenizer=None):
super().__init__(image_processor, tokenizer)

def __call__(
Expand All @@ -70,7 +53,6 @@ def __call__(
images: ImageInput = None,
padding: Union[bool, str, PaddingStrategy] = False,
truncation: Union[bool, str, TruncationStrategy] = None,
transform: Callable = None,
NielsRogge marked this conversation as resolved.
Show resolved Hide resolved
max_length=None,
return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH,
) -> BatchFeature:
Expand Down Expand Up @@ -103,10 +85,6 @@ def __call__(
Maximum length of the returned list and optionally padding length (see above).
truncation (`bool`, *optional*):
Activates truncation to cut input sequences longer than `max_length` to `max_length`.
transform (`Callable`, *optional*):
A custom transform function that accepts a single image can be passed for training. For example,
`torchvision.Compose` can be used to compose multiple functions. If `None` a preset inference-specific
set of transforms will be applied to the images
return_tensors (`str` or [`~utils.TensorType`], *optional*):
If set, will return tensors of a particular framework. Acceptable values are:

Expand All @@ -125,9 +103,7 @@ def __call__(
- **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
"""
if images is not None:
pixel_values = self.image_processor(images, transform=transform, return_tensors=return_tensors)[
"pixel_values"
]
pixel_values = self.image_processor(images, return_tensors=return_tensors)["pixel_values"]
else:
pixel_values = None
text_inputs = self.tokenizer(
Expand Down
Loading