Skip to content

Commit

Permalink
fix padding issue
Browse files Browse the repository at this point in the history
  • Loading branch information
noskill committed Jan 21, 2025
1 parent 35e46a1 commit 2ddd394
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 28 deletions.
23 changes: 16 additions & 7 deletions multigen/pipes.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,14 @@ def prepare_inputs(self, inputs):
self.try_set_scheduler(inputs)
return kwargs

@property
def pad(self):
pad = 8
if hasattr(self.pipe, 'image_processor'):
if hasattr(self.pipe.image_processor, 'vae_scale_factor'):
pad = self.pipe.image_processor.vae_scale_factor
return pad


class Prompt2ImPipe(BasePipe):
"""
Expand Down Expand Up @@ -446,7 +454,7 @@ def setup(self, fimage, image=None, strength=0.75,
self._input_image = self.scale_image(self._input_image, scale)
self._original_size = self._input_image.size
logging.debug("origin image size {self._original_size}")
self._input_image = util.pad_image_to_multiple_of_8(self._input_image)
self._input_image = util.pad_image_to_multiple(self._input_image, self.pad)
self.pipe_params.update({
"width": self._input_image.width if width is None else width,
"height": self._input_image.height if height is None else height,
Expand Down Expand Up @@ -599,12 +607,13 @@ def setup(self, image=None, image_painted=None, mask=None, blur=4,
input_image = self._image_painted if self._image_painted is not None else self._original_image

super().setup(fimage=None, image=input_image, scale=scale, **kwargs)

if self._original_image is not None:
self._original_image = self.scale_image(self._original_image, scale)
self._original_image = util.pad_image_to_multiple_of_8(self._original_image)
self._original_image = util.pad_image_to_multiple(self._original_image, self.pad)
if self._image_painted is not None:
self._image_painted = self.scale_image(self._image_painted, scale)
self._image_painted = util.pad_image_to_multiple_of_8(self._image_painted)
self._image_painted = util.pad_image_to_multiple(self._image_painted, self.pad)

# there are two options:
# 1. mask is provided
Expand All @@ -621,7 +630,7 @@ def setup(self, image=None, image_painted=None, mask=None, blur=4,
pil_mask = Image.fromarray(mask)
if pil_mask.mode != "L":
pil_mask = pil_mask.convert("L")
pil_mask = util.pad_image_to_multiple_of_8(pil_mask)
pil_mask = util.pad_image_to_multiple(pil_mask, self.pad)
self._mask = pil_mask
self._mask_blur = self.blur_mask(pil_mask, blur)
self._mask_compose = self.blur_mask(pil_mask.crop((0, 0, self._original_size[0], self._original_size[1]))
Expand Down Expand Up @@ -867,7 +876,7 @@ def setup(self, fimage, width=None, height=None,
image = Image.open(fimage).convert("RGB") if image is None else image
self._original_size = image.size
self._use_input_size = width is None or height is None
image = util.pad_image_to_multiple_of_8(image)
image = util.pad_image_to_multiple(image, self.pad)
self._condition_image = [image]
self._input_image = [image]
if cscales is None:
Expand Down Expand Up @@ -922,7 +931,7 @@ def gen(self, inputs):
for image in self.pipe(**inputs).images:
result = image.crop((0, 0, self._original_size[0] if self._use_input_size else inputs.get('height'),
self._original_size[1] if self._use_input_size else inputs.get('width') ))
res.append(image)
res.append(result)
return res


Expand Down Expand Up @@ -1044,7 +1053,7 @@ def _proc_cimg(self, oriImg):
condition_image += [Image.fromarray(formatted)]
else:
condition_image += [Image.fromarray(oriImg)]
return condition_image
return [c.resize((oriImg.shape[1], oriImg.shape[0])) for c in condition_image]


class InpaintingPipe(MaskedIm2ImPipe):
Expand Down
55 changes: 34 additions & 21 deletions multigen/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,37 +28,50 @@ def create_exif_metadata(im: Image, custom_metadata):
return exif


def pad_image_to_multiple_of_8(image: Image) -> Image:
def pad_image_to_multiple(image: Image, padding_size: int = 8) -> Image:
"""
Pads the input image by repeating the bottom or right-most column of pixels
so that the height and width of the image is divisible by 8.
Pads the input image by repeating the bottom and right-most rows and columns of pixels
so that its dimensions are divisible by 'padding_size'.
Args:
image (Image): The input PIL image.
image (Image): The input PIL Image.
padding_size (int): The multiple to which dimensions are padded.
Returns:
Image: The padded PIL image.
Image: The padded PIL Image.
"""

# Calculate the new dimensions
new_width = (image.width + 7) // 8 * 8
new_height = (image.height + 7) // 8 * 8
new_width = ((image.width + padding_size - 1) // padding_size) * padding_size
new_height = ((image.height + padding_size - 1) // padding_size) * padding_size

# Calculate padding amounts
pad_right = new_width - image.width
pad_bottom = new_height - image.height

# Create a new image with the new dimensions and paste the original image onto it
# Create a new image with the new dimensions
padded_image = Image.new(image.mode, (new_width, new_height))
padded_image.paste(image, (0, 0))

# Repeat the right-most column of pixels to fill the horizontal padding
for x in range(new_width - image.width):
box = (image.width + x, 0, image.width + x + 1, image.height)
region = image.crop((image.width - 1, 0, image.width, image.height))
padded_image.paste(region, box)

# Repeat the bottom-most row of pixels to fill the vertical padding
for y in range(new_height - image.height):
box = (0, image.height + y, image.width, image.height + y + 1)
region = image.crop((0, image.height - 1, image.width, image.height))
padded_image.paste(region, box)
# Check if padding is needed
if pad_right > 0 or pad_bottom > 0:
# Get the last column and row
if pad_right > 0:
last_column = image.crop((image.width - 1, 0, image.width, image.height))
# Resize the last column to fill the right padding area
right_padding = last_column.resize((pad_right, image.height), Image.NEAREST)
padded_image.paste(right_padding, (image.width, 0))

if pad_bottom > 0:
last_row = image.crop((0, image.height - 1, image.width, image.height))
# Resize the last row to fill the bottom padding area
bottom_padding = last_row.resize((image.width, pad_bottom), Image.NEAREST)
padded_image.paste(bottom_padding, (0, image.height))

if pad_right > 0 and pad_bottom > 0:
# Fill the bottom-right corner
last_pixel = image.getpixel((image.width - 1, image.height - 1))
corner = Image.new(image.mode, (pad_right, pad_bottom), last_pixel)
padded_image.paste(corner, (image.width, image.height))

return padded_image

Expand Down Expand Up @@ -97,7 +110,7 @@ def awailable_ram():

def quantize(pipe, dtype=qfloat8):
components = ['unet', 'transformer', 'text_encoder', 'text_encoder_2', 'vae']

for component in components:
if hasattr(pipe, component):
component_obj = getattr(pipe, component)
Expand Down

0 comments on commit 2ddd394

Please sign in to comment.