Skip to content

Commit

Permalink
Use numpy
Browse files Browse the repository at this point in the history
  • Loading branch information
NielsRogge committed Aug 31, 2024
1 parent d4d38f2 commit b4d527e
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 41 deletions.
82 changes: 41 additions & 41 deletions src/transformers/models/llava/image_processing_llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,11 @@
from typing import Dict, List, Optional, Tuple, Union

import numpy as np
from PIL import Image

from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
from ...image_transforms import (
PaddingMode,
convert_to_rgb,
get_resize_output_image_size,
pad,
resize,
to_channel_dimension_format,
)
Expand Down Expand Up @@ -157,25 +154,6 @@ def __init__(
# `shortest_edge` key.
delattr(self, "use_square_size")

def pad_to_square_original(
self, image: Image.Image, background_color: Union[int, Tuple[int, int, int]] = 0
) -> Image.Image:
"""
Pads an image to make it square.
"""
print("Image size:", image.size)
width, height = image.size
if width == height:
return image
elif width > height:
result = Image.new(image.mode, (width, width), background_color)
result.paste(image, (0, (width - height) // 2))
return result
else:
result = Image.new(image.mode, (height, height), background_color)
result.paste(image, ((height - width) // 2, 0))
return result

def pad_to_square(
self,
image: np.array,
Expand All @@ -199,27 +177,42 @@ def pad_to_square(
Returns:
`np.ndarray`: The padded image.
"""
height, width = get_image_size(image, input_data_format)
h, w = get_image_size(image, input_data_format)
c = image.shape[0] if input_data_format == ChannelDimension.FIRST else image.shape[-1]

if height == width:
if h == w:
return image

max_dim = max(height, width)
pad_height = max_dim - height
pad_width = max_dim - width

padding = (
(pad_height // 2, pad_height - pad_height // 2),
(pad_width // 2, pad_width - pad_width // 2),
)

return pad(
image=image,
padding=padding,
mode=PaddingMode.CONSTANT,
constant_values=background_color,
input_data_format=input_data_format,
)
max_dim = max(h, w)

# Ensure background_color is the correct shape
if isinstance(background_color, (int, float)):
background_color = [background_color] * c
elif len(background_color) != c:
raise ValueError(f"background_color must have {c} elements to match the number of channels")

if input_data_format == ChannelDimension.FIRST:
result = np.zeros((c, max_dim, max_dim), dtype=image.dtype)
for i, color in enumerate(background_color):
result[i, :, :] = color
if w > h:
start = (max_dim - h) // 2
result[:, start : start + h, :] = image
else:
start = (max_dim - w) // 2
result[:, :, start : start + w] = image
else:
result = np.zeros((max_dim, max_dim, c), dtype=image.dtype)
for i, color in enumerate(background_color):
result[:, :, i] = color
if w > h:
start = (max_dim - h) // 2
result[start : start + h, :, :] = image
else:
start = (max_dim - w) // 2
result[:, start : start + w, :] = image

return result

def resize(
self,
Expand Down Expand Up @@ -368,7 +361,6 @@ def preprocess(
"torch.Tensor, tf.Tensor or jax.ndarray."
)
validate_preprocess_arguments(
do_pad=do_pad,
do_rescale=do_rescale,
rescale_factor=rescale_factor,
do_normalize=do_normalize,
Expand Down Expand Up @@ -397,6 +389,10 @@ def preprocess(
# We assume that all images have the same channel dimension format.
input_data_format = infer_channel_dimension_format(images[0])

print("Input data format:", input_data_format)
for image in images:
print(image.shape)

if do_pad:
images = [
self.pad_to_square(
Expand All @@ -407,6 +403,10 @@ def preprocess(
for image in images
]

print("After padding:")
for image in images:
print(image.shape)

if do_resize:
images = [
self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format)
Expand Down
33 changes: 33 additions & 0 deletions tests/models/llava/test_image_processing_llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@


import unittest
from typing import Tuple, Union

import numpy as np

from transformers.testing_utils import require_torch, require_vision
from transformers.utils import is_vision_available
Expand All @@ -23,6 +26,8 @@


if is_vision_available():
from PIL import Image

from transformers import LlavaImageProcessor


Expand Down Expand Up @@ -128,3 +133,31 @@ def test_image_processor_from_dict_with_kwargs(self):
image_processor = self.image_processing_class.from_dict(self.image_processor_dict, size=42, crop_size=84)
self.assertEqual(image_processor.size, {"shortest_edge": 42})
self.assertEqual(image_processor.crop_size, {"height": 84, "width": 84})

# Ignore copy
def test_padding(self):
# taken from original implementation
def pad_to_square_original(
image: Image.Image, background_color: Union[int, Tuple[int, int, int]] = 0
) -> Image.Image:
width, height = image.size
if width == height:
return image
elif width > height:
result = Image.new(image.mode, (width, width), background_color)
result.paste(image, (0, (width - height) // 2))
return result
else:
result = Image.new(image.mode, (height, height), background_color)
result.paste(image, ((height - width) // 2, 0))
return result

image_processor = self.image_processing_class.from_dict(self.image_processor_dict)

# create random numpy tensors
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, numpify=True)
for image in image_inputs:
padded_image = image_processor.pad_to_square(image)
padded_image_original = pad_to_square_original(image)

self.assertTrue(np.assert_array_equal(padded_image, padded_image_original))

0 comments on commit b4d527e

Please sign in to comment.