Skip to content

Commit

Permalink
Adding 'strength' parameter to StableDiffusionInpaintingPipeline (hug…
Browse files Browse the repository at this point in the history
…gingface#3424)

* Added explanation of 'strength' parameter

* Added get_timesteps function which relies on new strength parameter

* Added `strength` parameter which defaults to 1.

* Swapped ordering so `noise_timestep` can be calculated before masking the image

this is required when you aren't applying 100% noise to the masked region, e.g. strength < 1.

* Added strength to check_inputs, throws error if out of range

* Changed `prepare_latents` to initialise latents w.r.t strength

inspired from the stable diffusion img2img pipeline, init latents are initialised by converting the init image into a VAE latent and adding noise (based upon the strength parameter passed in), e.g. random when strength = 1, or the init image at strength = 0.

* WIP: Added a unit test for the new strength parameter in the StableDiffusionInpaintingPipeline

still need to add correct regression values

* Created a is_strength_max to initialise from pure random noise

* Updated unit tests w.r.t new strength parameter + fixed new strength unit test

* renamed parameter to avoid confusion with variable of same name

* Updated regression values for new strength test - now passes

* removed 'copied from' comment as this method is now different and divergent from the cpy

* Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Ensure backwards compatibility for prepare_mask_and_masked_image

created a return_image boolean and initialised to false

* Ensure backwards compatibility for prepare_latents

* Fixed copy check typo

* Fixes w.r.t backward compibility changes

* make style

* keep function argument ordering same for backwards compatibility in callees with copied from statements

* make fix-copies

---------

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: William Berman <WLBberman@gmail.com>
  • Loading branch information
3 people authored and hari10599 committed May 20, 2023
1 parent fa9a44a commit 8804fee
Show file tree
Hide file tree
Showing 3 changed files with 211 additions and 37 deletions.
47 changes: 44 additions & 3 deletions src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@


# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint.prepare_mask_and_masked_image
def prepare_mask_and_masked_image(image, mask, height, width):
def prepare_mask_and_masked_image(image, mask, height, width, return_image=False):
"""
Prepares a pair (image, mask) to be consumed by the Stable Diffusion pipeline. This means that those inputs will be
converted to ``torch.Tensor`` with shapes ``batch x channels x height x width`` where ``channels`` is ``3`` for the
Expand Down Expand Up @@ -209,6 +209,10 @@ def prepare_mask_and_masked_image(image, mask, height, width):

masked_image = image * (mask < 0.5)

# n.b. ensure backwards compatibility as old function does not return image
if return_image:
return mask, masked_image, image

return mask, masked_image


Expand Down Expand Up @@ -795,21 +799,58 @@ def prepare_control_image(
return image

# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint.StableDiffusionInpaintPipeline.prepare_latents
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
def prepare_latents(
self,
batch_size,
num_channels_latents,
height,
width,
dtype,
device,
generator,
latents=None,
image=None,
timestep=None,
is_strength_max=True,
):
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)

if (image is None or timestep is None) and not is_strength_max:
raise ValueError(
"Since strength < 1. initial latents are to be initialised as a combination of Image + Noise."
"However, either the image or the noise timestep has not been provided."
)

if latents is None:
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
if is_strength_max:
# if strength is 100% then simply initialise the latents to noise
latents = noise
else:
# otherwise initialise latents as init image + noise
image = image.to(device=device, dtype=dtype)
if isinstance(generator, list):
image_latents = [
self.vae.encode(image[i : i + 1]).latent_dist.sample(generator=generator[i])
for i in range(batch_size)
]
else:
image_latents = self.vae.encode(image).latent_dist.sample(generator=generator)

image_latents = self.vae.config.scaling_factor * image_latents

latents = self.scheduler.add_noise(image_latents, noise, timestep)
else:
latents = latents.to(device)

# scale the initial noise by the standard deviation required by the scheduler
latents = latents * self.scheduler.init_noise_sigma

return latents

def _default_height_width(self, height, width, image):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
logger = logging.get_logger(__name__) # pylint: disable=invalid-name


def prepare_mask_and_masked_image(image, mask, height, width):
def prepare_mask_and_masked_image(image, mask, height, width, return_image: bool = False):
"""
Prepares a pair (image, mask) to be consumed by the Stable Diffusion pipeline. This means that those inputs will be
converted to ``torch.Tensor`` with shapes ``batch x channels x height x width`` where ``channels`` is ``3`` for the
Expand Down Expand Up @@ -146,6 +146,10 @@ def prepare_mask_and_masked_image(image, mask, height, width):

masked_image = image * (mask < 0.5)

# n.b. ensure backwards compatibility as old function does not return image
if return_image:
return mask, masked_image, image

return mask, masked_image


Expand Down Expand Up @@ -552,17 +556,20 @@ def decode_latents(self, latents):
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
return image

# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.check_inputs
def check_inputs(
self,
prompt,
height,
width,
strength,
callback_steps,
negative_prompt=None,
prompt_embeds=None,
negative_prompt_embeds=None,
):
if strength < 0 or strength > 1:
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")

if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")

Expand Down Expand Up @@ -600,22 +607,58 @@ def check_inputs(
f" {negative_prompt_embeds.shape}."
)

# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
def prepare_latents(
self,
batch_size,
num_channels_latents,
height,
width,
dtype,
device,
generator,
latents=None,
image=None,
timestep=None,
is_strength_max=True,
):
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)

