Skip to content

Commit

Permalink
add support for non nested images and add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
yonigozlan committed Oct 21, 2024
1 parent ca541bd commit e45cc3a
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 33 deletions.
15 changes: 14 additions & 1 deletion src/transformers/models/idefics2/processing_idefics2.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,20 @@ def __call__(
if is_image_or_image_url(images):
images = [[images]]
elif isinstance(images, list) and is_image_or_image_url(images[0]):
images = [images]
if text is not None:
if sum(n_images_in_text) != len(images):
raise ValueError(
f"The total number of {image_token} tokens in the prompts should be the same as the number of images passed."
f" Found {sum(n_images_in_text)} {image_token} tokens and {len(images)} images."
)
# Reorganize the images to match the prompts
images = [
images[sum(n_images_in_text[:i]) : sum(n_images_in_text[: i + 1])]
for i in range(len(n_images_in_text))
]
else:
images = [images]

elif (
not isinstance(images, list)
and not isinstance(images[0], list)
Expand Down
35 changes: 23 additions & 12 deletions src/transformers/models/idefics3/processing_idefics3.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,11 +241,29 @@ def __call__(
n_images_in_images = []
inputs = BatchFeature()

if text is not None:
if isinstance(text, str):
text = [text]
elif not isinstance(text, list) and not isinstance(text[0], str):
raise ValueError("Invalid input text. Please provide a string, or a list of strings")
n_images_in_text = [sample.count(self.image_token.content) for sample in text]

if images is not None:
if is_image_or_image_url(images):
images = [[images]]
elif isinstance(images, list) and is_image_or_image_url(images[0]):
images = [images]
if text is not None:
if sum(n_images_in_text) != len(images):
raise ValueError(
f"The total number of {self.image_token.content} tokens in the prompts should be the same as the number of images passed."
f" Found {sum(n_images_in_text)} {self.image_token.content} tokens and {len(images)} images."
)
images = [
images[sum(n_images_in_text[:i]) : sum(n_images_in_text[: i + 1])]
for i in range(len(n_images_in_text))
]
else:
images = [images]
elif (
not isinstance(images, list)
and not isinstance(images[0], list)
Expand All @@ -263,10 +281,10 @@ def __call__(
inputs.update(image_inputs)

if text is not None:
if isinstance(text, str):
text = [text]
elif not isinstance(text, list) and not isinstance(text[0], str):
raise ValueError("Invalid input text. Please provide a string, or a list of strings")
if n_images_in_images != n_images_in_text:
raise ValueError(
f"The number of images in the text {n_images_in_text} and images {n_images_in_images} should be the same."
)

image_rows = inputs.pop("rows", [[0] * len(text)])
image_cols = inputs.pop("cols", [[0] * len(text)])
Expand All @@ -277,8 +295,6 @@ def __call__(

prompt_strings = []
for sample, sample_rows, sample_cols in zip(text, image_rows, image_cols):
n_images_in_text.append(sample.count(image_token))

# Replace the image token with fake tokens around the expanded image token sequence of length `image_seq_len`
image_prompt_strings = []
for n_rows, n_cols in zip(sample_rows, sample_cols):
Expand All @@ -305,11 +321,6 @@ def __call__(
text_inputs = self.tokenizer(text=prompt_strings, **output_kwargs["text_kwargs"])
inputs.update(text_inputs)

if n_images_in_images != n_images_in_text:
raise ValueError(
f"The number of images in the text {n_images_in_text} and images {n_images_in_images} should be the same."
)

return inputs

def batch_decode(self, *args, **kwargs):
Expand Down
29 changes: 19 additions & 10 deletions tests/models/idefics2/test_processor_idefics2.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,25 @@ def test_add_special_tokens_processor(self):
self.assertEqual(inputs["input_ids"], expected_input_ids)
# fmt: on

def test_non_nested_images_with_batched_text(self):
processor = self.get_processor()
processor.image_processor.do_image_splitting = False

image_str = "<image>"
text_str_1 = "In this image, we see"
text_str_2 = "bla, bla"

text = [
image_str + text_str_1,
text_str_2 + image_str + image_str,
]
images = [self.image1, self.image2, self.image3]

inputs = processor(text=text, images=images, padding=True)

self.assertEqual(inputs["pixel_values"].shape, (2, 2, 3, 767, 980))
self.assertEqual(inputs["pixel_attention_mask"].shape, (2, 2, 767, 980))

def test_apply_chat_template(self):
# Message contains content which a mix of lists with images and image urls and string
messages = [
Expand Down Expand Up @@ -275,13 +294,3 @@ def prepare_text_inputs(self, batch_size: Optional[int] = None):
return ["lower newer <image>", "<image> upper older longer string"] + ["<image> lower newer"] * (
batch_size - 2
)

# Override as PixtralProcessor needs nested images to work properly with batched inputs
@require_vision
def prepare_image_inputs(self, batch_size: Optional[int] = None):
"""This function prepares a list of PIL images for testing"""
if batch_size is None:
return super().prepare_image_inputs()
if batch_size < 1:
raise ValueError("batch_size must be greater than 0")
return [[super().prepare_image_inputs()]] * batch_size
30 changes: 20 additions & 10 deletions tests/models/idefics3/test_processor_idefics3.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,25 @@ def test_add_special_tokens_processor(self):
self.assertEqual(inputs["input_ids"], expected_input_ids)
# fmt: on

def test_non_nested_images_with_batched_text(self):
processor = self.get_processor()
processor.image_processor.do_image_splitting = False

image_str = "<image>"
text_str_1 = "In this image, we see"
text_str_2 = "In this image, we see"

text = [
image_str + text_str_1,
image_str + image_str + text_str_2,
]
images = [self.image1, self.image2, self.image3]

inputs = processor(text=text, images=images, padding=True)

self.assertEqual(np.array(inputs["pixel_values"]).shape, (2, 2, 3, 364, 364))
self.assertEqual(np.array(inputs["pixel_attention_mask"]).shape, (2, 2, 364, 364))

def test_apply_chat_template(self):
# Message contains content which a mix of lists with images and image urls and string
messages = [
Expand Down Expand Up @@ -299,16 +318,7 @@ def prepare_text_inputs(self, batch_size: Optional[int] = None):
batch_size - 2
)

# Override as Idefics3Processor needs nested images to work properly with batched inputs
@require_vision
def prepare_image_inputs(self, batch_size: Optional[int] = None):
"""This function prepares a list of PIL images for testing"""
if batch_size is None:
return super().prepare_image_inputs()
if batch_size < 1:
raise ValueError("batch_size must be greater than 0")
return [[super().prepare_image_inputs()]] * batch_size

# Override tests as inputs_ids padded dimension is the second one but not the last one
@require_vision
@require_torch
def test_kwargs_overrides_default_tokenizer_kwargs(self):
Expand Down

0 comments on commit e45cc3a

Please sign in to comment.