Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding 'strength' parameter to StableDiffusionInpaintingPipeline #3424

Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
e4d82f2
Added explanation of 'strength' parameter
May 12, 2023
a0a04b8
Added get_timesteps function which relies on new strength parameter
May 12, 2023
43fce93
Added `strength` parameter which defaults to 1.
May 12, 2023
4176047
Swapped ordering so `noise_timestep` can be calculated before masking…
May 12, 2023
db47974
Added strength to check_inputs, throws error if out of range
May 12, 2023
81660d0
Changed `prepare_latents` to initialise latents w.r.t strength
May 12, 2023
73b2d20
WIP: Added a unit test for the new strength parameter in the StableDi…
May 12, 2023
d900fb8
Created a is_strength_max to initialise from pure random noise
May 12, 2023
4fe9a26
Updated unit tests w.r.t new strength parameter + fixed new strength …
May 12, 2023
8aa9489
renamed parameter to avoid confusion with variable of same name
May 12, 2023
cd3101b
Updated regression values for new strength test - now passes
May 12, 2023
aca884f
removed 'copied from' comment as this method is now different and div…
May 12, 2023
f245d6e
Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffu…
rupertmenneer May 16, 2023
60c1a35
Ensure backwards compatibility for prepare_mask_and_masked_image
May 16, 2023
b0f874b
Ensure backwards compatibility for prepare_latents
May 16, 2023
46583cc
Fixed copy check typo
May 16, 2023
c14ecc6
Fixes w.r.t backward compibility changes
May 16, 2023
dc65be6
make style
williamberman May 17, 2023
13f7c94
Merge branch 'main' into add_strength_param_to_inpainting_pipeline
williamberman May 17, 2023
26f0c2e
keep function argument ordering same for backwards compatibility in c…
williamberman May 17, 2023
934974a
make fix-copies
williamberman May 17, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def prepare_mask_and_masked_image(image, mask, height, width):

masked_image = image * (mask < 0.5)

return mask, masked_image
return mask, masked_image, image
rupertmenneer marked this conversation as resolved.
Show resolved Hide resolved


class StableDiffusionInpaintPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin):
Expand Down Expand Up @@ -558,11 +558,15 @@ def check_inputs(
prompt,
height,
width,
strength,
callback_steps,
negative_prompt=None,
prompt_embeds=None,
negative_prompt_embeds=None,
):
if strength < 0 or strength > 1:
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")

if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")

Expand Down Expand Up @@ -600,8 +604,7 @@ def check_inputs(
f" {negative_prompt_embeds.shape}."
)

# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
def prepare_latents(self, image, timestep, batch_size, num_channels_latents, height, width, dtype, device, generator, is_strength_max, latents=None):
rupertmenneer marked this conversation as resolved.
Show resolved Hide resolved
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
Expand All @@ -610,12 +613,31 @@ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype
)

if latents is None:
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
if is_strength_max:
# if strength is 100% then simply initialise the latents to noise
latents = noise
else:
# otherwise initialise latents as init image + noise
rupertmenneer marked this conversation as resolved.
Show resolved Hide resolved
image = image.to(device=device, dtype=dtype)
if isinstance(generator, list):
image_latents = [
self.vae.encode(image[i : i + 1]).latent_dist.sample(generator=generator[i])
for i in range(batch_size)
]
else:
image_latents = self.vae.encode(image).latent_dist.sample(generator=generator)

image_latents = self.vae.config.scaling_factor * image_latents

latents = self.scheduler.add_noise(image_latents, noise, timestep)
else:
latents = latents.to(device)


# scale the initial noise by the standard deviation required by the scheduler
latents = latents * self.scheduler.init_noise_sigma

return latents

def prepare_mask_latents(
Expand Down Expand Up @@ -668,6 +690,16 @@ def prepare_mask_latents(
# aligning device to prevent device errors when concating it with the latent model input
masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)
return mask, masked_image_latents

# Copied from diffusers.pipelines.deepfloyd_if.pipeline_if_img2img.IFImg2ImgPipeline.get_timesteps
rupertmenneer marked this conversation as resolved.
Show resolved Hide resolved
def get_timesteps(self, num_inference_steps, strength):
# get the original timestep using init_timestep
init_timestep = min(int(num_inference_steps * strength), num_inference_steps)

t_start = max(num_inference_steps - init_timestep, 0)
timesteps = self.scheduler.timesteps[t_start:]

return timesteps, num_inference_steps - t_start

@torch.no_grad()
def __call__(
Expand All @@ -677,6 +709,7 @@ def __call__(
mask_image: Union[torch.FloatTensor, PIL.Image.Image] = None,
height: Optional[int] = None,
width: Optional[int] = None,
strength: float = 1.,
num_inference_steps: int = 50,
guidance_scale: float = 7.5,
negative_prompt: Optional[Union[str, List[str]]] = None,
Expand Down Expand Up @@ -710,6 +743,13 @@ def __call__(
The height in pixels of the generated image.
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The width in pixels of the generated image.
strength (`float`, *optional*, defaults to 1.):
Conceptually, indicates how much to transform the masked portion of the reference `image`.
Must be between 0 and 1. `image` will be used as a starting point, adding more noise to it the
larger the `strength`. The number of denoising steps depends on the amount of noise initially
added. When `strength` is 1, added noise will be maximum and the denoising process will run for
the full number of iterations specified in `num_inference_steps`. A value of 1, therefore,
essentially ignores the masked portion of the reference `image`.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
Expand Down Expand Up @@ -802,6 +842,7 @@ def __call__(
prompt,
height,
width,
strength,
callback_steps,
negative_prompt,
prompt_embeds,
Expand Down Expand Up @@ -833,23 +874,30 @@ def __call__(
negative_prompt_embeds=negative_prompt_embeds,
)

# 4. Preprocess mask and image - resizes image and mask w.r.t height and width
mask, masked_image = prepare_mask_and_masked_image(image, mask_image, height, width)

# 5. set timesteps
# 4. set timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.scheduler.timesteps
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps=num_inference_steps, strength=strength)
# at which timestep to set the initial noise (n.b. 50% if strength is 0.5)
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
# create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise
is_strength_max = strength == 1.

# 5. Preprocess mask and image
mask, masked_image, init_image = prepare_mask_and_masked_image(image, mask_image, height, width)

# 6. Prepare latent variables
num_channels_latents = self.vae.config.latent_channels
latents = self.prepare_latents(
init_image,
latent_timestep,
batch_size * num_images_per_prompt,
num_channels_latents,
height,
width,
prompt_embeds.dtype,
device,
generator,
is_strength_max,
latents,
)

Expand Down
65 changes: 50 additions & 15 deletions tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,25 @@ def test_stable_diffusion_inpaint_pil_input_resolution_test(self):
# verify that the returned image has the same height and width as the input height and width
assert image.shape == (1, inputs["height"], inputs["width"], 3)

def test_stable_diffusion_inpaint_strength_test(self):
pipe = StableDiffusionInpaintPipeline.from_pretrained(
"runwayml/stable-diffusion-inpainting", safety_checker=None
)
pipe.scheduler = LMSDiscreteScheduler.from_config(pipe.scheduler.config)
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
pipe.enable_attention_slicing()

inputs = self.get_inputs(torch_device)
# change input strength
inputs["strength"] = 0.75
image = pipe(**inputs).images
# verify that the returned image has the same height and width as the input height and width
assert image.shape == (1, 512, 512, 3)

image_slice = image[0, 253:256, 253:256, -1].flatten()
expected_slice = np.array([0.0021, 0.2350, 0.3712, 0.0575, 0.2485, 0.3451, 0.1857, 0.3156, 0.3943])
assert np.abs(expected_slice - image_slice).max() < 3e-3

@nightly
@require_torch_gpu
Expand Down Expand Up @@ -428,24 +447,30 @@ def test_pil_inputs(self):
mask = np.random.randint(0, 255, (height, width), dtype=np.uint8) > 127.5
mask = Image.fromarray((mask * 255).astype(np.uint8))

t_mask, t_masked = prepare_mask_and_masked_image(im, mask, height, width)
t_mask, t_masked, t_image = prepare_mask_and_masked_image(im, mask, height, width)

self.assertTrue(isinstance(t_mask, torch.Tensor))
self.assertTrue(isinstance(t_masked, torch.Tensor))
self.assertTrue(isinstance(t_image, torch.Tensor))

self.assertEqual(t_mask.ndim, 4)
self.assertEqual(t_masked.ndim, 4)
self.assertEqual(t_image.ndim, 4)

self.assertEqual(t_mask.shape, (1, 1, height, width))
self.assertEqual(t_masked.shape, (1, 3, height, width))
self.assertEqual(t_image.shape, (1, 3, height, width))

self.assertTrue(t_mask.dtype == torch.float32)
self.assertTrue(t_masked.dtype == torch.float32)
self.assertTrue(t_image.dtype == torch.float32)

self.assertTrue(t_mask.min() >= 0.0)
self.assertTrue(t_mask.max() <= 1.0)
self.assertTrue(t_masked.min() >= -1.0)
self.assertTrue(t_masked.min() <= 1.0)
self.assertTrue(t_image.min() >= -1.0)
self.assertTrue(t_image.min() >= -1.0)

self.assertTrue(t_mask.sum() > 0.0)

Expand All @@ -468,11 +493,12 @@ def test_np_inputs(self):
)
mask_pil = Image.fromarray((mask_np * 255).astype(np.uint8))

t_mask_np, t_masked_np = prepare_mask_and_masked_image(im_np, mask_np, height, width)
t_mask_pil, t_masked_pil = prepare_mask_and_masked_image(im_pil, mask_pil, height, width)
t_mask_np, t_masked_np, t_image_np = prepare_mask_and_masked_image(im_np, mask_np, height, width)
t_mask_pil, t_masked_pil, t_image_pil = prepare_mask_and_masked_image(im_pil, mask_pil, height, width)

self.assertTrue((t_mask_np == t_mask_pil).all())
self.assertTrue((t_masked_np == t_masked_pil).all())
self.assertTrue((t_image_np == t_image_pil).all())

def test_torch_3D_2D_inputs(self):
height, width = 32, 32
Expand Down Expand Up @@ -502,13 +528,14 @@ def test_torch_3D_2D_inputs(self):
im_np = im_tensor.numpy().transpose(1, 2, 0)
mask_np = mask_tensor.numpy()

t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image(
t_mask_tensor, t_masked_tensor, t_image_tensor = prepare_mask_and_masked_image(
im_tensor / 127.5 - 1, mask_tensor, height, width
)
t_mask_np, t_masked_np = prepare_mask_and_masked_image(im_np, mask_np, height, width)
t_mask_np, t_masked_np, t_image_np = prepare_mask_and_masked_image(im_np, mask_np, height, width)

self.assertTrue((t_mask_tensor == t_mask_np).all())
self.assertTrue((t_masked_tensor == t_masked_np).all())
self.assertTrue((t_image_tensor == t_image_np).all())

def test_torch_3D_3D_inputs(self):
height, width = 32, 32
Expand Down Expand Up @@ -539,13 +566,14 @@ def test_torch_3D_3D_inputs(self):
im_np = im_tensor.numpy().transpose(1, 2, 0)
mask_np = mask_tensor.numpy()[0]

t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image(
t_mask_tensor, t_masked_tensor, t_image_tensor = prepare_mask_and_masked_image(
im_tensor / 127.5 - 1, mask_tensor, height, width
)
t_mask_np, t_masked_np = prepare_mask_and_masked_image(im_np, mask_np, height, width)
t_mask_np, t_masked_np, t_image_np = prepare_mask_and_masked_image(im_np, mask_np, height, width)

self.assertTrue((t_mask_tensor == t_mask_np).all())
self.assertTrue((t_masked_tensor == t_masked_np).all())
self.assertTrue((t_image_tensor == t_image_np).all())

def test_torch_4D_2D_inputs(self):
height, width = 32, 32
Expand Down Expand Up @@ -576,13 +604,14 @@ def test_torch_4D_2D_inputs(self):
im_np = im_tensor.numpy()[0].transpose(1, 2, 0)
mask_np = mask_tensor.numpy()

t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image(
t_mask_tensor, t_masked_tensor, t_image_tensor = prepare_mask_and_masked_image(
im_tensor / 127.5 - 1, mask_tensor, height, width
)
t_mask_np, t_masked_np = prepare_mask_and_masked_image(im_np, mask_np, height, width)
t_mask_np, t_masked_np, t_image_np = prepare_mask_and_masked_image(im_np, mask_np, height, width)

self.assertTrue((t_mask_tensor == t_mask_np).all())
self.assertTrue((t_masked_tensor == t_masked_np).all())
self.assertTrue((t_image_tensor == t_image_np).all())

def test_torch_4D_3D_inputs(self):
height, width = 32, 32
Expand Down Expand Up @@ -614,13 +643,14 @@ def test_torch_4D_3D_inputs(self):
im_np = im_tensor.numpy()[0].transpose(1, 2, 0)
mask_np = mask_tensor.numpy()[0]

t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image(
t_mask_tensor, t_masked_tensor, t_image_tensor = prepare_mask_and_masked_image(
im_tensor / 127.5 - 1, mask_tensor, height, width
)
t_mask_np, t_masked_np = prepare_mask_and_masked_image(im_np, mask_np, height, width)
t_mask_np, t_masked_np, t_image_np = prepare_mask_and_masked_image(im_np, mask_np, height, width)

self.assertTrue((t_mask_tensor == t_mask_np).all())
self.assertTrue((t_masked_tensor == t_masked_np).all())
self.assertTrue((t_image_tensor == t_image_np).all())

def test_torch_4D_4D_inputs(self):
height, width = 32, 32
Expand Down Expand Up @@ -653,13 +683,14 @@ def test_torch_4D_4D_inputs(self):
im_np = im_tensor.numpy()[0].transpose(1, 2, 0)
mask_np = mask_tensor.numpy()[0][0]

t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image(
t_mask_tensor, t_masked_tensor, t_image_tensor = prepare_mask_and_masked_image(
im_tensor / 127.5 - 1, mask_tensor, height, width
)
t_mask_np, t_masked_np = prepare_mask_and_masked_image(im_np, mask_np, height, width)
t_mask_np, t_masked_np, t_image_np = prepare_mask_and_masked_image(im_np, mask_np, height, width)

self.assertTrue((t_mask_tensor == t_mask_np).all())
self.assertTrue((t_masked_tensor == t_masked_np).all())
self.assertTrue((t_image_tensor == t_image_np).all())

def test_torch_batch_4D_3D(self):
height, width = 32, 32
Expand Down Expand Up @@ -692,15 +723,17 @@ def test_torch_batch_4D_3D(self):
im_nps = [im.numpy().transpose(1, 2, 0) for im in im_tensor]
mask_nps = [mask.numpy() for mask in mask_tensor]

t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image(
t_mask_tensor, t_masked_tensor, t_image_tensor = prepare_mask_and_masked_image(
im_tensor / 127.5 - 1, mask_tensor, height, width
)
nps = [prepare_mask_and_masked_image(i, m, height, width) for i, m in zip(im_nps, mask_nps)]
t_mask_np = torch.cat([n[0] for n in nps])
t_masked_np = torch.cat([n[1] for n in nps])
t_image_np = torch.cat([n[2] for n in nps])

self.assertTrue((t_mask_tensor == t_mask_np).all())
self.assertTrue((t_masked_tensor == t_masked_np).all())
self.assertTrue((t_image_tensor == t_image_np).all())

def test_torch_batch_4D_4D(self):
height, width = 32, 32
Expand Down Expand Up @@ -734,15 +767,17 @@ def test_torch_batch_4D_4D(self):
im_nps = [im.numpy().transpose(1, 2, 0) for im in im_tensor]
mask_nps = [mask.numpy()[0] for mask in mask_tensor]

t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image(
t_mask_tensor, t_masked_tensor, t_image_tensor = prepare_mask_and_masked_image(
im_tensor / 127.5 - 1, mask_tensor, height, width
)
nps = [prepare_mask_and_masked_image(i, m, height, width) for i, m in zip(im_nps, mask_nps)]
t_mask_np = torch.cat([n[0] for n in nps])
t_masked_np = torch.cat([n[1] for n in nps])
t_image_np = torch.cat([n[2] for n in nps])

self.assertTrue((t_mask_tensor == t_mask_np).all())
self.assertTrue((t_masked_tensor == t_masked_np).all())
self.assertTrue((t_image_tensor == t_image_np).all())

def test_shape_mismatch(self):
height, width = 32, 32
Expand Down