Skip to content

Commit 3ff2e98

Browse files
authored
Fix PerceptionLM image preprocessing for non-tiled image input. (#40006)
* Fix PerceptionLM image preprocessing for non-tiled image input. * Add test for single tile vanilla image processing. * ruff format * recover missing test skip * Simplify test. * minor test name fix
1 parent 4668ef1 commit 3ff2e98

File tree

2 files changed

+31
-1
lines changed

2 files changed

+31
-1
lines changed

src/transformers/models/perception_lm/image_processing_perception_lm_fast.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -310,7 +310,7 @@ def _preprocess(
310310
)
311311
processed_images_grouped[shape] = stacked_images
312312
processed_images = reorder_images(processed_images_grouped, grouped_images_index)
313-
313+
processed_images = [p[None] if p.ndim == 3 else p for p in processed_images] # add tiles dimension if needed
314314
processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images
315315
return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors)
316316

tests/models/perception_lm/test_processing_perception_lm.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,36 @@ def test_image_token_filling(self):
115115
)
116116
image_tokens = (inputs["input_ids"] == image_token_index).sum().item()
117117
self.assertEqual(expected_image_tokens, image_tokens)
118+
self.assertEqual(inputs["pixel_values"].ndim, 5)
119+
120+
def test_vanilla_image_with_no_tiles_token_filling(self):
121+
processor = self.processor_class.from_pretrained(self.tmpdirname)
122+
processor.image_processor.vision_input_type = "vanilla"
123+
# Important to check with non square image
124+
image = torch.randn((1, 3, 450, 500))
125+
# 1 tile
126+
# 448/patch_size/pooling_ratio = 16 => 16*16 tokens per tile
127+
expected_image_tokens = 16 * 16 * 1
128+
image_token_index = processor.image_token_id
129+
130+
messages = [
131+
{
132+
"role": "user",
133+
"content": [
134+
{"type": "image"},
135+
{"type": "text", "text": "What is shown in this image?"},
136+
],
137+
},
138+
]
139+
inputs = processor(
140+
text=[processor.apply_chat_template(messages)],
141+
images=[image],
142+
return_tensors="pt",
143+
)
144+
image_tokens = (inputs["input_ids"] == image_token_index).sum().item()
145+
self.assertEqual(expected_image_tokens, image_tokens)
146+
self.assertEqual(inputs["pixel_values"].ndim, 5)
147+
self.assertEqual(inputs["pixel_values"].shape[1], 1) # 1 tile
118148

119149

120150
CHAT_TEMPLATE = (

0 commit comments

Comments
 (0)