diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index ab59b418657..218e05a9866 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -79,6 +79,12 @@ SAMPLER_NAME_VALUES = Literal[tuple(SCHEDULER_MAP.keys())] +# HACK: Many nodes are currently hard-coded to use a fixed latent scale factor of 8. This is fragile, and will need to +# be addressed if future models use a different latent scale factor. Also, note that there may be places where the scale +# factor is hard-coded to a literal '8' rather than using this constant. +# The ratio of image:latent dimensions is LATENT_SCALE_FACTOR:1, or 8:1. +LATENT_SCALE_FACTOR = 8 + @invocation_output("scheduler_output") class SchedulerOutput(BaseInvocationOutput): @@ -394,9 +400,9 @@ def prep_control_data( exit_stack: ExitStack, do_classifier_free_guidance: bool = True, ) -> List[ControlNetData]: - # assuming fixed dimensional scaling of 8:1 for image:latents - control_height_resize = latents_shape[2] * 8 - control_width_resize = latents_shape[3] * 8 + # Assuming fixed dimensional scaling of LATENT_SCALE_FACTOR. + control_height_resize = latents_shape[2] * LATENT_SCALE_FACTOR + control_width_resize = latents_shape[3] * LATENT_SCALE_FACTOR if control_input is None: control_list = None elif isinstance(control_input, list) and len(control_input) == 0: @@ -909,12 +915,12 @@ class ResizeLatentsInvocation(BaseInvocation): ) width: int = InputField( ge=64, - multiple_of=8, + multiple_of=LATENT_SCALE_FACTOR, description=FieldDescriptions.width, ) height: int = InputField( ge=64, - multiple_of=8, + multiple_of=LATENT_SCALE_FACTOR, description=FieldDescriptions.width, ) mode: LATENTS_INTERPOLATION_MODE = InputField(default="bilinear", description=FieldDescriptions.interp_mode) @@ -928,7 +934,7 @@ def invoke(self, context: InvocationContext) -> LatentsOutput: resized_latents = torch.nn.functional.interpolate( latents.to(device), - size=(self.height // 8, self.width // 8), + size=(self.height // LATENT_SCALE_FACTOR, self.width // LATENT_SCALE_FACTOR), mode=self.mode, antialias=self.antialias if self.mode in ["bilinear", "bicubic"] else False, ) @@ -1166,3 +1172,60 @@ def slerp(t, v0, v1, DOT_THRESHOLD=0.9995): # context.services.latents.set(name, resized_latents) context.services.latents.save(name, blended_latents) return build_latents_output(latents_name=name, latents=blended_latents) + + +# The Crop Latents node was copied from @skunkworxdark's implementation here: +# https://github.com/skunkworxdark/XYGrid_nodes/blob/74647fa9c1fa57d317a94bd43ca689af7f0aae5e/images_to_grids.py#L1117C1-L1167C80 +@invocation( + "crop_latents", + title="Crop Latents", + tags=["latents", "crop"], + category="latents", + version="1.0.0", +) +# TODO(ryand): Named `CropLatentsCoreInvocation` to prevent a conflict with custom node `CropLatentsInvocation`. +# Currently, if the class names conflict then 'GET /openapi.json' fails. +class CropLatentsCoreInvocation(BaseInvocation): + """Crops a latent-space tensor to a box specified in image-space. The box dimensions and coordinates must be + divisible by the latent scale factor of 8. + """ + + latents: LatentsField = InputField( + description=FieldDescriptions.latents, + input=Input.Connection, + ) + x: int = InputField( + ge=0, + multiple_of=LATENT_SCALE_FACTOR, + description="The left x coordinate (in px) of the crop rectangle in image space. This value will be converted to a dimension in latent space.", + ) + y: int = InputField( + ge=0, + multiple_of=LATENT_SCALE_FACTOR, + description="The top y coordinate (in px) of the crop rectangle in image space. This value will be converted to a dimension in latent space.", + ) + width: int = InputField( + ge=1, + multiple_of=LATENT_SCALE_FACTOR, + description="The width (in px) of the crop rectangle in image space. This value will be converted to a dimension in latent space.", + ) + height: int = InputField( + ge=1, + multiple_of=LATENT_SCALE_FACTOR, + description="The height (in px) of the crop rectangle in image space. This value will be converted to a dimension in latent space.", + ) + + def invoke(self, context: InvocationContext) -> LatentsOutput: + latents = context.services.latents.get(self.latents.latents_name) + + x1 = self.x // LATENT_SCALE_FACTOR + y1 = self.y // LATENT_SCALE_FACTOR + x2 = x1 + (self.width // LATENT_SCALE_FACTOR) + y2 = y1 + (self.height // LATENT_SCALE_FACTOR) + + cropped_latents = latents[..., y1:y2, x1:x2] + + name = f"{context.graph_execution_state_id}__{self.id}" + context.services.latents.save(name, cropped_latents) + + return build_latents_output(latents_name=name, latents=cropped_latents) diff --git a/invokeai/app/invocations/tiles.py b/invokeai/app/invocations/tiles.py new file mode 100644 index 00000000000..3055c1baaeb --- /dev/null +++ b/invokeai/app/invocations/tiles.py @@ -0,0 +1,181 @@ +import numpy as np +from PIL import Image +from pydantic import BaseModel + +from invokeai.app.invocations.baseinvocation import ( + BaseInvocation, + BaseInvocationOutput, + InputField, + InvocationContext, + OutputField, + WithMetadata, + WithWorkflow, + invocation, + invocation_output, +) +from invokeai.app.invocations.primitives import ImageField, ImageOutput +from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin +from invokeai.backend.tiles.tiles import calc_tiles_with_overlap, merge_tiles_with_linear_blending +from invokeai.backend.tiles.utils import Tile + + +class TileWithImage(BaseModel): + tile: Tile + image: ImageField + + +@invocation_output("calculate_image_tiles_output") +class CalculateImageTilesOutput(BaseInvocationOutput): + tiles: list[Tile] = OutputField(description="The tiles coordinates that cover a particular image shape.") + + +@invocation("calculate_image_tiles", title="Calculate Image Tiles", tags=["tiles"], category="tiles", version="1.0.0") +class CalculateImageTilesInvocation(BaseInvocation): + """Calculate the coordinates and overlaps of tiles that cover a target image shape.""" + + image_width: int = InputField(ge=1, default=1024, description="The image width, in pixels, to calculate tiles for.") + image_height: int = InputField( + ge=1, default=1024, description="The image height, in pixels, to calculate tiles for." + ) + tile_width: int = InputField(ge=1, default=576, description="The tile width, in pixels.") + tile_height: int = InputField(ge=1, default=576, description="The tile height, in pixels.") + overlap: int = InputField( + ge=0, + default=128, + description="The target overlap, in pixels, between adjacent tiles. Adjacent tiles will overlap by at least this amount", + ) + + def invoke(self, context: InvocationContext) -> CalculateImageTilesOutput: + tiles = calc_tiles_with_overlap( + image_height=self.image_height, + image_width=self.image_width, + tile_height=self.tile_height, + tile_width=self.tile_width, + overlap=self.overlap, + ) + return CalculateImageTilesOutput(tiles=tiles) + + +@invocation_output("tile_to_properties_output") +class TileToPropertiesOutput(BaseInvocationOutput): + coords_left: int = OutputField(description="Left coordinate of the tile relative to its parent image.") + coords_right: int = OutputField(description="Right coordinate of the tile relative to its parent image.") + coords_top: int = OutputField(description="Top coordinate of the tile relative to its parent image.") + coords_bottom: int = OutputField(description="Bottom coordinate of the tile relative to its parent image.") + + # HACK: The width and height fields are 'meta' fields that can easily be calculated from the other fields on this + # object. Including redundant fields that can cheaply/easily be re-calculated goes against conventional API design + # principles. These fields are included, because 1) they are often useful in tiled workflows, and 2) they are + # difficult to calculate in a workflow (even though it's just a couple of subtraction nodes the graph gets + # surprisingly complicated). + width: int = OutputField(description="The width of the tile. Equal to coords_right - coords_left.") + height: int = OutputField(description="The height of the tile. Equal to coords_bottom - coords_top.") + + overlap_top: int = OutputField(description="Overlap between this tile and its top neighbor.") + overlap_bottom: int = OutputField(description="Overlap between this tile and its bottom neighbor.") + overlap_left: int = OutputField(description="Overlap between this tile and its left neighbor.") + overlap_right: int = OutputField(description="Overlap between this tile and its right neighbor.") + + +@invocation("tile_to_properties", title="Tile to Properties", tags=["tiles"], category="tiles", version="1.0.0") +class TileToPropertiesInvocation(BaseInvocation): + """Split a Tile into its individual properties.""" + + tile: Tile = InputField(description="The tile to split into properties.") + + def invoke(self, context: InvocationContext) -> TileToPropertiesOutput: + return TileToPropertiesOutput( + coords_left=self.tile.coords.left, + coords_right=self.tile.coords.right, + coords_top=self.tile.coords.top, + coords_bottom=self.tile.coords.bottom, + width=self.tile.coords.right - self.tile.coords.left, + height=self.tile.coords.bottom - self.tile.coords.top, + overlap_top=self.tile.overlap.top, + overlap_bottom=self.tile.overlap.bottom, + overlap_left=self.tile.overlap.left, + overlap_right=self.tile.overlap.right, + ) + + +@invocation_output("pair_tile_image_output") +class PairTileImageOutput(BaseInvocationOutput): + tile_with_image: TileWithImage = OutputField(description="A tile description with its corresponding image.") + + +@invocation("pair_tile_image", title="Pair Tile with Image", tags=["tiles"], category="tiles", version="1.0.0") +class PairTileImageInvocation(BaseInvocation): + """Pair an image with its tile properties.""" + + # TODO(ryand): The only reason that PairTileImage is needed is because the iterate/collect nodes don't preserve + # order. Can this be fixed? + + image: ImageField = InputField(description="The tile image.") + tile: Tile = InputField(description="The tile properties.") + + def invoke(self, context: InvocationContext) -> PairTileImageOutput: + return PairTileImageOutput( + tile_with_image=TileWithImage( + tile=self.tile, + image=self.image, + ) + ) + + +@invocation("merge_tiles_to_image", title="Merge Tiles to Image", tags=["tiles"], category="tiles", version="1.0.0") +class MergeTilesToImageInvocation(BaseInvocation, WithMetadata, WithWorkflow): + """Merge multiple tile images into a single image.""" + + # Inputs + tiles_with_images: list[TileWithImage] = InputField(description="A list of tile images with tile properties.") + blend_amount: int = InputField( + ge=0, + description="The amount to blend adjacent tiles in pixels. Must be <= the amount of overlap between adjacent tiles.", + ) + + def invoke(self, context: InvocationContext) -> ImageOutput: + images = [twi.image for twi in self.tiles_with_images] + tiles = [twi.tile for twi in self.tiles_with_images] + + # Infer the output image dimensions from the max/min tile limits. + height = 0 + width = 0 + for tile in tiles: + height = max(height, tile.coords.bottom) + width = max(width, tile.coords.right) + + # Get all tile images for processing. + # TODO(ryand): It pains me that we spend time PNG decoding each tile from disk when they almost certainly + # existed in memory at an earlier point in the graph. + tile_np_images: list[np.ndarray] = [] + for image in images: + pil_image = context.services.images.get_pil_image(image.image_name) + pil_image = pil_image.convert("RGB") + tile_np_images.append(np.array(pil_image)) + + # Prepare the output image buffer. + # Check the first tile to determine how many image channels are expected in the output. + channels = tile_np_images[0].shape[-1] + dtype = tile_np_images[0].dtype + np_image = np.zeros(shape=(height, width, channels), dtype=dtype) + + merge_tiles_with_linear_blending( + dst_image=np_image, tiles=tiles, tile_images=tile_np_images, blend_amount=self.blend_amount + ) + pil_image = Image.fromarray(np_image) + + image_dto = context.services.images.create( + image=pil_image, + image_origin=ResourceOrigin.INTERNAL, + image_category=ImageCategory.GENERAL, + node_id=self.id, + session_id=context.graph_execution_state_id, + is_intermediate=self.is_intermediate, + metadata=self.metadata, + workflow=self.workflow, + ) + return ImageOutput( + image=ImageField(image_name=image_dto.image_name), + width=image_dto.width, + height=image_dto.height, + ) diff --git a/invokeai/backend/tiles/__init__.py b/invokeai/backend/tiles/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/invokeai/backend/tiles/tiles.py b/invokeai/backend/tiles/tiles.py new file mode 100644 index 00000000000..3a678d825ec --- /dev/null +++ b/invokeai/backend/tiles/tiles.py @@ -0,0 +1,201 @@ +import math +from typing import Union + +import numpy as np + +from invokeai.backend.tiles.utils import TBLR, Tile, paste + + +def calc_tiles_with_overlap( + image_height: int, image_width: int, tile_height: int, tile_width: int, overlap: int = 0 +) -> list[Tile]: + """Calculate the tile coordinates for a given image shape under a simple tiling scheme with overlaps. + + Args: + image_height (int): The image height in px. + image_width (int): The image width in px. + tile_height (int): The tile height in px. All tiles will have this height. + tile_width (int): The tile width in px. All tiles will have this width. + overlap (int, optional): The target overlap between adjacent tiles. If the tiles do not evenly cover the image + shape, then the last row/column of tiles will overlap more than this. Defaults to 0. + + Returns: + list[Tile]: A list of tiles that cover the image shape. Ordered from left-to-right, top-to-bottom. + """ + assert image_height >= tile_height + assert image_width >= tile_width + assert overlap < tile_height + assert overlap < tile_width + + non_overlap_per_tile_height = tile_height - overlap + non_overlap_per_tile_width = tile_width - overlap + + num_tiles_y = math.ceil((image_height - overlap) / non_overlap_per_tile_height) + num_tiles_x = math.ceil((image_width - overlap) / non_overlap_per_tile_width) + + # tiles[y * num_tiles_x + x] is the tile for the y'th row, x'th column. + tiles: list[Tile] = [] + + # Calculate tile coordinates. (Ignore overlap values for now.) + for tile_idx_y in range(num_tiles_y): + for tile_idx_x in range(num_tiles_x): + tile = Tile( + coords=TBLR( + top=tile_idx_y * non_overlap_per_tile_height, + bottom=tile_idx_y * non_overlap_per_tile_height + tile_height, + left=tile_idx_x * non_overlap_per_tile_width, + right=tile_idx_x * non_overlap_per_tile_width + tile_width, + ), + overlap=TBLR(top=0, bottom=0, left=0, right=0), + ) + + if tile.coords.bottom > image_height: + # If this tile would go off the bottom of the image, shift it so that it is aligned with the bottom + # of the image. + tile.coords.bottom = image_height + tile.coords.top = image_height - tile_height + + if tile.coords.right > image_width: + # If this tile would go off the right edge of the image, shift it so that it is aligned with the + # right edge of the image. + tile.coords.right = image_width + tile.coords.left = image_width - tile_width + + tiles.append(tile) + + def get_tile_or_none(idx_y: int, idx_x: int) -> Union[Tile, None]: + if idx_y < 0 or idx_y > num_tiles_y or idx_x < 0 or idx_x > num_tiles_x: + return None + return tiles[idx_y * num_tiles_x + idx_x] + + # Iterate over tiles again and calculate overlaps. + for tile_idx_y in range(num_tiles_y): + for tile_idx_x in range(num_tiles_x): + cur_tile = get_tile_or_none(tile_idx_y, tile_idx_x) + top_neighbor_tile = get_tile_or_none(tile_idx_y - 1, tile_idx_x) + left_neighbor_tile = get_tile_or_none(tile_idx_y, tile_idx_x - 1) + + assert cur_tile is not None + + # Update cur_tile top-overlap and corresponding top-neighbor bottom-overlap. + if top_neighbor_tile is not None: + cur_tile.overlap.top = max(0, top_neighbor_tile.coords.bottom - cur_tile.coords.top) + top_neighbor_tile.overlap.bottom = cur_tile.overlap.top + + # Update cur_tile left-overlap and corresponding left-neighbor right-overlap. + if left_neighbor_tile is not None: + cur_tile.overlap.left = max(0, left_neighbor_tile.coords.right - cur_tile.coords.left) + left_neighbor_tile.overlap.right = cur_tile.overlap.left + + return tiles + + +def merge_tiles_with_linear_blending( + dst_image: np.ndarray, tiles: list[Tile], tile_images: list[np.ndarray], blend_amount: int +): + """Merge a set of image tiles into `dst_image` with linear blending between the tiles. + + We expect every tile edge to either: + 1) have an overlap of 0, because it is aligned with the image edge, or + 2) have an overlap >= blend_amount. + If neither of these conditions are satisfied, we raise an exception. + + The linear blending is centered at the halfway point of the overlap between adjacent tiles. + + Args: + dst_image (np.ndarray): The destination image. Shape: (H, W, C). + tiles (list[Tile]): The list of tiles describing the locations of the respective `tile_images`. + tile_images (list[np.ndarray]): The tile images to merge into `dst_image`. + blend_amount (int): The amount of blending (in px) between adjacent overlapping tiles. + """ + # Sort tiles and images first by left x coordinate, then by top y coordinate. During tile processing, we want to + # iterate over tiles left-to-right, top-to-bottom. + tiles_and_images = list(zip(tiles, tile_images, strict=True)) + tiles_and_images = sorted(tiles_and_images, key=lambda x: x[0].coords.left) + tiles_and_images = sorted(tiles_and_images, key=lambda x: x[0].coords.top) + + # Organize tiles into rows. + tile_and_image_rows: list[list[tuple[Tile, np.ndarray]]] = [] + cur_tile_and_image_row: list[tuple[Tile, np.ndarray]] = [] + first_tile_in_cur_row, _ = tiles_and_images[0] + for tile_and_image in tiles_and_images: + tile, _ = tile_and_image + if not ( + tile.coords.top == first_tile_in_cur_row.coords.top + and tile.coords.bottom == first_tile_in_cur_row.coords.bottom + ): + # Store the previous row, and start a new one. + tile_and_image_rows.append(cur_tile_and_image_row) + cur_tile_and_image_row = [] + first_tile_in_cur_row, _ = tile_and_image + + cur_tile_and_image_row.append(tile_and_image) + tile_and_image_rows.append(cur_tile_and_image_row) + + # Prepare 1D linear gradients for blending. + gradient_left_x = np.linspace(start=0.0, stop=1.0, num=blend_amount) + gradient_top_y = np.linspace(start=0.0, stop=1.0, num=blend_amount) + # Convert shape: (blend_amount, ) -> (blend_amount, 1). The extra dimension enables the gradient to be applied + # to a 2D image via broadcasting. Note that no additional dimension is needed on gradient_left_x for + # broadcasting to work correctly. + gradient_top_y = np.expand_dims(gradient_top_y, axis=1) + + for tile_and_image_row in tile_and_image_rows: + first_tile_in_row, _ = tile_and_image_row[0] + row_height = first_tile_in_row.coords.bottom - first_tile_in_row.coords.top + row_image = np.zeros((row_height, dst_image.shape[1], dst_image.shape[2]), dtype=dst_image.dtype) + + # Blend the tiles in the row horizontally. + for tile, tile_image in tile_and_image_row: + # We expect the tiles to be ordered left-to-right. For each tile, we construct a mask that applies linear + # blending to the left of the current tile. The inverse linear blending is automatically applied to the + # right of the tiles that have already been pasted by the paste(...) operation. + tile_height, tile_width, _ = tile_image.shape + mask = np.ones(shape=(tile_height, tile_width), dtype=np.float64) + + # Left blending: + if tile.overlap.left > 0: + assert tile.overlap.left >= blend_amount + # Center the blending gradient in the middle of the overlap. + blend_start_left = tile.overlap.left // 2 - blend_amount // 2 + # The region left of the blending region is masked completely. + mask[:, :blend_start_left] = 0.0 + # Apply the blend gradient to the mask. + mask[:, blend_start_left : blend_start_left + blend_amount] = gradient_left_x + # For visual debugging: + # tile_image[:, blend_start_left : blend_start_left + blend_amount] = 0 + + paste( + dst_image=row_image, + src_image=tile_image, + box=TBLR( + top=0, bottom=tile.coords.bottom - tile.coords.top, left=tile.coords.left, right=tile.coords.right + ), + mask=mask, + ) + + # Blend the row into the dst_image vertically. + # We construct a mask that applies linear blending to the top of the current row. The inverse linear blending is + # automatically applied to the bottom of the tiles that have already been pasted by the paste(...) operation. + mask = np.ones(shape=(row_image.shape[0], row_image.shape[1]), dtype=np.float64) + # Top blending: + # (See comments under 'Left blending' for an explanation of the logic.) + # We assume that the entire row has the same vertical overlaps as the first_tile_in_row. + if first_tile_in_row.overlap.top > 0: + assert first_tile_in_row.overlap.top >= blend_amount + blend_start_top = first_tile_in_row.overlap.top // 2 - blend_amount // 2 + mask[:blend_start_top, :] = 0.0 + mask[blend_start_top : blend_start_top + blend_amount, :] = gradient_top_y + # For visual debugging: + # row_image[blend_start_top : blend_start_top + blend_amount, :] = 0 + paste( + dst_image=dst_image, + src_image=row_image, + box=TBLR( + top=first_tile_in_row.coords.top, + bottom=first_tile_in_row.coords.bottom, + left=0, + right=row_image.shape[1], + ), + mask=mask, + ) diff --git a/invokeai/backend/tiles/utils.py b/invokeai/backend/tiles/utils.py new file mode 100644 index 00000000000..4ad40ffa358 --- /dev/null +++ b/invokeai/backend/tiles/utils.py @@ -0,0 +1,47 @@ +from typing import Optional + +import numpy as np +from pydantic import BaseModel, Field + + +class TBLR(BaseModel): + top: int + bottom: int + left: int + right: int + + def __eq__(self, other): + return ( + self.top == other.top + and self.bottom == other.bottom + and self.left == other.left + and self.right == other.right + ) + + +class Tile(BaseModel): + coords: TBLR = Field(description="The coordinates of this tile relative to its parent image.") + overlap: TBLR = Field(description="The amount of overlap with adjacent tiles on each side of this tile.") + + def __eq__(self, other): + return self.coords == other.coords and self.overlap == other.overlap + + +def paste(dst_image: np.ndarray, src_image: np.ndarray, box: TBLR, mask: Optional[np.ndarray] = None): + """Paste a source image into a destination image. + + Args: + dst_image (torch.Tensor): The destination image to paste into. Shape: (H, W, C). + src_image (torch.Tensor): The source image to paste. Shape: (H, W, C). H and W must be compatible with 'box'. + box (TBLR): Box defining the region in the 'dst_image' where 'src_image' will be pasted. + mask (Optional[torch.Tensor]): A mask that defines the blending between 'src_image' and 'dst_image'. + Range: [0.0, 1.0], Shape: (H, W). The output is calculate per-pixel according to + `src * mask + dst * (1 - mask)`. + """ + + if mask is None: + dst_image[box.top : box.bottom, box.left : box.right] = src_image + else: + mask = np.expand_dims(mask, -1) + dst_image_box = dst_image[box.top : box.bottom, box.left : box.right] + dst_image[box.top : box.bottom, box.left : box.right] = src_image * mask + dst_image_box * (1.0 - mask) diff --git a/tests/backend/tiles/test_tiles.py b/tests/backend/tiles/test_tiles.py new file mode 100644 index 00000000000..353e65d3368 --- /dev/null +++ b/tests/backend/tiles/test_tiles.py @@ -0,0 +1,224 @@ +import numpy as np +import pytest + +from invokeai.backend.tiles.tiles import calc_tiles_with_overlap, merge_tiles_with_linear_blending +from invokeai.backend.tiles.utils import TBLR, Tile + +#################################### +# Test calc_tiles_with_overlap(...) +#################################### + + +def test_calc_tiles_with_overlap_single_tile(): + """Test calc_tiles_with_overlap() behavior when a single tile covers the image.""" + tiles = calc_tiles_with_overlap(image_height=512, image_width=1024, tile_height=512, tile_width=1024, overlap=64) + + expected_tiles = [ + Tile(coords=TBLR(top=0, bottom=512, left=0, right=1024), overlap=TBLR(top=0, bottom=0, left=0, right=0)) + ] + + assert tiles == expected_tiles + + +def test_calc_tiles_with_overlap_evenly_divisible(): + """Test calc_tiles_with_overlap() behavior when the image is evenly covered by multiple tiles.""" + # Parameters chosen so that image is evenly covered by 2 rows, 3 columns of tiles. + tiles = calc_tiles_with_overlap(image_height=576, image_width=1600, tile_height=320, tile_width=576, overlap=64) + + expected_tiles = [ + # Row 0 + Tile(coords=TBLR(top=0, bottom=320, left=0, right=576), overlap=TBLR(top=0, bottom=64, left=0, right=64)), + Tile(coords=TBLR(top=0, bottom=320, left=512, right=1088), overlap=TBLR(top=0, bottom=64, left=64, right=64)), + Tile(coords=TBLR(top=0, bottom=320, left=1024, right=1600), overlap=TBLR(top=0, bottom=64, left=64, right=0)), + # Row 1 + Tile(coords=TBLR(top=256, bottom=576, left=0, right=576), overlap=TBLR(top=64, bottom=0, left=0, right=64)), + Tile(coords=TBLR(top=256, bottom=576, left=512, right=1088), overlap=TBLR(top=64, bottom=0, left=64, right=64)), + Tile(coords=TBLR(top=256, bottom=576, left=1024, right=1600), overlap=TBLR(top=64, bottom=0, left=64, right=0)), + ] + + assert tiles == expected_tiles + + +def test_calc_tiles_with_overlap_not_evenly_divisible(): + """Test calc_tiles_with_overlap() behavior when the image requires 'uneven' overlaps to achieve proper coverage.""" + # Parameters chosen so that image is covered by 2 rows and 3 columns of tiles, with uneven overlaps. + tiles = calc_tiles_with_overlap(image_height=400, image_width=1200, tile_height=256, tile_width=512, overlap=64) + + expected_tiles = [ + # Row 0 + Tile(coords=TBLR(top=0, bottom=256, left=0, right=512), overlap=TBLR(top=0, bottom=112, left=0, right=64)), + Tile(coords=TBLR(top=0, bottom=256, left=448, right=960), overlap=TBLR(top=0, bottom=112, left=64, right=272)), + Tile(coords=TBLR(top=0, bottom=256, left=688, right=1200), overlap=TBLR(top=0, bottom=112, left=272, right=0)), + # Row 1 + Tile(coords=TBLR(top=144, bottom=400, left=0, right=512), overlap=TBLR(top=112, bottom=0, left=0, right=64)), + Tile( + coords=TBLR(top=144, bottom=400, left=448, right=960), overlap=TBLR(top=112, bottom=0, left=64, right=272) + ), + Tile( + coords=TBLR(top=144, bottom=400, left=688, right=1200), overlap=TBLR(top=112, bottom=0, left=272, right=0) + ), + ] + + assert tiles == expected_tiles + + +@pytest.mark.parametrize( + ["image_height", "image_width", "tile_height", "tile_width", "overlap", "raises"], + [ + (128, 128, 128, 128, 127, False), # OK + (128, 128, 128, 128, 0, False), # OK + (128, 128, 64, 64, 0, False), # OK + (128, 128, 129, 128, 0, True), # tile_height exceeds image_height. + (128, 128, 128, 129, 0, True), # tile_width exceeds image_width. + (128, 128, 64, 128, 64, True), # overlap equals tile_height. + (128, 128, 128, 64, 64, True), # overlap equals tile_width. + ], +) +def test_calc_tiles_with_overlap_input_validation( + image_height: int, image_width: int, tile_height: int, tile_width: int, overlap: int, raises: bool +): + """Test that calc_tiles_with_overlap() raises an exception if the inputs are invalid.""" + if raises: + with pytest.raises(AssertionError): + calc_tiles_with_overlap(image_height, image_width, tile_height, tile_width, overlap) + else: + calc_tiles_with_overlap(image_height, image_width, tile_height, tile_width, overlap) + + +############################################# +# Test merge_tiles_with_linear_blending(...) +############################################# + + +@pytest.mark.parametrize("blend_amount", [0, 32]) +def test_merge_tiles_with_linear_blending_horizontal(blend_amount: int): + """Test merge_tiles_with_linear_blending(...) behavior when merging horizontally.""" + # Initialize 2 tiles side-by-side. + tiles = [ + Tile(coords=TBLR(top=0, bottom=512, left=0, right=512), overlap=TBLR(top=0, bottom=0, left=0, right=64)), + Tile(coords=TBLR(top=0, bottom=512, left=448, right=960), overlap=TBLR(top=0, bottom=0, left=64, right=0)), + ] + + dst_image = np.zeros((512, 960, 3), dtype=np.uint8) + + # Prepare tile_images that match tiles. Pixel values are set based on the tile index. + tile_images = [ + np.zeros((512, 512, 3)) + 64, + np.zeros((512, 512, 3)) + 128, + ] + + # Calculate expected output. + expected_output = np.zeros((512, 960, 3), dtype=np.uint8) + expected_output[:, : 480 - (blend_amount // 2), :] = 64 + if blend_amount > 0: + gradient = np.linspace(start=64, stop=128, num=blend_amount, dtype=np.uint8).reshape((1, blend_amount, 1)) + expected_output[:, 480 - (blend_amount // 2) : 480 + (blend_amount // 2), :] = gradient + expected_output[:, 480 + (blend_amount // 2) :, :] = 128 + + merge_tiles_with_linear_blending( + dst_image=dst_image, tiles=tiles, tile_images=tile_images, blend_amount=blend_amount + ) + + np.testing.assert_array_equal(dst_image, expected_output, strict=True) + + +@pytest.mark.parametrize("blend_amount", [0, 32]) +def test_merge_tiles_with_linear_blending_vertical(blend_amount: int): + """Test merge_tiles_with_linear_blending(...) behavior when merging vertically.""" + # Initialize 2 tiles stacked vertically. + tiles = [ + Tile(coords=TBLR(top=0, bottom=512, left=0, right=512), overlap=TBLR(top=0, bottom=64, left=0, right=0)), + Tile(coords=TBLR(top=448, bottom=960, left=0, right=512), overlap=TBLR(top=64, bottom=0, left=0, right=0)), + ] + + dst_image = np.zeros((960, 512, 3), dtype=np.uint8) + + # Prepare tile_images that match tiles. Pixel values are set based on the tile index. + tile_images = [ + np.zeros((512, 512, 3)) + 64, + np.zeros((512, 512, 3)) + 128, + ] + + # Calculate expected output. + expected_output = np.zeros((960, 512, 3), dtype=np.uint8) + expected_output[: 480 - (blend_amount // 2), :, :] = 64 + if blend_amount > 0: + gradient = np.linspace(start=64, stop=128, num=blend_amount, dtype=np.uint8).reshape((blend_amount, 1, 1)) + expected_output[480 - (blend_amount // 2) : 480 + (blend_amount // 2), :, :] = gradient + expected_output[480 + (blend_amount // 2) :, :, :] = 128 + + merge_tiles_with_linear_blending( + dst_image=dst_image, tiles=tiles, tile_images=tile_images, blend_amount=blend_amount + ) + + np.testing.assert_array_equal(dst_image, expected_output, strict=True) + + +def test_merge_tiles_with_linear_blending_blend_amount_exceeds_vertical_overlap(): + """Test that merge_tiles_with_linear_blending(...) raises an exception if 'blend_amount' exceeds the overlap between + any vertically adjacent tiles. + """ + # Initialize 2 tiles stacked vertically. + tiles = [ + Tile(coords=TBLR(top=0, bottom=512, left=0, right=512), overlap=TBLR(top=0, bottom=64, left=0, right=0)), + Tile(coords=TBLR(top=448, bottom=960, left=0, right=512), overlap=TBLR(top=64, bottom=0, left=0, right=0)), + ] + + dst_image = np.zeros((960, 512, 3), dtype=np.uint8) + + # Prepare tile_images that match tiles. + tile_images = [np.zeros((512, 512, 3)), np.zeros((512, 512, 3))] + + # blend_amount=128 exceeds overlap of 64, so should raise exception. + with pytest.raises(AssertionError): + merge_tiles_with_linear_blending(dst_image=dst_image, tiles=tiles, tile_images=tile_images, blend_amount=128) + + +def test_merge_tiles_with_linear_blending_blend_amount_exceeds_horizontal_overlap(): + """Test that merge_tiles_with_linear_blending(...) raises an exception if 'blend_amount' exceeds the overlap between + any horizontally adjacent tiles. + """ + # Initialize 2 tiles side-by-side. + tiles = [ + Tile(coords=TBLR(top=0, bottom=512, left=0, right=512), overlap=TBLR(top=0, bottom=0, left=0, right=64)), + Tile(coords=TBLR(top=0, bottom=512, left=448, right=960), overlap=TBLR(top=0, bottom=0, left=64, right=0)), + ] + + dst_image = np.zeros((512, 960, 3), dtype=np.uint8) + + # Prepare tile_images that match tiles. + tile_images = [np.zeros((512, 512, 3)), np.zeros((512, 512, 3))] + + # blend_amount=128 exceeds overlap of 64, so should raise exception. + with pytest.raises(AssertionError): + merge_tiles_with_linear_blending(dst_image=dst_image, tiles=tiles, tile_images=tile_images, blend_amount=128) + + +def test_merge_tiles_with_linear_blending_tiles_overflow_dst_image(): + """Test that merge_tiles_with_linear_blending(...) raises an exception if any of the tiles overflows the + dst_image. + """ + tiles = [Tile(coords=TBLR(top=0, bottom=512, left=0, right=512), overlap=TBLR(top=0, bottom=0, left=0, right=0))] + + dst_image = np.zeros((256, 512, 3), dtype=np.uint8) + + # Prepare tile_images that match tiles. + tile_images = [np.zeros((512, 512, 3))] + + with pytest.raises(ValueError): + merge_tiles_with_linear_blending(dst_image=dst_image, tiles=tiles, tile_images=tile_images, blend_amount=0) + + +def test_merge_tiles_with_linear_blending_mismatched_list_lengths(): + """Test that merge_tiles_with_linear_blending(...) raises an exception if the lengths of 'tiles' and 'tile_images' + do not match. + """ + tiles = [Tile(coords=TBLR(top=0, bottom=512, left=0, right=512), overlap=TBLR(top=0, bottom=0, left=0, right=0))] + + dst_image = np.zeros((256, 512, 3), dtype=np.uint8) + + # tile_images is longer than tiles, so should cause an exception. + tile_images = [np.zeros((512, 512, 3)), np.zeros((512, 512, 3))] + + with pytest.raises(ValueError): + merge_tiles_with_linear_blending(dst_image=dst_image, tiles=tiles, tile_images=tile_images, blend_amount=0) diff --git a/tests/backend/tiles/test_utils.py b/tests/backend/tiles/test_utils.py new file mode 100644 index 00000000000..bbef233ca51 --- /dev/null +++ b/tests/backend/tiles/test_utils.py @@ -0,0 +1,101 @@ +import numpy as np +import pytest + +from invokeai.backend.tiles.utils import TBLR, paste + + +def test_paste_no_mask_success(): + """Test successful paste with mask=None.""" + dst_image = np.zeros((5, 5, 3), dtype=np.uint8) + + # Create src_image with a pattern that can be used to validate that it was pasted correctly. + src_image = np.zeros((3, 3, 3), dtype=np.uint8) + src_image[0, :, 0] = 1 # Row of 1s in channel 0. + src_image[:, 0, 1] = 2 # Column of 2s in channel 1. + + # Paste in bottom-center of dst_image. + box = TBLR(top=2, bottom=5, left=1, right=4) + + # Construct expected output image. + expected_output = np.zeros((5, 5, 3), dtype=np.uint8) + expected_output[2, 1:4, 0] = 1 + expected_output[2:5, 1, 1] = 2 + + paste(dst_image=dst_image, src_image=src_image, box=box) + + np.testing.assert_array_equal(dst_image, expected_output, strict=True) + + +def test_paste_with_mask_success(): + """Test successful paste with a mask.""" + dst_image = np.zeros((5, 5, 3), dtype=np.uint8) + + # Create src_image with a pattern that can be used to validate that it was pasted correctly. + src_image = np.zeros((3, 3, 3), dtype=np.uint8) + src_image[0, :, 0] = 64 # Row of 64s in channel 0. + src_image[:, 0, 1] = 128 # Column of 128s in channel 1. + + # Paste in bottom-center of dst_image. + box = TBLR(top=2, bottom=5, left=1, right=4) + + # Create a mask that blends the top-left corner of 'src_image' at 50%, and ignores the rest of src_image. + mask = np.zeros((3, 3)) + mask[0, 0] = 0.5 + + # Construct expected output image. + expected_output = np.zeros((5, 5, 3), dtype=np.uint8) + expected_output[2, 1, 0] = 32 + expected_output[2, 1, 1] = 64 + + paste(dst_image=dst_image, src_image=src_image, box=box, mask=mask) + + np.testing.assert_array_equal(dst_image, expected_output, strict=True) + + +@pytest.mark.parametrize("use_mask", [True, False]) +def test_paste_box_overflows_dst_image(use_mask: bool): + """Test that an exception is raised if 'box' overflows the 'dst_image'.""" + dst_image = np.zeros((5, 5, 3), dtype=np.uint8) + src_image = np.zeros((3, 3, 3), dtype=np.uint8) + mask = None + if use_mask: + mask = np.zeros((3, 3)) + + # Construct box that overflows bottom of dst_image. + top = 3 + left = 0 + box = TBLR(top=top, bottom=top + src_image.shape[0], left=left, right=left + src_image.shape[1]) + + with pytest.raises(ValueError): + paste(dst_image=dst_image, src_image=src_image, box=box, mask=mask) + + +@pytest.mark.parametrize("use_mask", [True, False]) +def test_paste_src_image_does_not_match_box(use_mask: bool): + """Test that an exception is raised if the 'src_image' shape does not match the 'box' dimensions.""" + dst_image = np.zeros((5, 5, 3), dtype=np.uint8) + src_image = np.zeros((3, 3, 3), dtype=np.uint8) + mask = None + if use_mask: + mask = np.zeros((3, 3)) + + # Construct box that is smaller than src_image. + box = TBLR(top=0, bottom=src_image.shape[0] - 1, left=0, right=src_image.shape[1]) + + with pytest.raises(ValueError): + paste(dst_image=dst_image, src_image=src_image, box=box, mask=mask) + + +def test_paste_mask_does_not_match_src_image(): + """Test that an exception is raised if the 'mask' shape is different than the 'src_image' shape.""" + dst_image = np.zeros((5, 5, 3), dtype=np.uint8) + src_image = np.zeros((3, 3, 3), dtype=np.uint8) + + # Construct mask that is smaller than the src_image. + mask = np.zeros((src_image.shape[0] - 1, src_image.shape[1])) + + # Construct box that matches src_image shape. + box = TBLR(top=0, bottom=src_image.shape[0], left=0, right=src_image.shape[1]) + + with pytest.raises(ValueError): + paste(dst_image=dst_image, src_image=src_image, box=box, mask=mask)