Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add padding_mask_crop to all inpaint pipelines #6360

Merged
merged 14 commits into from
Jan 22, 2024

Conversation

rootonchair
Copy link
Contributor

What does this PR do?

Add padding_mask_crop to inpaint pipelines: SDXL, ControlNet, ControlNet SDXL

Fixes #6345 (issue)

Before submitting

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thank you!

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Could I see some results too?

@yiyixuxu shouldn't we add tests too?

@yiyixuxu
Copy link
Collaborator

I think it is fine not to add tests for these auto1111 features. We are not currently testing all the value combinations for all pipeline arguments

@rootonchair
Copy link
Contributor Author

I will try to get the result of padding_mask_crop with new pipeline. It will help if someone could provide an example code for running ControlNet inpaint

@rootonchair
Copy link
Contributor Author

Here are the result of SDXL

import torch
from diffusers import AutoPipelineForInpainting
from diffusers.utils import load_image
from PIL import Image

model = "stabilityai/stable-diffusion-xl-base-1.0"
blur_factor = 33
seed = 0

img_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/sdxl-text2img.png"
mask_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/sdxl-inpaint-mask.png"
base = load_image(img_url)
mask = load_image(mask_url)

# create inpaint pipeline
pipe1 = AutoPipelineForInpainting.from_pretrained(model, torch_dtype=torch.float16).to('cuda')

# this is baseline, no mask blur, no inpant_full_res
generator = torch.Generator(device='cuda').manual_seed(seed)    
inpaint = pipe1('boat', image=base, mask_image=mask, strength=0.75,generator=generator).images[0]
inpaint.save(f'out_base.png')

# create blurred nask
mask_blurred = pipe1.mask_processor.blur(mask, blur_factor=blur_factor)
mask_blurred.save(f'mask_blurred.png')

# with mask blur
generator = torch.Generator(device='cuda').manual_seed(seed) 
inpaint = pipe1('boat', image=base, mask_image=mask_blurred, strength=0.75,generator=generator).images[0]
inpaint.save(f'out_mask_blur.png')

# with both mask_blur and inpaint_full_res
generator = torch.Generator(device='cuda').manual_seed(seed) 
inpaint = pipe1('boat', image=base, mask_image=mask_blurred, strength=0.75,generator=generator, padding_mask_crop=32).images[0]
inpaint.save(f'out_mask_blur_full_res.png')

base
base
mask
mask

out_base
out_base
out_base_blur
out_mask_blur
out_mask_blur_full_res
out_mask_blur_full_res

@rootonchair
Copy link
Contributor Author

Here the result of ControlNet sd1.5 inpaint

from diffusers import StableDiffusionControlNetInpaintPipeline, ControlNetModel, DDIMScheduler
from diffusers.utils import load_image
import numpy as np
import cv2
from PIL import Image
import torch

init_image = load_image(
    "https://huggingface.co/datasets/diffusers/test-arrays/resolve/main/stable_diffusion_inpaint/boy.png"
)
init_image = init_image.resize((512, 512))
init_image.save("input.png")

generator = torch.Generator(device="cpu").manual_seed(1)

mask_image = load_image(
    "https://huggingface.co/datasets/diffusers/test-arrays/resolve/main/stable_diffusion_inpaint/boy_mask.png"
)
mask_image.save("input_mask.png")
mask_image = mask_image.resize((512, 512))


def make_canny_condition(image):
    image = np.array(image)
    image = cv2.Canny(image, 100, 200)
    image = image[:, :, None]
    image = np.concatenate([image, image, image], axis=2)
    image = Image.fromarray(image)
    return image


control_image = make_canny_condition(init_image)
control_image.save("control_image.png")

controlnet = ControlNetModel.from_pretrained(
    "lllyasviel/control_v11p_sd15_canny", torch_dtype=torch.float16
)
pipe = StableDiffusionControlNetInpaintPipeline.from_pretrained(
    "runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16
)

pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
pipe.enable_model_cpu_offload()

# generate image
image = pipe(
    "a handsome man with ray-ban sunglasses",
    num_inference_steps=30,
    generator=generator,
    width=512,
    height=512,
    eta=1.0,
    image=init_image,
    mask_image=mask_image,
    control_image=control_image,
    padding_mask_crop=32
).images[0]

image.save("image_out.png")

Input image
input
Mask image
input_mask
Control image
control_image
Run without padding_mask_crop
image_out_no_pad

Run with padding_mask_crop
image_out

@rootonchair
Copy link
Contributor Author

Finally, SDXL control net output

from diffusers import StableDiffusionXLControlNetInpaintPipeline, ControlNetModel, DDIMScheduler
from diffusers.utils import load_image
import cv2
from PIL import Image
import numpy as np
import torch

init_image = load_image(
    "https://huggingface.co/datasets/diffusers/test-arrays/resolve/main/stable_diffusion_inpaint/boy.png"
)
init_image = init_image.resize((1024, 1024))

generator = torch.Generator(device="cpu").manual_seed(1)

mask_image = load_image(
    "https://huggingface.co/datasets/diffusers/test-arrays/resolve/main/stable_diffusion_inpaint/boy_mask.png"
)
mask_image = mask_image.resize((1024, 1024))


def make_canny_condition(image):
    image = np.array(image)
    image = cv2.Canny(image, 100, 200)
    image = image[:, :, None]
    image = np.concatenate([image, image, image], axis=2)
    image = Image.fromarray(image)
    return image


control_image = make_canny_condition(init_image)

controlnet = ControlNetModel.from_pretrained(
    "diffusers/controlnet-canny-sdxl-1.0", torch_dtype=torch.float16
)
pipe = StableDiffusionXLControlNetInpaintPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0", controlnet=controlnet, torch_dtype=torch.float16
)

pipe.enable_model_cpu_offload()

# generate image
image = pipe(
    "a handsome man with ray-ban sunglasses",
    num_inference_steps=20,
    generator=generator,
    eta=1.0,
    image=init_image,
    mask_image=mask_image,
    control_image=control_image,
).images[0]
image.save("sdxl_controlnet_no_pad.png")


image = pipe(
    "a handsome man with ray-ban sunglasses",
    num_inference_steps=20,
    generator=generator,
    eta=1.0,
    width=1024,
    height=1024,
    image=init_image,
    mask_image=mask_image,
    control_image=control_image,
    padding_mask_crop=32
).images[0]
image.save("sdxl_controlnet_pad.png")

No padding_mask_crop

sdxl_controlnet_no_pad

Add padding_mask_crop
sdxl_controlnet_pad

@patrickvonplaten
Copy link
Contributor

Can we run make style here?

@yiyixuxu
Copy link
Collaborator

yiyixuxu commented Jan 2, 2024

we still see some of the astronauts in the sdxl example; wonder if it is related to this. #6417

can you run it again with the fix you proposed?

@rootonchair
Copy link
Contributor Author

@yiyixuxu I think it due to the size of padding_mask_crop

padding_mask_crop=32
out_mask_blur_full_res_32

padding_mask_crop=128
out_mask_blur_full_res_128

@rootonchair
Copy link
Contributor Author

Can we run make style here?

@patrickvonplaten I did run make style but it doesn't change any files. It seems like the error requires running make fix-copies but it changes many unrelated files

Run python utils/check_copies.py
Traceback (most recent call last):
  File "utils/check_copies.py", line 222, in <module>
    check_copies(args.fix_and_overwrite)
  File "utils/check_copies.py", line 206, in check_copies
    raise Exception(
Exception: Found the following copy inconsistencies:
- src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py: copy does not match pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.prepare_image at line 884
Run `make fix-copies` or `python utils/check_copies.py --fix_and_overwrite` to fix them.

@GoGiants1
Copy link

GoGiants1 commented Jan 3, 2024

Hi @rootonchair
Some inpaint models have Unet that in_channels==9 (realistic vision 5.1, runwayml/stable-diffusion-inpainting).

in check_inputs
    raise ValueError(
ValueError: The UNet should have 4 input channels for inpainting mask crop, but has 9 input channels.

I got error from StableDiffusionControlNetInpaintPipeline when using realistic vision inpainting model..!

@rootonchair
Copy link
Contributor Author

Hi @rootonchair Some inpaint models have Unet that in_channels==9 (realistic vision 5.1, runwayml/stable-diffusion-inpainting).

in check_inputs
    raise ValueError(
ValueError: The UNet should have 4 input channels for inpainting mask crop, but has 9 input channels.

I got error from StableDiffusionControlNetInpaintPipeline when using realistic vision inpainting model..!

@yiyixuxu should we change check_inputs condition? Running inpaint with 9 channel input seem to not raise any error

Below is the result of running padding_mask_crop with realisticVision
image_out

@patrickvonplaten
Copy link
Contributor

Can we run make style here?

@patrickvonplaten I did run make style but it doesn't change any files. It seems like the error requires running make fix-copies but it changes many unrelated files

Run python utils/check_copies.py
Traceback (most recent call last):
  File "utils/check_copies.py", line 222, in <module>
    check_copies(args.fix_and_overwrite)
  File "utils/check_copies.py", line 206, in check_copies
    raise Exception(
Exception: Found the following copy inconsistencies:
- src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py: copy does not match pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.prepare_image at line 884
Run `make fix-copies` or `python utils/check_copies.py --fix_and_overwrite` to fix them.

Actually you should indeed run make fix-copies, it's expected that this changes unrelated other files due to our # Copied from mechanism

@rootonchair
Copy link
Contributor Author

@patrickvonplaten I see. I will make an update on that

@yiyixuxu
Copy link
Collaborator

yiyixuxu commented Jan 8, 2024

@yiyixuxu should we change check_inputs condition? Running inpaint with 9 channel input seem to not raise any error

let's change it! you can change for inpaint pipeline too :)

let's also pass output_type to check_input() method and make sure we only support PIL output with padding_mask_crop feature see comments #6072 (comment)

@yiyixuxu
Copy link
Collaborator

yiyixuxu commented Jan 8, 2024

great job!
I'm going to look into the sdxl example a little bit more to see what's going on there. Other than that looks good to merge soon:)

@rootonchair
Copy link
Contributor Author

@yiyixuxu should we change check_inputs condition? Running inpaint with 9 channel input seem to not raise any error

let's change it! you can change for inpaint pipeline too :)

let's also pass output_type to check_input() method and make sure we only support PIL output with padding_mask_crop feature see comments #6072 (comment)

@yiyixuxu all done 😄

@@ -1264,6 +1298,13 @@ def __call__(
else:
batch_size = prompt_embeds.shape[0]

if padding_mask_crop is not None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

um just saw your issue #6435
maybe we need to move this code into prepare_control_image()?

see my comment here #6435 (comment)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think it would work. width and height are still None in there. Do you think we should handle None in get_crop_region?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok!
you can use self.image_processor. get_default_height_width(image) to get it

@rootonchair
Copy link
Contributor Author

@yiyixuxu could you help me review this PR?

Copy link
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry I'm a little bit slow in reviewing this.
looks good and I left a few comments. Thanks again for working on this!

f"The mask image should be a PIL image when inpainting mask crop, but is of type"
f" {type(mask_image)}."
)
if output_type != "pil":
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

Comment on lines 696 to 701
if self.unet.config.in_channels != 4:
if self.unet.config.in_channels != 4 and self.unet.config.in_channels != 9:
raise ValueError(
f"The UNet should have 4 input channels for inpainting mask crop, but has"
f"The UNet should have 4 or 9 input channels for inpainting mask crop, but has"
f" {self.unet.config.in_channels} input channels."
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can remove this warning

@@ -1527,10 +1559,22 @@ def denoising_value_valid(dnv):
is_strength_max = strength == 1.0

# 5. Preprocess mask and image
init_image = self.image_processor.preprocess(image, height=height, width=width)
if padding_mask_crop is not None:
crops_coords = self.mask_processor.get_crop_region(mask_image, width, height, pad=padding_mask_crop)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we don't need height, width = self.image_processor.get_default_height_width(image, height, width) here?

Copy link
Contributor Author

@rootonchair rootonchair Jan 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rootonchair and others added 3 commits January 21, 2024 22:31
Co-authored-by: YiYi Xu <yixu310@gmail.com>
@rootonchair
Copy link
Contributor Author

@yiyixuxu fixed. Thank you for your reviews

Copy link
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks

@yiyixuxu yiyixuxu merged commit 8e7bbfb into huggingface:main Jan 22, 2024
14 checks passed
@rootonchair rootonchair deleted the add_padding_mask_crop branch January 22, 2024 09:21
AmericanPresidentJimmyCarter pushed a commit to AmericanPresidentJimmyCarter/diffusers that referenced this pull request Apr 26, 2024
* add padding_mask_crop
---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Co-authored-by: YiYi Xu <yixu310@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

add padding_mask_crop to all inpaint pipelines
6 participants