diff --git a/src/transformers/models/perception_lm/image_processing_perception_lm_fast.py b/src/transformers/models/perception_lm/image_processing_perception_lm_fast.py index 919a1203128e..7fe3293ea28f 100644 --- a/src/transformers/models/perception_lm/image_processing_perception_lm_fast.py +++ b/src/transformers/models/perception_lm/image_processing_perception_lm_fast.py @@ -310,7 +310,7 @@ def _preprocess( ) processed_images_grouped[shape] = stacked_images processed_images = reorder_images(processed_images_grouped, grouped_images_index) - + processed_images = [p[None] if p.ndim == 3 else p for p in processed_images] # add tiles dimension if needed processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors) diff --git a/tests/models/perception_lm/test_processing_perception_lm.py b/tests/models/perception_lm/test_processing_perception_lm.py index c6384e4b456c..a0d2c19fbf0e 100644 --- a/tests/models/perception_lm/test_processing_perception_lm.py +++ b/tests/models/perception_lm/test_processing_perception_lm.py @@ -115,6 +115,36 @@ def test_image_token_filling(self): ) image_tokens = (inputs["input_ids"] == image_token_index).sum().item() self.assertEqual(expected_image_tokens, image_tokens) + self.assertEqual(inputs["pixel_values"].ndim, 5) + + def test_vanilla_image_with_no_tiles_token_filling(self): + processor = self.processor_class.from_pretrained(self.tmpdirname) + processor.image_processor.vision_input_type = "vanilla" + # Important to check with non square image + image = torch.randn((1, 3, 450, 500)) + # 1 tile + # 448/patch_size/pooling_ratio = 16 => 16*16 tokens per tile + expected_image_tokens = 16 * 16 * 1 + image_token_index = processor.image_token_id + + messages = [ + { + "role": "user", + "content": [ + {"type": "image"}, + {"type": "text", "text": "What is shown in this image?"}, + ], + }, + ] + inputs = processor( + text=[processor.apply_chat_template(messages)], + images=[image], + return_tensors="pt", + ) + image_tokens = (inputs["input_ids"] == image_token_index).sum().item() + self.assertEqual(expected_image_tokens, image_tokens) + self.assertEqual(inputs["pixel_values"].ndim, 5) + self.assertEqual(inputs["pixel_values"].shape[1], 1) # 1 tile CHAT_TEMPLATE = (