From 70a70b1d14cb9b5168a1493ec12406f7e611ddd4 Mon Sep 17 00:00:00 2001 From: raushan Date: Fri, 30 Aug 2024 16:01:10 +0200 Subject: [PATCH 1/9] leave only half of the changes --- .../models/llava/modeling_llava.py | 8 +- .../models/llava_next/modeling_llava_next.py | 13 +- .../llava_next/processing_llava_next.py | 51 ++-- .../llava_next_video/diff_llava_next_video.py | 231 ++++++++--------- .../modeling_llava_next_video.py | 232 ++++++++---------- .../processing_llava_next_video.py | 77 ++++-- .../video_llava/modeling_video_llava.py | 43 ++-- .../video_llava/processing_video_llava.py | 2 +- .../models/vipllava/modeling_vipllava.py | 8 +- src/transformers/pytorch_utils.py | 2 +- tests/models/llava/test_modeling_llava.py | 20 ++ .../llava_next/test_modeling_llava_next.py | 47 +++- .../test_modeling_llava_next_video.py | 71 ++++++ .../video_llava/test_modeling_video_llava.py | 86 ++++--- 14 files changed, 530 insertions(+), 361 deletions(-) diff --git a/src/transformers/models/llava/modeling_llava.py b/src/transformers/models/llava/modeling_llava.py index 394c80edb540..f1203d69ee84 100644 --- a/src/transformers/models/llava/modeling_llava.py +++ b/src/transformers/models/llava/modeling_llava.py @@ -469,6 +469,7 @@ def forward( inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_image_features( image_features, inputs_embeds, input_ids, attention_mask, labels ) + cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device) else: # Retrieve the first layer to inspect the logits and mask out the hidden states # that are set to 0 @@ -499,6 +500,9 @@ def forward( attention_mask = torch.cat((extended_attention_mask, attention_mask[:, -target_length:]), dim=1) position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1 + cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device)[ + -target_length: + ] # TODO: @raushan retain only the new behavior after v4.47 else: @@ -575,9 +579,7 @@ def prepare_inputs_for_generation( **kwargs, ) - if legacy_processing: - model_inputs["pixel_values"] = pixel_values - elif cache_position[0] == 0: + if legacy_processing or cache_position[0] == 0: # If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore # Otherwise we need pixel values to be passed to model model_inputs["pixel_values"] = pixel_values diff --git a/src/transformers/models/llava_next/modeling_llava_next.py b/src/transformers/models/llava_next/modeling_llava_next.py index 723d54c92dd9..0ed9bb91162b 100644 --- a/src/transformers/models/llava_next/modeling_llava_next.py +++ b/src/transformers/models/llava_next/modeling_llava_next.py @@ -842,6 +842,7 @@ def forward( position_ids, labels=labels, ) + cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device) else: # Retrieve the first layer to inspect the logits and mask out the hidden states # that are set to 0 @@ -871,6 +872,9 @@ def forward( extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0 attention_mask = torch.cat((extended_attention_mask, attention_mask[:, -target_length:]), dim=1) position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1 + cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device)[ + -target_length: + ] # TODO: @raushan retain only the new behavior after v4.47 else: @@ -947,12 +951,9 @@ def prepare_inputs_for_generation( **kwargs, ) - if legacy_processing: - model_inputs["pixel_values"] = pixel_values - model_inputs["image_sizes"] = image_sizes - elif cache_position[0] == 0: - # If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore - # Otherwise we need pixel values to be passed to model + # If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore + # Otherwise we need pixel values to be passed to model + if legacy_processing or cache_position[0] == 0: model_inputs["pixel_values"] = pixel_values model_inputs["image_sizes"] = image_sizes diff --git a/src/transformers/models/llava_next/processing_llava_next.py b/src/transformers/models/llava_next/processing_llava_next.py index c043b8bc7ed6..4bf4429baf78 100644 --- a/src/transformers/models/llava_next/processing_llava_next.py +++ b/src/transformers/models/llava_next/processing_llava_next.py @@ -140,30 +140,29 @@ def __call__( elif not isinstance(text, list) and not isinstance(text[0], str): raise ValueError("Invalid input text. Please provide a string, or a list of strings") - if self.patch_size is None or self.vision_feature_select_strategy is None: - prompt_strings = text - logger.warning_once( - "Expanding inputs for image tokens in LLaVa-NeXT should be done in processing. " - "Please add `patch_size` and `vision_feature_select_strategy` to the model's processing config or set directly " - "with `processor.patch_size = {{patch_size}}` and processor.vision_feature_select_strategy = {{vision_feature_select_strategy}}`. " - "Using processors without these attributes in the config is deprecated and will throw an error in v4.47." - ) - # cannot infer image expansion length if no images are found - elif not image_inputs: - prompt_strings = text - else: - image_sizes = image_inputs["image_sizes"] - height, width = get_image_size(to_numpy_array(image_inputs["pixel_values"][0][0])) - prompt_strings = [] - for image_size, sample in zip(image_sizes, text): - # Replace the image token with the expanded image token sequence - orig_height, orig_width = image_size - num_image_tokens = self._get_number_of_features(orig_height, orig_width, height, width) - if self.vision_feature_select_strategy == "default": - num_image_tokens -= 1 - - sample = sample.replace(self.image_token, self.image_token * num_image_tokens) - prompt_strings.append(sample) + prompt_strings = text + if image_inputs: + if self.oatch_size is None or self.vision_feature_select_strategy is None: + logger.warning_once( + "Expanding inputs for image tokens in LLaVa-NeXT should be done in processing. " + "Please add `oatch_size` and `vision_feature_select_strategy` to the model's processing config or set directly " + "with `processor.oatch_size = {{oatch_size}}` and processor.vision_feature_select_strategy = {{vision_feature_select_strategy}}`. " + "Using processors without these attributes in the config is deprecated and will throw an error in v4.47." + ) + else: + image_sizes = iter(image_inputs["image_sizes"]) + height, width = get_image_size(to_numpy_array(image_inputs["pixel_values"][0][0])) + prompt_strings = [] + for sample in text: + while self.image_token in sample: + image_size = next(image_sizes) + orig_height, orig_width = image_size + num_image_tokens = self._get_number_of_features(orig_height, orig_width, height, width) + if self.vision_feature_select_strategy == "default": + num_image_tokens -= 1 + sample = sample.replace(self.image_token, "" * num_image_tokens, 1) + prompt_strings.append(sample) + prompt_strings = [sample.replace("", self.image_token) for sample in prompt_strings] text_inputs = self.tokenizer( prompt_strings, @@ -199,8 +198,8 @@ def _get_unpadded_features(self, height, width, patches_height, patches_width, s because it divided each image into patches depending on its resolution. Therefore we need to calculate how many patches an image is divided into and get the number of features from that. """ - current_width = patches_height * scale_height - current_height = patches_width * scale_width + current_height = patches_height * scale_height + current_width = patches_width * scale_width original_aspect_ratio = width / height current_aspect_ratio = current_width / current_height diff --git a/src/transformers/models/llava_next_video/diff_llava_next_video.py b/src/transformers/models/llava_next_video/diff_llava_next_video.py index b4018db586e7..e765dfb95cc3 100644 --- a/src/transformers/models/llava_next_video/diff_llava_next_video.py +++ b/src/transformers/models/llava_next_video/diff_llava_next_video.py @@ -29,7 +29,6 @@ image_size_to_num_patches, ) -from ...cache_utils import Cache from ...utils import ( logging, replace_return_docstrings, @@ -389,13 +388,17 @@ def forward( # if the number of image/video tokens is more than image embeddings seq length, then prob we expanded it in processing # not very reliable, but we don't expect one to actually pass 500+ images for one prompt - img_token_count = (input_ids == self.config.image_token_index).sum(1).max() - video_token_count = (input_ids == self.config.video_token_index).sum(1).max() - inputs_expanded = ( - img_token_count < self.config.image_seq_length and video_token_count < self.config.video_seq_length + img_token_not_enough = (input_ids == self.config.image_token_index).sum( + 1 + ).max() < self.config.image_seq_length + video_token_not_enough = (input_ids == self.config.video_token_index).sum( + 1 + ).max() < self.config.video_seq_length + inputs_not_expanded = (img_token_not_enough and pixel_values is not None) or ( + video_token_not_enough and pixel_values_videos is not None ) - pixels_present = input_ids.shape[-1] == 1 and pixel_values is not None and pixel_values_videos is not None - legacy_processing = inputs_expanded or pixels_present + pixels_present = input_ids.shape[-1] == 1 and (pixel_values is not None or pixel_values_videos is not None) + legacy_processing = inputs_not_expanded or pixels_present image_features = feature_lens = None if pixel_values is not None and pixel_values.size(0) > 0: @@ -414,75 +417,76 @@ def forward( video_features = torch.cat(video_features, dim=0) video_feature_lens = torch.tensor(video_feature_lens, dtype=torch.long, device=video_features.device) - if legacy_processing: - logger.warning_once( - "Expanding inputs for image.video tokens in LLaVa-NeXT-Video should be done in processing. " - "Please add `patch_size` and `vision_feature_select_strategy` to the model's processing config or set directly " - "with `processor.patch_size = {{patch_size}}` and processor.vision_feature_select_strategy = {{vision_feature_select_strategy}}`. " - "Using processors without these attributes in the config is deprecated and will throw an error in v4.47." + if legacy_processing: + logger.warning_once( + "Expanding inputs for image.video tokens in LLaVa-NeXT-Video should be done in processing. " + "Please add `patch_size` and `vision_feature_select_strategy` to the model's processing config or set directly " + "with `processor.patch_size = {{patch_size}}` and processor.vision_feature_select_strategy = {{vision_feature_select_strategy}}`. " + "Using processors without these attributes in the config is deprecated and will throw an error in v4.47." + ) + if input_ids.shape[1] != 1: + iterator = ( + (image_features, feature_lens, self.config.image_token_index), + (video_features, video_feature_lens, self.config.video_token_index), ) - if input_ids.shape[1] != 1: - iterator = ( - (image_features, feature_lens, self.config.image_token_index), - (video_features, video_feature_lens, self.config.video_token_index), - ) - for features, lens, special_token in zip(iterator): - if features is not None: - ( - inputs_embeds, - attention_mask, - position_ids, - labels, - input_ids, - ) = self._merge_input_ids_with_image_features( - features, - lens, - inputs_embeds, - input_ids, - attention_mask, - position_ids, - labels=labels, - image_token_index=special_token, - ) - else: - # Retrieve the first layer to inspect the logits and mask out the hidden states that are set to 0 - first_layer_past_key_value = past_key_values[0][0][:, :, :, 0] - # Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941 - batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0) - # Get the target length - target_length = input_ids.shape[1] - past_length = first_layer_past_key_value.shape[-1] - extended_attention_mask = torch.ones( - (attention_mask.shape[0], past_length), - dtype=attention_mask.dtype, - device=attention_mask.device, - ) - # Filter out only the tokens that can be un-attended, this can happen - # if one uses Llava + Fused modules where the cache on the - # first iteration is already big enough, or if one passes custom cache - valid_indices = non_attended_tokens < extended_attention_mask.size(-1) - new_batch_index = batch_index[valid_indices] - new_non_attended_tokens = non_attended_tokens[valid_indices] - # Zero-out the places where we don't need to attend - extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0 - attention_mask = torch.cat((extended_attention_mask, attention_mask[:, -target_length:]), dim=1) - position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1 - - # TODO: @raushan retain only the new behavior after v4.47 + for features, lens, special_token in iterator: + if features is not None: + ( + inputs_embeds, + attention_mask, + position_ids, + labels, + input_ids, + ) = self._merge_input_ids_with_image_features( + features, + lens, + inputs_embeds, + input_ids, + attention_mask, + position_ids, + labels=labels, + image_token_index=special_token, + ) + cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device) else: - if image_features is not None: - special_image_mask = ( - (input_ids == self.config.image_token_index).unsqueeze(-1).expand_as(inputs_embeds) - ) - image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) - inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) - - if video_features is not None: - special_image_mask = ( - (input_ids == self.config.video_token_index).unsqueeze(-1).expand_as(inputs_embeds) - ) - video_features = video_features.to(inputs_embeds.device, inputs_embeds.dtype) - inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, video_features) + # Retrieve the first layer to inspect the logits and mask out the hidden states that are set to 0 + first_layer_past_key_value = past_key_values[0][0][:, :, :, 0] + # Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941 + batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0) + # Get the target length + target_length = input_ids.shape[1] + past_length = first_layer_past_key_value.shape[-1] + extended_attention_mask = torch.ones( + (attention_mask.shape[0], past_length), + dtype=attention_mask.dtype, + device=attention_mask.device, + ) + # Filter out only the tokens that can be un-attended, this can happen + # if one uses Llava + Fused modules where the cache on the + # first iteration is already big enough, or if one passes custom cache + valid_indices = non_attended_tokens < extended_attention_mask.size(-1) + new_batch_index = batch_index[valid_indices] + new_non_attended_tokens = non_attended_tokens[valid_indices] + # Zero-out the places where we don't need to attend + extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0 + attention_mask = torch.cat((extended_attention_mask, attention_mask[:, -target_length:]), dim=1) + position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1 + cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device)[-target_length:] + + # TODO: @raushan retain only the new behavior after v4.47 + else: + if image_features is not None: + special_image_mask = ( + (input_ids == self.config.image_token_index).unsqueeze(-1).expand_as(inputs_embeds) + ) + image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) + if video_features is not None: + special_image_mask = ( + (input_ids == self.config.video_token_index).unsqueeze(-1).expand_as(inputs_embeds) + ) + video_features = video_features.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, video_features) outputs = self.language_model( attention_mask=attention_mask, @@ -493,6 +497,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) logits = outputs[0] @@ -534,58 +539,34 @@ def prepare_inputs_for_generation( pixel_values_videos=None, image_sizes=None, attention_mask=None, + cache_position=None, **kwargs, ): - if past_key_values is not None: - if isinstance(past_key_values, Cache): - cache_length = past_key_values.get_seq_length() - past_length = past_key_values.seen_tokens - else: - cache_length = past_length = past_key_values[0][0].shape[2] - - # Keep only the unprocessed tokens: - # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where - # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as - # input) - if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: - input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] - # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard - # input_ids based on the past_length. - elif past_length < input_ids.shape[1]: - input_ids = input_ids[:, past_length:] - - # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. - elif self.config.image_token_index in input_ids or self.config.video_token_index in input_ids: - input_ids = input_ids[:, input_ids.shape[1] - 1 :] - - # If the cache has seen more tokens than it can hold, then the cache has a size limit. Let's discard the - # older attention values, as their corresponding values are not part of the input. - if cache_length < past_length and attention_mask is not None: - attention_mask = attention_mask[:, -(cache_length + input_ids.shape[1]) :] - - position_ids = kwargs.get("position_ids", None) - if attention_mask is not None and position_ids is None: - # create position_ids on the fly for batch generation - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - if past_key_values: - position_ids = position_ids[:, -input_ids.shape[1] :] - - # if `inputs_embeds` are passed, we only want to use them in the 1st generation step - if inputs_embeds is not None and past_key_values is None: - model_inputs = {"inputs_embeds": inputs_embeds} - else: - model_inputs = {"input_ids": input_ids} - - model_inputs.update( - { - "position_ids": position_ids, - "past_key_values": past_key_values, - "use_cache": kwargs.get("use_cache"), - "attention_mask": attention_mask, - "pixel_values": pixel_values, - "pixel_values_videos": pixel_values_videos, - "image_sizes": image_sizes, - } + if input_ids is not None: + img_token_not_enough = (input_ids == self.config.image_token_index).sum( + 1 + ).max() < self.config.image_seq_length + video_token_not_enough = (input_ids == self.config.video_token_index).sum( + 1 + ).max() < self.config.video_seq_length + legacy_processing = (img_token_not_enough and pixel_values is not None) or ( + video_token_not_enough and pixel_values_videos is not None + ) + + model_inputs = self.language_model.prepare_inputs_for_generation( + input_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + cache_position=cache_position, + **kwargs, ) + + # If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore + # Otherwise we need pixel values to be passed to model + if legacy_processing or cache_position[0] == 0: + model_inputs["pixel_values"] = pixel_values + model_inputs["pixel_values_videos"] = pixel_values_videos + model_inputs["image_sizes"] = image_sizes + return model_inputs diff --git a/src/transformers/models/llava_next_video/modeling_llava_next_video.py b/src/transformers/models/llava_next_video/modeling_llava_next_video.py index 3430fbe590aa..807937ce5c36 100644 --- a/src/transformers/models/llava_next_video/modeling_llava_next_video.py +++ b/src/transformers/models/llava_next_video/modeling_llava_next_video.py @@ -31,7 +31,6 @@ from ... import PreTrainedModel from ...activations import ACT2FN -from ...cache_utils import Cache from ...image_processing_utils import select_best_resolution from ...modeling_outputs import ModelOutput from ...utils import ( @@ -767,6 +766,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple, LlavaNextVideoCausalLMOutputWithPast]: r""" Args: @@ -869,13 +869,17 @@ def forward( # if the number of image/video tokens is more than image embeddings seq length, then prob we expanded it in processing # not very reliable, but we don't expect one to actually pass 500+ images for one prompt - img_token_count = (input_ids == self.config.image_token_index).sum(1).max() - video_token_count = (input_ids == self.config.video_token_index).sum(1).max() - inputs_expanded = ( - img_token_count < self.config.image_seq_length and video_token_count < self.config.video_seq_length + img_token_not_enough = (input_ids == self.config.image_token_index).sum( + 1 + ).max() < self.config.image_seq_length + video_token_not_enough = (input_ids == self.config.video_token_index).sum( + 1 + ).max() < self.config.video_seq_length + inputs_not_expanded = (img_token_not_enough and pixel_values is not None) or ( + video_token_not_enough and pixel_values_videos is not None ) - pixels_present = input_ids.shape[-1] == 1 and pixel_values is not None and pixel_values_videos is not None - legacy_processing = inputs_expanded or pixels_present + pixels_present = input_ids.shape[-1] == 1 and (pixel_values is not None or pixel_values_videos is not None) + legacy_processing = inputs_not_expanded or pixels_present image_features = feature_lens = None if pixel_values is not None and pixel_values.size(0) > 0: @@ -894,75 +898,76 @@ def forward( video_features = torch.cat(video_features, dim=0) video_feature_lens = torch.tensor(video_feature_lens, dtype=torch.long, device=video_features.device) - if legacy_processing: - logger.warning_once( - "Expanding inputs for image.video tokens in LLaVa-NeXT-Video should be done in processing. " - "Please add `patch_size` and `vision_feature_select_strategy` to the model's processing config or set directly " - "with `processor.patch_size = {{patch_size}}` and processor.vision_feature_select_strategy = {{vision_feature_select_strategy}}`. " - "Using processors without these attributes in the config is deprecated and will throw an error in v4.47." + if legacy_processing: + logger.warning_once( + "Expanding inputs for image.video tokens in LLaVa-NeXT-Video should be done in processing. " + "Please add `patch_size` and `vision_feature_select_strategy` to the model's processing config or set directly " + "with `processor.patch_size = {{patch_size}}` and processor.vision_feature_select_strategy = {{vision_feature_select_strategy}}`. " + "Using processors without these attributes in the config is deprecated and will throw an error in v4.47." + ) + if input_ids.shape[1] != 1: + iterator = ( + (image_features, feature_lens, self.config.image_token_index), + (video_features, video_feature_lens, self.config.video_token_index), ) - if input_ids.shape[1] != 1: - iterator = ( - (image_features, feature_lens, self.config.image_token_index), - (video_features, video_feature_lens, self.config.video_token_index), - ) - for features, lens, special_token in iterator: - if features is not None: - ( - inputs_embeds, - attention_mask, - position_ids, - labels, - input_ids, - ) = self._merge_input_ids_with_image_features( - features, - lens, - inputs_embeds, - input_ids, - attention_mask, - position_ids, - labels=labels, - image_token_index=special_token, - ) - else: - # Retrieve the first layer to inspect the logits and mask out the hidden states that are set to 0 - first_layer_past_key_value = past_key_values[0][0][:, :, :, 0] - # Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941 - batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0) - # Get the target length - target_length = input_ids.shape[1] - past_length = first_layer_past_key_value.shape[-1] - extended_attention_mask = torch.ones( - (attention_mask.shape[0], past_length), - dtype=attention_mask.dtype, - device=attention_mask.device, - ) - # Filter out only the tokens that can be un-attended, this can happen - # if one uses Llava + Fused modules where the cache on the - # first iteration is already big enough, or if one passes custom cache - valid_indices = non_attended_tokens < extended_attention_mask.size(-1) - new_batch_index = batch_index[valid_indices] - new_non_attended_tokens = non_attended_tokens[valid_indices] - # Zero-out the places where we don't need to attend - extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0 - attention_mask = torch.cat((extended_attention_mask, attention_mask[:, -target_length:]), dim=1) - position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1 - - # TODO: @raushan retain only the new behavior after v4.47 + for features, lens, special_token in iterator: + if features is not None: + ( + inputs_embeds, + attention_mask, + position_ids, + labels, + input_ids, + ) = self._merge_input_ids_with_image_features( + features, + lens, + inputs_embeds, + input_ids, + attention_mask, + position_ids, + labels=labels, + image_token_index=special_token, + ) + cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device) else: - if image_features is not None: - special_image_mask = ( - (input_ids == self.config.image_token_index).unsqueeze(-1).expand_as(inputs_embeds) - ) - image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) - inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) - - if video_features is not None: - special_image_mask = ( - (input_ids == self.config.video_token_index).unsqueeze(-1).expand_as(inputs_embeds) - ) - video_features = video_features.to(inputs_embeds.device, inputs_embeds.dtype) - inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, video_features) + # Retrieve the first layer to inspect the logits and mask out the hidden states that are set to 0 + first_layer_past_key_value = past_key_values[0][0][:, :, :, 0] + # Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941 + batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0) + # Get the target length + target_length = input_ids.shape[1] + past_length = first_layer_past_key_value.shape[-1] + extended_attention_mask = torch.ones( + (attention_mask.shape[0], past_length), + dtype=attention_mask.dtype, + device=attention_mask.device, + ) + # Filter out only the tokens that can be un-attended, this can happen + # if one uses Llava + Fused modules where the cache on the + # first iteration is already big enough, or if one passes custom cache + valid_indices = non_attended_tokens < extended_attention_mask.size(-1) + new_batch_index = batch_index[valid_indices] + new_non_attended_tokens = non_attended_tokens[valid_indices] + # Zero-out the places where we don't need to attend + extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0 + attention_mask = torch.cat((extended_attention_mask, attention_mask[:, -target_length:]), dim=1) + position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1 + cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device)[-target_length:] + + # TODO: @raushan retain only the new behavior after v4.47 + else: + if image_features is not None: + special_image_mask = ( + (input_ids == self.config.image_token_index).unsqueeze(-1).expand_as(inputs_embeds) + ) + image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) + if video_features is not None: + special_image_mask = ( + (input_ids == self.config.video_token_index).unsqueeze(-1).expand_as(inputs_embeds) + ) + video_features = video_features.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, video_features) outputs = self.language_model( attention_mask=attention_mask, @@ -973,6 +978,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) logits = outputs[0] @@ -1014,60 +1020,36 @@ def prepare_inputs_for_generation( pixel_values_videos=None, image_sizes=None, attention_mask=None, + cache_position=None, **kwargs, ): - if past_key_values is not None: - if isinstance(past_key_values, Cache): - cache_length = past_key_values.get_seq_length() - past_length = past_key_values.seen_tokens - else: - cache_length = past_length = past_key_values[0][0].shape[2] - - # Keep only the unprocessed tokens: - # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where - # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as - # input) - if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: - input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] - # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard - # input_ids based on the past_length. - elif past_length < input_ids.shape[1]: - input_ids = input_ids[:, past_length:] - - # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. - elif self.config.image_token_index in input_ids or self.config.video_token_index in input_ids: - input_ids = input_ids[:, input_ids.shape[1] - 1 :] - - # If the cache has seen more tokens than it can hold, then the cache has a size limit. Let's discard the - # older attention values, as their corresponding values are not part of the input. - if cache_length < past_length and attention_mask is not None: - attention_mask = attention_mask[:, -(cache_length + input_ids.shape[1]) :] - - position_ids = kwargs.get("position_ids", None) - if attention_mask is not None and position_ids is None: - # create position_ids on the fly for batch generation - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - if past_key_values: - position_ids = position_ids[:, -input_ids.shape[1] :] - - # if `inputs_embeds` are passed, we only want to use them in the 1st generation step - if inputs_embeds is not None and past_key_values is None: - model_inputs = {"inputs_embeds": inputs_embeds} - else: - model_inputs = {"input_ids": input_ids} - - model_inputs.update( - { - "position_ids": position_ids, - "past_key_values": past_key_values, - "use_cache": kwargs.get("use_cache"), - "attention_mask": attention_mask, - "pixel_values": pixel_values, - "pixel_values_videos": pixel_values_videos, - "image_sizes": image_sizes, - } + if input_ids is not None: + img_token_not_enough = (input_ids == self.config.image_token_index).sum( + 1 + ).max() < self.config.image_seq_length + video_token_not_enough = (input_ids == self.config.video_token_index).sum( + 1 + ).max() < self.config.video_seq_length + legacy_processing = (img_token_not_enough and pixel_values is not None) or ( + video_token_not_enough and pixel_values_videos is not None + ) + + model_inputs = self.language_model.prepare_inputs_for_generation( + input_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + cache_position=cache_position, + **kwargs, ) + + # If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore + # Otherwise we need pixel values to be passed to model + if legacy_processing or cache_position[0] == 0: + model_inputs["pixel_values"] = pixel_values + model_inputs["pixel_values_videos"] = pixel_values_videos + model_inputs["image_sizes"] = image_sizes + return model_inputs def _get_image_features(self, pixel_values, image_sizes): diff --git a/src/transformers/models/llava_next_video/processing_llava_next_video.py b/src/transformers/models/llava_next_video/processing_llava_next_video.py index efbb193ba62a..e0e4534e42b5 100644 --- a/src/transformers/models/llava_next_video/processing_llava_next_video.py +++ b/src/transformers/models/llava_next_video/processing_llava_next_video.py @@ -19,6 +19,7 @@ from typing import TYPE_CHECKING, List, Optional, Union from ...feature_extraction_utils import BatchFeature +from ...image_processing_utils import select_best_resolution from ...image_utils import ImageInput, VideoInput, get_image_size, to_numpy_array from ...processing_utils import ProcessorMixin from ...tokenization_utils_base import PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy @@ -160,35 +161,29 @@ def __call__( elif not isinstance(text, list) and not isinstance(text[0], str): raise ValueError("Invalid input text. Please provide a string, or a list of strings") - print(self.patch_size, self.vision_feature_select_strategy, image_inputs, videos_inputs.keys()) - if self.patch_size is None or self.vision_feature_select_strategy is None: - prompt_strings = text logger.warning_once( "Expanding inputs for image/video tokens in LLaVa-NeXT-Video should be done in processing. " "Please add `patch_size` and `vision_feature_select_strategy` to the model's processing config or set directly " "with `processor.patch_size = {{patch_size}}` and processor.vision_feature_select_strategy = {{vision_feature_select_strategy}}`. " "Using processors without these attributes in the config is deprecated and will throw an error in v4.47." ) - # cannot infer image expansion length if no images/videos are found - elif not image_inputs and not videos_inputs: - prompt_strings = text else: # images expand taking into account num_of_patches in each image if image_inputs: - image_sizes = image_inputs["image_sizes"] + image_sizes = iter(image_inputs["image_sizes"]) height, width = get_image_size(to_numpy_array(image_inputs["pixel_values"][0][0])) prompt_strings = [] - for image_size, sample in zip(image_sizes, text): - # Replace the image token with the expanded image token sequence - orig_height, orig_width = image_size - num_image_tokens = self._get_number_of_features(orig_height, orig_width, height, width) - if self.vision_feature_select_strategy == "default": - num_image_tokens -= 1 - - sample = sample.replace(self.image_token, self.image_token * num_image_tokens) + for sample in text: + while self.image_token in sample: + image_size = next(image_sizes) + orig_height, orig_width = image_size + num_image_tokens = self._get_number_of_features(orig_height, orig_width, height, width) + if self.vision_feature_select_strategy == "default": + num_image_tokens -= 1 + sample = sample.replace(self.image_token, "" * num_image_tokens, 1) prompt_strings.append(sample) - text = prompt_strings + text = [sample.replace("", self.image_token) for sample in prompt_strings] # videos are easier, simply get frames and multiply if videos_inputs: @@ -197,23 +192,65 @@ def __call__( num_frames = one_video.shape[0] # frame dim is always after batch dim num_image_tokens = (height // self.patch_size) * (width // self.patch_size) num_video_tokens = num_image_tokens // 4 * num_frames # divide by 4 needed for avg pooling layer - prompt_strings = [] for sample in text: sample = sample.replace(self.video_token, self.video_token * num_video_tokens) prompt_strings.append(sample) + text = prompt_strings text_inputs = self.tokenizer( - prompt_strings, + text, return_tensors=return_tensors, padding=padding, truncation=truncation, max_length=max_length, ) - print(text_inputs.keys()) - return BatchFeature(data={**text_inputs, **image_inputs, **videos_inputs}) + # Copied from transformers.models.llava_next.processing_llava_next.LlavaNextProcessor._get_number_of_features + def _get_number_of_features(self, orig_height: int, orig_width: int, height: int, width: int) -> int: + image_grid_pinpoints = self.image_processor.image_grid_pinpoints + + height_best_resolution, width_best_resolution = select_best_resolution( + [orig_height, orig_width], image_grid_pinpoints + ) + scale_height, scale_width = height_best_resolution // height, width_best_resolution // width + + patches_height = height // self.patch_size + patches_width = width // self.patch_size + unpadded_features, newline_features = self._get_unpadded_features( + orig_height, orig_width, patches_height, patches_width, scale_height, scale_width + ) + # The base patch covers the entire image (+1 for the CLS) + base_features = patches_height * patches_width + 1 + num_image_tokens = unpadded_features + newline_features + base_features + return num_image_tokens + + # Copied from transformers.models.llava_next.processing_llava_next.LlavaNextProcessor._get_unpadded_features + def _get_unpadded_features(self, height, width, patches_height, patches_width, scale_height, scale_width): + """ + Get number of features for a given image with height/width. LLaVA-NeXT is different from LLaVA + because it divided each image into patches depending on its resolution. Therefore we need to calculate how many + patches an image is divided into and get the number of features from that. + """ + current_height = patches_height * scale_height + current_width = patches_width * scale_width + + original_aspect_ratio = width / height + current_aspect_ratio = current_width / current_height + if original_aspect_ratio > current_aspect_ratio: + new_height = (height * current_width) // width + padding = (current_height - new_height) // 2 + current_height -= padding * 2 + else: + new_width = (width * current_height) // height + padding = (current_width - new_width) // 2 + current_width -= padding * 2 + + unpadded_features = current_height * current_width + newline_features = current_height + return (unpadded_features, newline_features) + # Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Llama def batch_decode(self, *args, **kwargs): """ diff --git a/src/transformers/models/video_llava/modeling_video_llava.py b/src/transformers/models/video_llava/modeling_video_llava.py index 425d46bd7741..1180991c23dc 100644 --- a/src/transformers/models/video_llava/modeling_video_llava.py +++ b/src/transformers/models/video_llava/modeling_video_llava.py @@ -523,15 +523,19 @@ def forward( # if the number of image/video tokens is more than image embeddings seq length, then prob we expanded it in processing # not very reliable, but we don't expect one to actually pass 500+ images for one prompt - img_token_count = (input_ids == self.config.image_token_index).sum(1).max() - video_token_count = (input_ids == self.config.video_token_index).sum(1).max() - inputs_expanded = ( - img_token_count < self.config.image_seq_length and video_token_count < self.config.video_seq_length + img_token_not_enough = (input_ids == self.config.image_token_index).sum( + 1 + ).max() < self.config.image_seq_length + video_token_not_enough = (input_ids == self.config.video_token_index).sum( + 1 + ).max() < self.config.video_seq_length + inputs_not_expanded = (img_token_not_enough and pixel_values_images is not None) or ( + video_token_not_enough and pixel_values_videos is not None ) - pixels_present = ( - input_ids.shape[-1] == 1 and pixel_values_images is not None and pixel_values_videos is not None + pixels_present = input_ids.shape[-1] == 1 and ( + pixel_values_images is not None or pixel_values_videos is not None ) - legacy_processing = inputs_expanded or pixels_present + legacy_processing = inputs_not_expanded or pixels_present if pixel_values_images is not None or pixel_values_videos is not None: image_outputs, video_outputs, num_frames = self._get_vision_features( @@ -571,6 +575,7 @@ def forward( labels, num_frames=frames, ) + cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device) else: # Retrieve the first layer to inspect the logits and mask out the hidden states # that are set to 0 @@ -600,6 +605,9 @@ def forward( attention_mask = torch.cat((extended_attention_mask, attention_mask[:, -target_length:]), dim=1) position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1 + cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device)[ + -target_length: + ] # TODO: @raushan retain only the new behavior after v4.47 else: @@ -670,11 +678,16 @@ def prepare_inputs_for_generation( cache_position=None, **kwargs, ): - # Trigger the new behavior if we have more than image embeddings seq length tokens for images - legacy_processing = input_ids is not None and ( - (input_ids == self.config.image_token_index).sum(1).max() < self.config.image_seq_length - and (input_ids == self.config.video_token_index).sum(1).max() < self.config.video_seq_length - ) + if input_ids is not None: + img_token_not_enough = (input_ids == self.config.image_token_index).sum( + 1 + ).max() < self.config.image_seq_length + video_token_not_enough = (input_ids == self.config.video_token_index).sum( + 1 + ).max() < self.config.video_seq_length + legacy_processing = (img_token_not_enough and pixel_values_images is not None) or ( + video_token_not_enough and pixel_values_videos is not None + ) model_inputs = self.language_model.prepare_inputs_for_generation( input_ids, @@ -685,11 +698,7 @@ def prepare_inputs_for_generation( **kwargs, ) - if legacy_processing: - model_inputs["pixel_values_images"] = pixel_values_images - model_inputs["pixel_values_videos"] = pixel_values_videos - - elif cache_position[0] == 0: + if legacy_processing or cache_position[0] == 0: # If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore # Otherwise we need pixel values to be passed to model model_inputs["pixel_values_images"] = pixel_values_images diff --git a/src/transformers/models/video_llava/processing_video_llava.py b/src/transformers/models/video_llava/processing_video_llava.py index a06913d7acf7..078eba8aee70 100644 --- a/src/transformers/models/video_llava/processing_video_llava.py +++ b/src/transformers/models/video_llava/processing_video_llava.py @@ -155,7 +155,7 @@ def __call__( ) elif encoded_images is not None: # Replace the image token with the expanded image token sequence - if "pixel_values" in encoded_images: + if "pixel_values_images" in encoded_images.keys(): height, width = get_image_size(to_numpy_array(encoded_images.get("pixel_values")[0])) num_frames = 1 else: diff --git a/src/transformers/models/vipllava/modeling_vipllava.py b/src/transformers/models/vipllava/modeling_vipllava.py index b1df10fdb3dc..95c5656e0464 100644 --- a/src/transformers/models/vipllava/modeling_vipllava.py +++ b/src/transformers/models/vipllava/modeling_vipllava.py @@ -464,6 +464,7 @@ def forward( inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_image_features( image_features, inputs_embeds, input_ids, attention_mask, labels ) + cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device) else: # Retrieve the first layer to inspect the logits and mask out the hidden states # that are set to 0 @@ -493,6 +494,9 @@ def forward( attention_mask = torch.cat((extended_attention_mask, attention_mask[:, -target_length:]), dim=1) position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1 + cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device)[ + -target_length: + ] # TODO: @raushan retain only the new behavior after v4.47 else: @@ -569,9 +573,7 @@ def prepare_inputs_for_generation( **kwargs, ) - if legacy_processing: - model_inputs["pixel_values"] = pixel_values - elif cache_position[0] == 0: + if legacy_processing or cache_position[0] == 0: # If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore # Otherwise we need pixel values to be passed to model model_inputs["pixel_values"] = pixel_values diff --git a/src/transformers/pytorch_utils.py b/src/transformers/pytorch_utils.py index 8c02e0781092..f3663c09902f 100644 --- a/src/transformers/pytorch_utils.py +++ b/src/transformers/pytorch_utils.py @@ -24,7 +24,7 @@ from .utils import is_torch_xla_available, logging -ALL_LAYERNORM_LAYERS = [nn.LayerNorm, nn.RMSNorm] +ALL_LAYERNORM_LAYERS = [nn.LayerNorm] logger = logging.get_logger(__name__) diff --git a/tests/models/llava/test_modeling_llava.py b/tests/models/llava/test_modeling_llava.py index 360bbde29c18..f4f6a9d5c7cb 100644 --- a/tests/models/llava/test_modeling_llava.py +++ b/tests/models/llava/test_modeling_llava.py @@ -515,6 +515,26 @@ def test_generation_no_images(self): # Make sure that `generate` works _ = model.generate(**inputs, max_new_tokens=20) + slow + + @require_bitsandbytes + def test_generation_siglip_backbone(self): + model_id = "llava-hf/llava-interleave-qwen-0.5b-hf" + model = LlavaForConditionalGeneration.from_pretrained(model_id, torch_dtype="float16") + processor = AutoProcessor.from_pretrained(model_id) + + image_file = "http://images.cocodataset.org/val2017/000000039769.jpg" + raw_image = Image.open(requests.get(image_file, stream=True).raw) + inputs = processor("\nUSER: Describe the image.\nASSISTANT:", raw_image, return_tensors="pt").to( + torch_device, torch.float16 + ) + + # Make sure that `generate` works + output = model.generate(**inputs, max_new_tokens=30) + + EXPECTED_DECODED_TEXT = "" + self.assertTrue(processor.batch_decode(output, skip_special_tokens=True)[0] == EXPECTED_DECODED_TEXT) + @slow @require_bitsandbytes def test_expansion_in_processing(self): diff --git a/tests/models/llava_next/test_modeling_llava_next.py b/tests/models/llava_next/test_modeling_llava_next.py index c665631c4033..bc1b3133c118 100644 --- a/tests/models/llava_next/test_modeling_llava_next.py +++ b/tests/models/llava_next/test_modeling_llava_next.py @@ -471,16 +471,17 @@ def test_small_model_integration_test_batch_different_resolutions(self): output = model(**inputs) expected_slice = torch.tensor( - [[-0.0308, -0.0313, -0.0314], [-0.3064, -0.3013, -0.2986], [-0.1226, -0.1246, -0.1210]], + [[-0.1287, -0.1294, -0.1284], [-0.2744, -0.2698, -0.2671], [-0.1071, -0.1091, -0.1056]], dtype=torch.float32, device=torch_device, ) assert torch.allclose(output.logits[0, -3:, -3:], expected_slice, atol=1e-3) - assert torch.allclose(output.loss, torch.tensor(6.8619, device=torch_device)) + assert torch.allclose(output.loss, torch.tensor(7.0206, device=torch_device)) # verify generation output = model.generate(**inputs, max_new_tokens=50) - EXPECTED_DECODED_TEXT = '[INST] \nWhat is shown in this image? [/INST] The image shows a forested area with a misty or foggy atmosphere. In the foreground, there is a grassy field with a few deer grazing. The deer are partially obscured by the fog, and the trees in the background' # fmt: skip + print(self.processor.decode(output[0], skip_special_tokens=True)) + EXPECTED_DECODED_TEXT = '[INST] \nWhat is shown in this image? [/INST] The image shows two deer, likely fawns, in a grassy area with trees in the background. The setting appears to be a forest or woodland, and the photo is taken during what seems to be either dawn or dusk, given' # fmt: skip self.assertEqual( self.processor.decode(output[0], skip_special_tokens=True), EXPECTED_DECODED_TEXT, @@ -567,6 +568,46 @@ def test_padding_side_when_merging_inputs(self): "Padding side is set to 'right' but the model is in inference mode. For correct", logs.output[0] ) + @slow + @require_bitsandbytes + def test_expansion_in_processing_multiimage(self): + model_id = "llava-hf/llava-v1.6-mistral-7b-hf" + model = LlavaNextForConditionalGeneration.from_pretrained(model_id, load_in_4bit=True) + processor = AutoProcessor.from_pretrained(model_id) + + prompt = "USER: \nDescribe the similarity between the two images:\nASSISTANT:" + image_file = "http://images.cocodataset.org/val2017/000000039769.jpg" + raw_image = Image.open(requests.get(image_file, stream=True).raw) + deer_image = Image.open( + requests.get( + "https://4.img-dpreview.com/files/p/TS560x560~forums/56876524/03975b28741443319e9a94615e35667e", + stream=True, + ).raw + ) + + # check processing with expansion of inputs + processor.vision_feature_select_strategy = "default" + processor.num_image_tokens = 577 + inputs_expanded = processor(text=prompt, images=[raw_image, deer_image], return_tensors="pt").to( + torch_device, torch.float16 + ) + self.assertTrue(inputs_expanded.input_ids.shape[-1] == 3969) + + # check processing without expansion of inputs (legacy behavior) + processor.vision_feature_select_strategy = None + processor.num_image_tokens = None + inputs = processor(text=prompt, images=[raw_image, deer_image], return_tensors="pt").to( + torch_device, torch.float16 + ) + self.assertTrue(inputs.input_ids.shape[-1] == 23) + + # generate exactly 20 tokens + output = model.generate(**inputs, min_new_tokens=20, max_new_tokens=20) + output_expanded = model.generate(**inputs_expanded, min_new_tokens=20, max_new_tokens=20) + + # check that both inputs are handled correctly and generate the same output + self.assertListEqual(output_expanded[:, -20:].tolist(), output[:, -20:].tolist()) + @slow @require_bitsandbytes def test_expansion_in_processing(self): diff --git a/tests/models/llava_next_video/test_modeling_llava_next_video.py b/tests/models/llava_next_video/test_modeling_llava_next_video.py index 38b1782b75d6..57d1da2fd617 100644 --- a/tests/models/llava_next_video/test_modeling_llava_next_video.py +++ b/tests/models/llava_next_video/test_modeling_llava_next_video.py @@ -18,6 +18,7 @@ import unittest import numpy as np +import requests from huggingface_hub import hf_hub_download from transformers import ( @@ -556,3 +557,73 @@ def test_expansion_in_processing(self): # check that both inputs are handled correctly and generate the same output self.assertListEqual(output_expanded[:, -20:].tolist(), output[:, -20:].tolist()) + + @slow + @require_bitsandbytes + def test_expansion_in_processing_images(self): + model_id = "llava-hf/LLaVA-NeXT-Video-7B-hf" + model = LlavaNextVideoForConditionalGeneration.from_pretrained( + "llava-hf/LLaVA-NeXT-Video-7B-hf", load_in_4bit=True + ) + processor = AutoProcessor.from_pretrained(model_id) + + # check processing with expansion of inputs + processor.vision_feature_select_strategy = "default" + processor.num_image_tokens = 577 + inputs_expanded = processor(self.prompt_image, images=[self.image], return_tensors="pt").to(torch_device) + self.assertTrue(inputs_expanded.input_ids.shape[-1] == 2652) + + # check processing without expansion of inputs (legacy behavior) + processor.vision_feature_select_strategy = None + processor.num_image_tokens = None + inputs = processor(self.prompt_image, images=[self.image], return_tensors="pt").to(torch_device) + self.assertTrue(inputs.input_ids.shape[-1] == 19) + + # generate exactly 20 tokens + output = model.generate(**inputs, min_new_tokens=20, max_new_tokens=20) + output_expanded = model.generate(**inputs_expanded, min_new_tokens=20, max_new_tokens=20) + + # check that both inputs are handled correctly and generate the same output + self.assertListEqual(output_expanded[:, -20:].tolist(), output[:, -20:].tolist()) + + @slow + @require_bitsandbytes + def test_expansion_in_processing_multiimage(self): + model_id = "llava-hf/LLaVA-NeXT-Video-7B-hf" + model = LlavaNextVideoForConditionalGeneration.from_pretrained( + "llava-hf/LLaVA-NeXT-Video-7B-hf", load_in_4bit=True + ) + processor = AutoProcessor.from_pretrained(model_id) + + prompt = "USER: \nDescribe the similarity between the two images:\nASSISTANT:" + image_file = "http://images.cocodataset.org/val2017/000000039769.jpg" + raw_image = Image.open(requests.get(image_file, stream=True).raw) + deer_image = Image.open( + requests.get( + "https://4.img-dpreview.com/files/p/TS560x560~forums/56876524/03975b28741443319e9a94615e35667e", + stream=True, + ).raw + ) + + # check processing with expansion of inputs + processor.vision_feature_select_strategy = "default" + processor.num_image_tokens = 577 + inputs_expanded = processor(text=prompt, images=[raw_image, deer_image], return_tensors="pt").to( + torch_device, torch.float16 + ) + self.assertTrue(inputs_expanded.input_ids.shape[-1] == 3968) + + # check processing without expansion of inputs (legacy behavior) + processor.vision_feature_select_strategy = None + processor.num_image_tokens = None + inputs = processor(text=prompt, images=[raw_image, deer_image], return_tensors="pt").to( + torch_device, torch.float16 + ) + self.assertTrue(inputs.input_ids.shape[-1] == 22) + + # generate exactly 20 tokens + output = model.generate(**inputs, min_new_tokens=20, max_new_tokens=20) + output_expanded = model.generate(**inputs_expanded, min_new_tokens=20, max_new_tokens=20) + + # check that both inputs are handled correctly and generate the same output + self.assertListEqual(output_expanded[:, -20:].tolist(), output[:, -20:].tolist()) diff --git a/tests/models/video_llava/test_modeling_video_llava.py b/tests/models/video_llava/test_modeling_video_llava.py index 29fa3b71589a..3772d1d2872e 100644 --- a/tests/models/video_llava/test_modeling_video_llava.py +++ b/tests/models/video_llava/test_modeling_video_llava.py @@ -383,18 +383,19 @@ def test_small_model_integration_test(self): # Let' s make sure we test the preprocessing to replace what is used model = VideoLlavaForConditionalGeneration.from_pretrained("LanguageBind/Video-LLaVA-7B-hf", load_in_4bit=True) - prompt = "USER: