Skip to content

Commit dae1ccf

Browse files
lambertwjhYour Nameyonigozlan
authored
fix_image_processing_fast_for_glm4v (#40483)
* fix_image_processing_fast_for_glm4v * fix(format): auto-ruff format * add test image processing glm4v * fix quality --------- Co-authored-by: Your Name <you@example.com> Co-authored-by: yonigozlan <yoni.gozlan@huggingface.co>
1 parent 7d57b31 commit dae1ccf

File tree

2 files changed

+310
-39
lines changed

2 files changed

+310
-39
lines changed

src/transformers/models/glm4v/image_processing_glm4v_fast.py

Lines changed: 56 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222
from ...image_processing_utils_fast import (
2323
BaseImageProcessorFast,
2424
DefaultFastImageProcessorKwargs,
25+
group_images_by_shape,
26+
reorder_images,
2527
)
2628
from ...image_utils import (
2729
OPENAI_CLIP_MEAN,
@@ -128,46 +130,54 @@ def _preprocess(
128130
Preprocess an image or batch of images. Copy of the `preprocess` method from `CLIPImageProcessor`.
129131
"""
130132

131-
processed_images = []
132-
processed_grids = []
133-
134-
all_target_sizes = []
135-
for image in images:
136-
height, width = image.shape[-2:]
137-
resized_height, resized_width = smart_resize(
138-
num_frames=temporal_patch_size,
139-
height=height,
140-
width=width,
141-
temporal_factor=temporal_patch_size,
142-
factor=patch_size * merge_size,
143-
min_pixels=size.shortest_edge,
144-
max_pixels=size.longest_edge,
145-
)
146-
all_target_sizes.append((resized_height, resized_width))
147-
148-
target_height = max([s[0] for s in all_target_sizes])
149-
target_width = max([s[1] for s in all_target_sizes])
150-
151-
for image in images:
133+
grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping)
134+
resized_images_grouped = {}
135+
for shape, stacked_images in grouped_images.items():
136+
height, width = stacked_images.shape[-2:]
152137
if do_resize:
153-
image = self.resize(
154-
image,
155-
size=SizeDict(height=target_height, width=target_width),
138+
resized_height, resized_width = smart_resize(
139+
num_frames=temporal_patch_size,
140+
height=height,
141+
width=width,
142+
temporal_factor=temporal_patch_size,
143+
factor=patch_size * merge_size,
144+
min_pixels=size.shortest_edge,
145+
max_pixels=size.longest_edge,
146+
)
147+
stacked_images = self.resize(
148+
stacked_images,
149+
size=SizeDict(height=resized_height, width=resized_width),
156150
interpolation=interpolation,
157151
)
152+
resized_images_grouped[shape] = stacked_images
153+
154+
resized_images = reorder_images(resized_images_grouped, grouped_images_index)
155+
156+
grouped_images, grouped_images_index = group_images_by_shape(resized_images, disable_grouping=disable_grouping)
157+
processed_images_grouped = {}
158+
processed_grids = {}
159+
160+
for shape, stacked_images in grouped_images.items():
161+
resized_height, resized_width = stacked_images.shape[-2:]
162+
163+
patches = self.rescale_and_normalize(
164+
stacked_images, do_rescale, rescale_factor, do_normalize, image_mean, image_std
165+
)
166+
if patches.ndim == 4: # (B, C, H, W)
167+
patches = patches.unsqueeze(1) # (B, T=1, C, H, W)
168+
169+
if patches.shape[1] % temporal_patch_size != 0:
170+
repeats = patches[:, -1:].repeat(
171+
1, temporal_patch_size - (patches.shape[1] % temporal_patch_size), 1, 1, 1
172+
)
173+
patches = torch.cat([patches, repeats], dim=1)
174+
175+
batch_size, t_len, channel = patches.shape[:3]
176+
grid_t = t_len // temporal_patch_size
177+
grid_h, grid_w = resized_height // patch_size, resized_width // patch_size
158178

159-
image = self.rescale_and_normalize(
160-
image.unsqueeze(0), do_rescale, rescale_factor, do_normalize, image_mean, image_std
161-
).squeeze(0)
162-
163-
patches = image.unsqueeze(0)
164-
if patches.shape[0] % temporal_patch_size != 0:
165-
repeats = patches[-1:].repeat(temporal_patch_size - (patches.shape[0] % temporal_patch_size), 1, 1, 1)
166-
patches = torch.cat([patches, repeats], dim=0)
167-
channel = patches.shape[1]
168-
grid_t = patches.shape[0] // temporal_patch_size
169-
grid_h, grid_w = target_height // patch_size, target_width // patch_size
170179
patches = patches.view(
180+
batch_size,
171181
grid_t,
172182
temporal_patch_size,
173183
channel,
@@ -178,15 +188,22 @@ def _preprocess(
178188
merge_size,
179189
patch_size,
180190
)
181-
patches = patches.permute(0, 3, 6, 4, 7, 2, 1, 5, 8)
191+
# (B, grid_t, gh, gw, mh, mw, C, tp, ph, pw)
192+
patches = patches.permute(0, 1, 4, 7, 5, 8, 3, 2, 6, 9)
193+
182194
flatten_patches = patches.reshape(
195+
batch_size,
183196
grid_t * grid_h * grid_w,
184197
channel * temporal_patch_size * patch_size * patch_size,
185198
)
186-
processed_images.append(flatten_patches)
187-
processed_grids.append([grid_t, grid_h, grid_w])
188199

189-
pixel_values = torch.stack(processed_images, dim=0)
200+
processed_images_grouped[shape] = flatten_patches
201+
processed_grids[shape] = [[grid_t, grid_h, grid_w]] * batch_size
202+
203+
processed_images = reorder_images(processed_images_grouped, grouped_images_index)
204+
processed_grids = reorder_images(processed_grids, grouped_images_index)
205+
206+
pixel_values = torch.cat(processed_images, dim=0)
190207
image_grid_thw = torch.tensor(processed_grids)
191208

192209
return BatchFeature(
Lines changed: 254 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,254 @@
1+
# Copyright 2021 HuggingFace Inc.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
import unittest
17+
18+
import numpy as np
19+
20+
from transformers.testing_utils import require_torch, require_vision
21+
from transformers.utils import is_torch_available, is_torchvision_available, is_vision_available
22+
23+
from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs
24+
25+
26+
if is_torch_available():
27+
import torch
28+
29+
30+
if is_vision_available():
31+
from PIL import Image
32+
33+
from transformers import Glm4vImageProcessor
34+
from transformers.models.glm4v.image_processing_glm4v import smart_resize
35+
36+
if is_torchvision_available():
37+
from transformers import Glm4vImageProcessorFast
38+
39+
40+
class Glm4vImageProcessingTester:
41+
def __init__(
42+
self,
43+
parent,
44+
batch_size=7,
45+
num_channels=3,
46+
min_resolution=30,
47+
max_resolution=80,
48+
do_resize=True,
49+
size=None,
50+
do_normalize=True,
51+
image_mean=[0.5, 0.5, 0.5],
52+
image_std=[0.5, 0.5, 0.5],
53+
temporal_patch_size=2,
54+
patch_size=14,
55+
merge_size=2,
56+
):
57+
size = size if size is not None else {"longest_edge": 20, "shortest_edge": 10}
58+
self.parent = parent
59+
self.batch_size = batch_size
60+
self.num_channels = num_channels
61+
self.min_resolution = min_resolution
62+
self.max_resolution = max_resolution
63+
self.do_resize = do_resize
64+
self.size = size
65+
self.do_normalize = do_normalize
66+
self.image_mean = image_mean
67+
self.image_std = image_std
68+
self.temporal_patch_size = temporal_patch_size
69+
self.patch_size = patch_size
70+
self.merge_size = merge_size
71+
72+
def prepare_image_processor_dict(self):
73+
return {
74+
"image_mean": self.image_mean,
75+
"image_std": self.image_std,
76+
"do_normalize": self.do_normalize,
77+
"do_resize": self.do_resize,
78+
"size": self.size,
79+
"temporal_patch_size": self.temporal_patch_size,
80+
"patch_size": self.patch_size,
81+
"merge_size": self.merge_size,
82+
}
83+
84+
def expected_output_image_shape(self, images):
85+
grid_t = 1
86+
hidden_dim = self.num_channels * self.temporal_patch_size * self.patch_size * self.patch_size
87+
seq_len = 0
88+
for image in images:
89+
if isinstance(image, list) and isinstance(image[0], Image.Image):
90+
image = np.stack([np.array(frame) for frame in image])
91+
elif hasattr(image, "shape"):
92+
pass
93+
else:
94+
image = np.array(image)
95+
if hasattr(image, "shape") and len(image.shape) >= 3:
96+
if isinstance(image, np.ndarray):
97+
if len(image.shape) == 4:
98+
height, width = image.shape[1:3]
99+
elif len(image.shape) == 3:
100+
height, width = image.shape[:2]
101+
else:
102+
height, width = self.min_resolution, self.min_resolution
103+
else:
104+
height, width = image.shape[-2:]
105+
else:
106+
height, width = self.min_resolution, self.min_resolution
107+
108+
resized_height, resized_width = smart_resize(
109+
self.temporal_patch_size,
110+
height,
111+
width,
112+
factor=self.patch_size * self.merge_size,
113+
min_pixels=self.size["shortest_edge"],
114+
max_pixels=self.size["longest_edge"],
115+
)
116+
grid_h, grid_w = resized_height // self.patch_size, resized_width // self.patch_size
117+
seq_len += grid_t * grid_h * grid_w
118+
return (seq_len, hidden_dim)
119+
120+
def prepare_image_inputs(self, equal_resolution=False, numpify=False, torchify=False):
121+
return prepare_image_inputs(
122+
batch_size=self.batch_size,
123+
num_channels=self.num_channels,
124+
min_resolution=self.min_resolution,
125+
max_resolution=self.max_resolution,
126+
equal_resolution=equal_resolution,
127+
numpify=numpify,
128+
torchify=torchify,
129+
)
130+
131+
132+
@require_torch
133+
@require_vision
134+
class ViTImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
135+
image_processing_class = Glm4vImageProcessor if is_vision_available() else None
136+
fast_image_processing_class = Glm4vImageProcessorFast if is_torchvision_available() else None
137+
138+
def setUp(self):
139+
super().setUp()
140+
self.image_processor_tester = Glm4vImageProcessingTester(self)
141+
142+
@property
143+
def image_processor_dict(self):
144+
return self.image_processor_tester.prepare_image_processor_dict()
145+
146+
def test_image_processor_properties(self):
147+
for image_processing_class in self.image_processor_list:
148+
image_processing = image_processing_class(**self.image_processor_dict)
149+
self.assertTrue(hasattr(image_processing, "image_mean"))
150+
self.assertTrue(hasattr(image_processing, "image_std"))
151+
self.assertTrue(hasattr(image_processing, "do_normalize"))
152+
self.assertTrue(hasattr(image_processing, "do_resize"))
153+
self.assertTrue(hasattr(image_processing, "size"))
154+
155+
def test_image_processor_from_dict_with_kwargs(self):
156+
for image_processing_class in self.image_processor_list:
157+
image_processor = image_processing_class.from_dict(self.image_processor_dict)
158+
self.assertEqual(image_processor.size, {"shortest_edge": 10, "longest_edge": 20})
159+
160+
image_processor = image_processing_class.from_dict(
161+
self.image_processor_dict, size={"shortest_edge": 42, "longest_edge": 42}
162+
)
163+
self.assertEqual(image_processor.size, {"shortest_edge": 42, "longest_edge": 42})
164+
165+
# batch size is flattened
166+
def test_call_pil(self):
167+
for image_processing_class in self.image_processor_list:
168+
# Initialize image_processing
169+
image_processing = image_processing_class(**self.image_processor_dict)
170+
# create random PIL images
171+
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False)
172+
for image in image_inputs:
173+
self.assertIsInstance(image, Image.Image)
174+
175+
# Test not batched input
176+
encoded_images = image_processing(image_inputs[0], return_tensors="pt").pixel_values
177+
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape([image_inputs[0]])
178+
self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape)
179+
180+
# Test batched
181+
encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values
182+
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image_inputs)
183+
self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape)
184+
185+
def test_call_numpy(self):
186+
for image_processing_class in self.image_processor_list:
187+
# Initialize image_processing
188+
image_processing = image_processing_class(**self.image_processor_dict)
189+
# create random numpy tensors
190+
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, numpify=True)
191+
for image in image_inputs:
192+
self.assertIsInstance(image, np.ndarray)
193+
194+
# Test not batched input
195+
encoded_images = image_processing(image_inputs[0], return_tensors="pt").pixel_values
196+
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape([image_inputs[0]])
197+
self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape)
198+
199+
# Test batched
200+
encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values
201+
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image_inputs)
202+
self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape)
203+
204+
def test_call_pytorch(self):
205+
for image_processing_class in self.image_processor_list:
206+
# Initialize image_processing
207+
image_processing = image_processing_class(**self.image_processor_dict)
208+
# create random PyTorch tensors
209+
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, torchify=True)
210+
211+
for image in image_inputs:
212+
self.assertIsInstance(image, torch.Tensor)
213+
214+
# Test not batched input
215+
encoded_images = image_processing(image_inputs[0], return_tensors="pt").pixel_values
216+
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape([image_inputs[0]])
217+
self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape)
218+
219+
# Test batched
220+
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image_inputs)
221+
encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values
222+
self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape)
223+
224+
def test_call_numpy_4_channels(self):
225+
for image_processing_class in self.image_processor_list:
226+
# Test that can process images which have an arbitrary number of channels
227+
# Initialize image_processing
228+
image_processor = image_processing_class(**self.image_processor_dict)
229+
230+
# create random numpy tensors
231+
self.image_processor_tester.num_channels = 4
232+
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, numpify=True)
233+
234+
# Test not batched input
235+
encoded_images = image_processor(
236+
image_inputs[0],
237+
return_tensors="pt",
238+
input_data_format="channels_last",
239+
image_mean=0,
240+
image_std=1,
241+
).pixel_values
242+
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape([image_inputs[0]])
243+
self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape)
244+
245+
# Test batched
246+
encoded_images = image_processor(
247+
image_inputs,
248+
return_tensors="pt",
249+
input_data_format="channels_last",
250+
image_mean=0,
251+
image_std=1,
252+
).pixel_values
253+
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image_inputs)
254+
self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape)

0 commit comments

Comments
 (0)