Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

VLM: fixes after refactor #32907

Merged
merged 11 commits into from
Sep 10, 2024
8 changes: 5 additions & 3 deletions src/transformers/models/llava/modeling_llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,6 +476,7 @@ def forward(
inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_image_features(
image_features, inputs_embeds, input_ids, attention_mask, labels
)
cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device)
else:
# Retrieve the first layer to inspect the logits and mask out the hidden states
# that are set to 0
Expand Down Expand Up @@ -506,6 +507,9 @@ def forward(

attention_mask = torch.cat((extended_attention_mask, attention_mask[:, -target_length:]), dim=1)
position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device)[
-target_length:
]

# TODO: @raushan retain only the new behavior after v4.47
else:
Expand Down Expand Up @@ -585,9 +589,7 @@ def prepare_inputs_for_generation(
**kwargs,
)

if legacy_processing:
model_inputs["pixel_values"] = pixel_values
elif cache_position[0] == 0:
if legacy_processing or cache_position[0] == 0:
# If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
# Otherwise we need pixel values to be passed to model
model_inputs["pixel_values"] = pixel_values
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/llava/processing_llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ def __call__(
raise ValueError("Invalid input text. Please provide a string, or a list of strings")

# try to expand inputs in processing if we have the necessary parts
prompt_strings = text
if image_inputs.get("pixel_values") is not None:
if self.patch_size is not None and self.vision_feature_select_strategy is not None:
# Replace the image token with the expanded image token sequence
Expand All @@ -150,7 +151,6 @@ def __call__(
sample = sample.replace(self.image_token, self.image_token * num_image_tokens)
prompt_strings.append(sample)
else:
prompt_strings = text
logger.warning_once(
"Expanding inputs for image tokens in LLaVa should be done in processing. "
"Please add `patch_size` and `vision_feature_select_strategy` to the model's processing config or set directly "
Expand Down
13 changes: 7 additions & 6 deletions src/transformers/models/llava_next/modeling_llava_next.py
Original file line number Diff line number Diff line change
Expand Up @@ -848,6 +848,7 @@ def forward(
position_ids,
labels=labels,
)
cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device)
else:
# Retrieve the first layer to inspect the logits and mask out the hidden states
# that are set to 0
Expand Down Expand Up @@ -877,6 +878,9 @@ def forward(
extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0
attention_mask = torch.cat((extended_attention_mask, attention_mask[:, -target_length:]), dim=1)
position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device)[
-target_length:
]

# TODO: @raushan retain only the new behavior after v4.47
else:
Expand Down Expand Up @@ -956,12 +960,9 @@ def prepare_inputs_for_generation(
**kwargs,
)

if legacy_processing:
model_inputs["pixel_values"] = pixel_values
model_inputs["image_sizes"] = image_sizes
elif cache_position[0] == 0:
# If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
# Otherwise we need pixel values to be passed to model
# If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
# Otherwise we need pixel values to be passed to model
if legacy_processing or cache_position[0] == 0:
model_inputs["pixel_values"] = pixel_values
model_inputs["image_sizes"] = image_sizes

Expand Down
47 changes: 23 additions & 24 deletions src/transformers/models/llava_next/processing_llava_next.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,30 +140,29 @@ def __call__(
elif not isinstance(text, list) and not isinstance(text[0], str):
raise ValueError("Invalid input text. Please provide a string, or a list of strings")

if self.patch_size is None or self.vision_feature_select_strategy is None:
prompt_strings = text
logger.warning_once(
"Expanding inputs for image tokens in LLaVa-NeXT should be done in processing. "
"Please add `patch_size` and `vision_feature_select_strategy` to the model's processing config or set directly "
"with `processor.patch_size = {{patch_size}}` and processor.vision_feature_select_strategy = {{vision_feature_select_strategy}}`. "
"Using processors without these attributes in the config is deprecated and will throw an error in v4.47."
)
# cannot infer image expansion length if no images are found
elif not image_inputs:
prompt_strings = text
else:
image_sizes = image_inputs["image_sizes"]
height, width = get_image_size(to_numpy_array(image_inputs["pixel_values"][0][0]))
prompt_strings = []
for image_size, sample in zip(image_sizes, text):
# Replace the image token with the expanded image token sequence
orig_height, orig_width = image_size
num_image_tokens = self._get_number_of_features(orig_height, orig_width, height, width)
if self.vision_feature_select_strategy == "default":
num_image_tokens -= 1

sample = sample.replace(self.image_token, self.image_token * num_image_tokens)
prompt_strings.append(sample)
prompt_strings = text
if image_inputs:
if self.patch_size is None or self.vision_feature_select_strategy is None:
logger.warning_once(
"Expanding inputs for image tokens in LLaVa-NeXT should be done in processing. "
"Please add `patch_size` and `vision_feature_select_strategy` to the model's processing config or set directly "
"with `processor.patch_size = {{patch_size}}` and processor.vision_feature_select_strategy = {{vision_feature_select_strategy}}`. "
"Using processors without these attributes in the config is deprecated and will throw an error in v4.47."
Comment on lines +143 to +150
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was needed for cases with mutli-image inputs where we cannot be sure that number of image sizes is same as text. For ex, one text and two images

)
else:
image_sizes = iter(image_inputs["image_sizes"])
height, width = get_image_size(to_numpy_array(image_inputs["pixel_values"][0][0]))
prompt_strings = []
for sample in text:
while self.image_token in sample:
image_size = next(image_sizes)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this right? The previous logic implies len(image_sizes) == len(text) however this is implying we have the same number of image_sizes as the number of image tokens per sample * number of samples.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I verified twice to be sure. The number of images == number of image sizes. The new added test multiimage_expansion fails with the previous logic

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@amyeroberts do you have any other concerns regarding this PR?

orig_height, orig_width = image_size
num_image_tokens = self._get_number_of_features(orig_height, orig_width, height, width)
if self.vision_feature_select_strategy == "default":
num_image_tokens -= 1
sample = sample.replace(self.image_token, "<placeholder>" * num_image_tokens, 1)
prompt_strings.append(sample)
prompt_strings = [sample.replace("<placeholder>", self.image_token) for sample in prompt_strings]

text_inputs = self.tokenizer(
prompt_strings,
Expand Down
Loading
Loading