From e45cc3a8049fdcea89293c614a0b3ce234c8e186 Mon Sep 17 00:00:00 2001 From: yonigozlan Date: Thu, 17 Oct 2024 13:45:31 +0000 Subject: [PATCH] add support for non nested images and add tests --- .../models/idefics2/processing_idefics2.py | 15 +++++++- .../models/idefics3/processing_idefics3.py | 35 ++++++++++++------- .../idefics2/test_processor_idefics2.py | 29 +++++++++------ .../idefics3/test_processor_idefics3.py | 30 ++++++++++------ 4 files changed, 76 insertions(+), 33 deletions(-) diff --git a/src/transformers/models/idefics2/processing_idefics2.py b/src/transformers/models/idefics2/processing_idefics2.py index 68566d182678c2..bb063476f28265 100644 --- a/src/transformers/models/idefics2/processing_idefics2.py +++ b/src/transformers/models/idefics2/processing_idefics2.py @@ -218,7 +218,20 @@ def __call__( if is_image_or_image_url(images): images = [[images]] elif isinstance(images, list) and is_image_or_image_url(images[0]): - images = [images] + if text is not None: + if sum(n_images_in_text) != len(images): + raise ValueError( + f"The total number of {image_token} tokens in the prompts should be the same as the number of images passed." + f" Found {sum(n_images_in_text)} {image_token} tokens and {len(images)} images." + ) + # Reorganize the images to match the prompts + images = [ + images[sum(n_images_in_text[:i]) : sum(n_images_in_text[: i + 1])] + for i in range(len(n_images_in_text)) + ] + else: + images = [images] + elif ( not isinstance(images, list) and not isinstance(images[0], list) diff --git a/src/transformers/models/idefics3/processing_idefics3.py b/src/transformers/models/idefics3/processing_idefics3.py index ceafa26a8b1187..e13e8664d5e455 100644 --- a/src/transformers/models/idefics3/processing_idefics3.py +++ b/src/transformers/models/idefics3/processing_idefics3.py @@ -241,11 +241,29 @@ def __call__( n_images_in_images = [] inputs = BatchFeature() + if text is not None: + if isinstance(text, str): + text = [text] + elif not isinstance(text, list) and not isinstance(text[0], str): + raise ValueError("Invalid input text. Please provide a string, or a list of strings") + n_images_in_text = [sample.count(self.image_token.content) for sample in text] + if images is not None: if is_image_or_image_url(images): images = [[images]] elif isinstance(images, list) and is_image_or_image_url(images[0]): - images = [images] + if text is not None: + if sum(n_images_in_text) != len(images): + raise ValueError( + f"The total number of {self.image_token.content} tokens in the prompts should be the same as the number of images passed." + f" Found {sum(n_images_in_text)} {self.image_token.content} tokens and {len(images)} images." + ) + images = [ + images[sum(n_images_in_text[:i]) : sum(n_images_in_text[: i + 1])] + for i in range(len(n_images_in_text)) + ] + else: + images = [images] elif ( not isinstance(images, list) and not isinstance(images[0], list) @@ -263,10 +281,10 @@ def __call__( inputs.update(image_inputs) if text is not None: - if isinstance(text, str): - text = [text] - elif not isinstance(text, list) and not isinstance(text[0], str): - raise ValueError("Invalid input text. Please provide a string, or a list of strings") + if n_images_in_images != n_images_in_text: + raise ValueError( + f"The number of images in the text {n_images_in_text} and images {n_images_in_images} should be the same." + ) image_rows = inputs.pop("rows", [[0] * len(text)]) image_cols = inputs.pop("cols", [[0] * len(text)]) @@ -277,8 +295,6 @@ def __call__( prompt_strings = [] for sample, sample_rows, sample_cols in zip(text, image_rows, image_cols): - n_images_in_text.append(sample.count(image_token)) - # Replace the image token with fake tokens around the expanded image token sequence of length `image_seq_len` image_prompt_strings = [] for n_rows, n_cols in zip(sample_rows, sample_cols): @@ -305,11 +321,6 @@ def __call__( text_inputs = self.tokenizer(text=prompt_strings, **output_kwargs["text_kwargs"]) inputs.update(text_inputs) - if n_images_in_images != n_images_in_text: - raise ValueError( - f"The number of images in the text {n_images_in_text} and images {n_images_in_images} should be the same." - ) - return inputs def batch_decode(self, *args, **kwargs): diff --git a/tests/models/idefics2/test_processor_idefics2.py b/tests/models/idefics2/test_processor_idefics2.py index bf713c6fb8cfbb..ed9f489c0dd268 100644 --- a/tests/models/idefics2/test_processor_idefics2.py +++ b/tests/models/idefics2/test_processor_idefics2.py @@ -226,6 +226,25 @@ def test_add_special_tokens_processor(self): self.assertEqual(inputs["input_ids"], expected_input_ids) # fmt: on + def test_non_nested_images_with_batched_text(self): + processor = self.get_processor() + processor.image_processor.do_image_splitting = False + + image_str = "" + text_str_1 = "In this image, we see" + text_str_2 = "bla, bla" + + text = [ + image_str + text_str_1, + text_str_2 + image_str + image_str, + ] + images = [self.image1, self.image2, self.image3] + + inputs = processor(text=text, images=images, padding=True) + + self.assertEqual(inputs["pixel_values"].shape, (2, 2, 3, 767, 980)) + self.assertEqual(inputs["pixel_attention_mask"].shape, (2, 2, 767, 980)) + def test_apply_chat_template(self): # Message contains content which a mix of lists with images and image urls and string messages = [ @@ -275,13 +294,3 @@ def prepare_text_inputs(self, batch_size: Optional[int] = None): return ["lower newer ", " upper older longer string"] + [" lower newer"] * ( batch_size - 2 ) - - # Override as PixtralProcessor needs nested images to work properly with batched inputs - @require_vision - def prepare_image_inputs(self, batch_size: Optional[int] = None): - """This function prepares a list of PIL images for testing""" - if batch_size is None: - return super().prepare_image_inputs() - if batch_size < 1: - raise ValueError("batch_size must be greater than 0") - return [[super().prepare_image_inputs()]] * batch_size diff --git a/tests/models/idefics3/test_processor_idefics3.py b/tests/models/idefics3/test_processor_idefics3.py index a53109b02b6951..365edd048cb9b7 100644 --- a/tests/models/idefics3/test_processor_idefics3.py +++ b/tests/models/idefics3/test_processor_idefics3.py @@ -250,6 +250,25 @@ def test_add_special_tokens_processor(self): self.assertEqual(inputs["input_ids"], expected_input_ids) # fmt: on + def test_non_nested_images_with_batched_text(self): + processor = self.get_processor() + processor.image_processor.do_image_splitting = False + + image_str = "" + text_str_1 = "In this image, we see" + text_str_2 = "In this image, we see" + + text = [ + image_str + text_str_1, + image_str + image_str + text_str_2, + ] + images = [self.image1, self.image2, self.image3] + + inputs = processor(text=text, images=images, padding=True) + + self.assertEqual(np.array(inputs["pixel_values"]).shape, (2, 2, 3, 364, 364)) + self.assertEqual(np.array(inputs["pixel_attention_mask"]).shape, (2, 2, 364, 364)) + def test_apply_chat_template(self): # Message contains content which a mix of lists with images and image urls and string messages = [ @@ -299,16 +318,7 @@ def prepare_text_inputs(self, batch_size: Optional[int] = None): batch_size - 2 ) - # Override as Idefics3Processor needs nested images to work properly with batched inputs - @require_vision - def prepare_image_inputs(self, batch_size: Optional[int] = None): - """This function prepares a list of PIL images for testing""" - if batch_size is None: - return super().prepare_image_inputs() - if batch_size < 1: - raise ValueError("batch_size must be greater than 0") - return [[super().prepare_image_inputs()]] * batch_size - + # Override tests as inputs_ids padded dimension is the second one but not the last one @require_vision @require_torch def test_kwargs_overrides_default_tokenizer_kwargs(self):