if (image is None or timestep is None) and not is_strength_max:
raise ValueError(
"Since strength < 1. initial latents are to be initialised as a combination of Image + Noise."
"However, either the image or the noise timestep has not been provided."
)

if latents is None:
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
if is_strength_max:
# if strength is 100% then simply initialise the latents to noise
latents = noise
else:
# otherwise initialise latents as init image + noise
image = image.to(device=device, dtype=dtype)
if isinstance(generator, list):
image_latents = [
self.vae.encode(image[i : i + 1]).latent_dist.sample(generator=generator[i])
for i in range(batch_size)
]
else:
image_latents = self.vae.encode(image).latent_dist.sample(generator=generator)

image_latents = self.vae.config.scaling_factor * image_latents

latents = self.scheduler.add_noise(image_latents, noise, timestep)
else:
latents = latents.to(device)

# scale the initial noise by the standard deviation required by the scheduler
latents = latents * self.scheduler.init_noise_sigma

return latents

def prepare_mask_latents(
Expand Down Expand Up @@ -669,6 +712,16 @@ def prepare_mask_latents(
masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)
return mask, masked_image_latents

# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps
def get_timesteps(self, num_inference_steps, strength, device):
# get the original timestep using init_timestep
init_timestep = min(int(num_inference_steps * strength), num_inference_steps)

t_start = max(num_inference_steps - init_timestep, 0)
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]

return timesteps, num_inference_steps - t_start

@torch.no_grad()
def __call__(
self,
Expand All @@ -677,6 +730,7 @@ def __call__(
mask_image: Union[torch.FloatTensor, PIL.Image.Image] = None,
height: Optional[int] = None,
width: Optional[int] = None,
strength: float = 1.0,
num_inference_steps: int = 50,
guidance_scale: float = 7.5,
negative_prompt: Optional[Union[str, List[str]]] = None,
Expand Down Expand Up @@ -710,6 +764,13 @@ def __call__(
The height in pixels of the generated image.
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The width in pixels of the generated image.
strength (`float`, *optional*, defaults to 1.):
Conceptually, indicates how much to transform the masked portion of the reference `image`. Must be
between 0 and 1. `image` will be used as a starting point, adding more noise to it the larger the
`strength`. The number of denoising steps depends on the amount of noise initially added. When
`strength` is 1, added noise will be maximum and the denoising process will run for the full number of
iterations specified in `num_inference_steps`. A value of 1, therefore, essentially ignores the masked
portion of the reference `image`.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
Expand Down Expand Up @@ -802,6 +863,7 @@ def __call__(
prompt,
height,
width,
strength,
callback_steps,
negative_prompt,
prompt_embeds,
Expand Down Expand Up @@ -833,12 +895,20 @@ def __call__(
negative_prompt_embeds=negative_prompt_embeds,
)

# 4. Preprocess mask and image - resizes image and mask w.r.t height and width
mask, masked_image = prepare_mask_and_masked_image(image, mask_image, height, width)

# 5. set timesteps
# 4. set timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.scheduler.timesteps
timesteps, num_inference_steps = self.get_timesteps(
num_inference_steps=num_inference_steps, strength=strength, device=device
)
# at which timestep to set the initial noise (n.b. 50% if strength is 0.5)
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
# create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise
is_strength_max = strength == 1.0

# 5. Preprocess mask and image
mask, masked_image, init_image = prepare_mask_and_masked_image(
image, mask_image, height, width, return_image=True
)

# 6. Prepare latent variables
num_channels_latents = self.vae.config.latent_channels
Expand All @@ -851,6 +921,9 @@ def __call__(
device,
generator,
latents,
image=init_image,
timestep=latent_timestep,
is_strength_max=is_strength_max,
)

# 7. Prepare mask latent variables
Expand Down
Loading

0 comments on commit 8804fee

Please sign in to comment.