Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

improve get_crop_region #14709

Merged
merged 1 commit into from
Jan 21, 2024
Merged

improve get_crop_region #14709

merged 1 commit into from
Jan 21, 2024

Conversation

w-e-w
Copy link
Collaborator

@w-e-w w-e-w commented Jan 20, 2024

Description

simplify and improve get_crop_region

test code

performance comparison code

from PIL import Image, ImageDraw
import numpy as np
import timeit


def get_crop_region(mask, pad=0):
    """finds a rectangular region that contains all masked ares in an image. Returns (x1, y1, x2, y2) coordinates of the rectangle.
    For example, if a user has painted the top-right part of a 512x512 image, the result may be (256, 0, 512, 256)"""

    h, w = mask.shape

    crop_left = 0
    for i in range(w):
        if not (mask[:, i] == 0).all():
            break
        crop_left += 1

    crop_right = 0
    for i in reversed(range(w)):
        if not (mask[:, i] == 0).all():
            break
        crop_right += 1

    crop_top = 0
    for i in range(h):
        if not (mask[i] == 0).all():
            break
        crop_top += 1

    crop_bottom = 0
    for i in reversed(range(h)):
        if not (mask[i] == 0).all():
            break
        crop_bottom += 1

    return (
        int(max(crop_left-pad, 0)),
        int(max(crop_top-pad, 0)),
        int(min(w - crop_right + pad, w)),
        int(min(h - crop_bottom + pad, h))
    )


def current(in_image, pad=0):
    return get_crop_region(np.array(in_image), pad)


def new_method(mask, pad=0):
    mask_img = mask if isinstance(mask, Image.Image) else Image.fromarray(mask)
    box = mask_img.getbbox()
    if box:
        x1, y1, x2, y2 = box
    else:  # when no box is found
        x1, y1 = mask_img.size
        x2 = y2 = 0
    return max(x1 - pad, 0), max(y1 - pad, 0), min(x2 + pad, mask_img.size[0]), min(y2 + pad, mask_img.size[1])

def new_method_with_np_input(mask, pad=0):
    # simulate if input is numpy array
    return new_method(np.array(mask), pad)


if __name__ == '__main__':
    img = Image.new('L', (1000, 1000), color='black')
    np_mask = np.array(img)

    padding = 10

    print(get_crop_region(np_mask, padding))            # (990, 990, 10, 10)
    print(current(img, padding))                        # (990, 990, 10, 10)
    print(new_method(img, padding))                     # (990, 990, 10, 10)
    print(new_method_with_np_input(img, padding))       # (990, 990, 10, 10)

    draw = ImageDraw.Draw(img)
    draw.ellipse((300, 400, 500, 600), fill='white')
    np_mask = np.array(img)

    print(get_crop_region(np_mask, padding))            # (290, 390, 511, 611)
    print(current(img, padding))                        # (290, 390, 511, 611)
    print(new_method(img, padding))                     # (290, 390, 511, 611)
    print(new_method_with_np_input(img, padding))       # (290, 390, 511, 611)

    iterations = 1000

    timeit_0 = timeit.timeit(lambda: get_crop_region(np_mask, padding), number=iterations)
    timeit_1 = timeit.timeit(lambda: current(img, padding), number=iterations)
    timeit_2 = timeit.timeit(lambda: new_method(img, padding), number=iterations)
    timeit_3 = timeit.timeit(lambda: new_method_with_np_input(img, padding), number=iterations)
    print(f"method_0 took {timeit_0:.6f} seconds for {iterations} iterations.")
    print(f"method_1 took {timeit_1:.6f} seconds for {iterations} iterations.")
    print(f"method_2 took {timeit_2:.6f} seconds for {iterations} iterations.")
    print(f"method_3 took {timeit_3:.6f} seconds for {iterations} iterations.")
    # method_0 took 3.534935 seconds for 1000 iterations.
    # method_1 took 4.385469 seconds for 1000 iterations.
    # method_2 took 0.224629 seconds for 1000 iterations.
    # method_3 took 0.967969 seconds for 1000 iterations.

conclusion
all round faster even in compatibility mode

Checklist:

@w-e-w w-e-w requested a review from AUTOMATIC1111 as a code owner January 20, 2024 22:26
@w-e-w w-e-w marked this pull request as draft January 20, 2024 22:45
@w-e-w w-e-w force-pushed the improve-get_crop_region branch from 3efbfd2 to 5f13560 Compare January 20, 2024 22:58
@w-e-w w-e-w marked this pull request as ready for review January 20, 2024 23:05
@w-e-w w-e-w force-pushed the improve-get_crop_region branch 5 times, most recently from 3ed1149 to cc1e7e7 Compare January 20, 2024 23:30
@w-e-w
Copy link
Collaborator Author

w-e-w commented Jan 20, 2024

it can be compacted down even more at the cost of readability

def get_crop_region(mask, pad=0):
    mask_img = mask if isinstance(mask, Image.Image) else Image.fromarray(mask)
    box = mask_img.getbbox()
    return max((box[0] - pad, 0), max(box[1] - pad, 0), min(box[2] + pad, mask_img.size[0]), max(box[3] + pad, mask_img.size[1])) if box else (min(mask_img.size[0] - pad, 0), min(mask_img.size[1] - pad, 0), max(0 + pad, mask_img.size[0]), max(0 + pad, mask_img.size[1])) 
def stupid(mask, pad=0):
    mask_img = mask if isinstance(mask, Image.Image) else Image.fromarray(mask)
    return (max(box[0] - pad, 0), max(box[1] - pad, 0), min(box[2] + pad, mask_img.size[0]), min(box[3] + pad, mask_img.size[1])) if (box := mask_img.getbbox()) else (max(mask_img.size[0] - pad, 0), max(mask_img.size[1] - pad, 0), min(pad, mask_img.size[0]), min(pad, mask_img.size[1]))
def shouldnt_ever_exist(mask, pad=0):
    return (max(box[0] - pad, 0), max(box[1] - pad, 0), min(box[2] + pad, mask_img.size[0]), min(box[3] + pad, mask_img.size[1])) if (box := (mask_img := mask if isinstance(mask, Image.Image) else Image.fromarray(mask)).getbbox()) else (max(mask_img.size[0] - pad, 0), max(mask_img.size[1] - pad, 0), min(pad, mask_img.size[0]), min(pad, mask_img.size[1]))
def just_why(mask, pad=0): return (max(box[0] - pad, 0), max(box[1] - pad, 0), min(box[2] + pad, mask_img.size[0]), min(box[3] + pad, mask_img.size[1])) if (box := (mask_img := mask if isinstance(mask, Image.Image) else Image.fromarray(mask)).getbbox()) else (max(mask_img.size[0] - pad, 0), max(mask_img.size[1] - pad, 0), min(pad, mask_img.size[0]), min(pad, mask_img.size[1]))
lambda_version = lambda mask, pad=0: (max(box[0] - pad, 0), max(box[1] - pad, 0), min(box[2] + pad, mask_img.size[0]), min(box[3] + pad, mask_img.size[1])) if (box := (mask_img := mask if isinstance(mask, Image.Image) else Image.fromarray(mask)).getbbox()) else (max(mask_img.size[0] - pad, 0), max(mask_img.size[1] - pad, 0), min(pad, mask_img.size[0]), min(pad, mask_img.size[1]))
c = lambda m, p=0: (max(b[0] - p, 0), max(b[1] - p, 0), min(b[2] + p, i.size[0]), min(b[3] + p, i.size[1])) if (b := (i := m if isinstance(m, Image.Image) else Image.fromarray(m)).getbbox()) else (max(i.size[0] - p, 0), max(i.size[1] - p, 0), min(p, i.size[0]), min(p, i.size[1]))

@w-e-w w-e-w marked this pull request as draft January 20, 2024 23:57
@w-e-w w-e-w force-pushed the improve-get_crop_region branch from cc1e7e7 to e36827a Compare January 21, 2024 00:02
@w-e-w w-e-w marked this pull request as ready for review January 21, 2024 00:25
@AUTOMATIC1111 AUTOMATIC1111 merged commit 8a6a4ad into dev Jan 21, 2024
6 checks passed
@AUTOMATIC1111 AUTOMATIC1111 deleted the improve-get_crop_region branch January 21, 2024 13:01
@w-e-w w-e-w mentioned this pull request Feb 17, 2024
@pawel665j pawel665j mentioned this pull request Apr 16, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants