Skip to content

Commit

Permalink
allow multiple generations per prompt (open-mmlab#741)
Browse files Browse the repository at this point in the history
* compute text embeds per prompt

* don't repeat uncond prompts

* repeat separatly

* update image2image

* fix repeat uncond embeds

* adapt inpaint pipeline

* ifx uncond tokens in img2img

* add tests and fix ucond embeds in im2img and inpaint pipe
  • Loading branch information
patil-suraj authored Oct 6, 2022
1 parent 367a671 commit c119dc4
Show file tree
Hide file tree
Showing 4 changed files with 236 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ def __call__(
num_inference_steps: int = 50,
guidance_scale: float = 7.5,
negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
eta: float = 0.0,
generator: Optional[torch.Generator] = None,
latents: Optional[torch.FloatTensor] = None,
Expand Down Expand Up @@ -148,6 +149,8 @@ def __call__(
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
if `guidance_scale` is less than `1`).
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
[`schedulers.DDIMScheduler`], will be ignored for others.
Expand Down Expand Up @@ -215,6 +218,9 @@ def __call__(
text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0]

# duplicate text embeddings for each generation per prompt
text_embeddings = text_embeddings.repeat_interleave(num_images_per_prompt, dim=0)

# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
Expand All @@ -223,14 +229,14 @@ def __call__(
if do_classifier_free_guidance:
uncond_tokens: List[str]
if negative_prompt is None:
uncond_tokens = [""] * batch_size
uncond_tokens = [""]
elif type(prompt) is not type(negative_prompt):
raise TypeError(
"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
" {type(prompt)}."
)
elif isinstance(negative_prompt, str):
uncond_tokens = [negative_prompt] * batch_size
uncond_tokens = [negative_prompt]
elif batch_size != len(negative_prompt):
raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
Expand All @@ -250,6 +256,9 @@ def __call__(
)
uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]

# duplicate unconditional embeddings for each generation per prompt
uncond_embeddings = uncond_embeddings.repeat_interleave(batch_size * num_images_per_prompt, dim=0)

# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
Expand All @@ -260,7 +269,7 @@ def __call__(
# Unlike in other pipelines, latents need to be generated in the target device
# for 1-to-1 results reproducibility with the CompVis implementation.
# However this currently doesn't work in `mps`.
latents_shape = (batch_size, self.unet.in_channels, height // 8, width // 8)
latents_shape = (batch_size * num_images_per_prompt, self.unet.in_channels, height // 8, width // 8)
latents_dtype = text_embeddings.dtype
if latents is None:
if self.device.type == "mps":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ def __call__(
num_inference_steps: Optional[int] = 50,
guidance_scale: Optional[float] = 7.5,
negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
eta: Optional[float] = 0.0,
generator: Optional[torch.Generator] = None,
output_type: Optional[str] = "pil",
Expand Down Expand Up @@ -164,6 +165,8 @@ def __call__(
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
if `guidance_scale` is less than `1`).
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
[`schedulers.DDIMScheduler`], will be ignored for others.
Expand Down Expand Up @@ -220,15 +223,15 @@ def __call__(
init_latents = 0.18215 * init_latents

# expand init_latents for batch_size
init_latents = torch.cat([init_latents] * batch_size)
init_latents = torch.cat([init_latents] * batch_size * num_images_per_prompt, dim=0)

# get the original timestep using init_timestep
offset = self.scheduler.config.get("steps_offset", 0)
init_timestep = int(num_inference_steps * strength) + offset
init_timestep = min(init_timestep, num_inference_steps)

timesteps = self.scheduler.timesteps[-init_timestep]
timesteps = torch.tensor([timesteps] * batch_size, device=self.device)
timesteps = torch.tensor([timesteps] * batch_size * num_images_per_prompt, device=self.device)

# add noise to latents using the timesteps
noise = torch.randn(init_latents.shape, generator=generator, device=self.device)
Expand All @@ -252,6 +255,9 @@ def __call__(
text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0]

# duplicate text embeddings for each generation per prompt
text_embeddings = text_embeddings.repeat_interleave(num_images_per_prompt, dim=0)

# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
Expand All @@ -260,14 +266,14 @@ def __call__(
if do_classifier_free_guidance:
uncond_tokens: List[str]
if negative_prompt is None:
uncond_tokens = [""] * batch_size
uncond_tokens = [""]
elif type(prompt) is not type(negative_prompt):
raise TypeError(
"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
" {type(prompt)}."
)
elif isinstance(negative_prompt, str):
uncond_tokens = [negative_prompt] * batch_size
uncond_tokens = [negative_prompt]
elif batch_size != len(negative_prompt):
raise ValueError("The length of `negative_prompt` should be equal to batch_size.")
else:
Expand All @@ -283,6 +289,9 @@ def __call__(
)
uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]

# duplicate unconditional embeddings for each generation per prompt
uncond_embeddings = uncond_embeddings.repeat_interleave(batch_size * num_images_per_prompt, dim=0)

# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ def __call__(
num_inference_steps: Optional[int] = 50,
guidance_scale: Optional[float] = 7.5,
negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
eta: Optional[float] = 0.0,
generator: Optional[torch.Generator] = None,
output_type: Optional[str] = "pil",
Expand Down Expand Up @@ -184,6 +185,8 @@ def __call__(
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
if `guidance_scale` is less than `1`).
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
[`schedulers.DDIMScheduler`], will be ignored for others.
Expand Down Expand Up @@ -242,15 +245,15 @@ def __call__(

init_latents = 0.18215 * init_latents

# Expand init_latents for batch_size
init_latents = torch.cat([init_latents] * batch_size)
# Expand init_latents for batch_size and num_images_per_prompt
init_latents = torch.cat([init_latents] * batch_size * num_images_per_prompt, dim=0)
init_latents_orig = init_latents

# preprocess mask
if not isinstance(mask_image, torch.FloatTensor):
mask_image = preprocess_mask(mask_image)
mask_image = mask_image.to(self.device)
mask = torch.cat([mask_image] * batch_size)
mask = torch.cat([mask_image] * batch_size * num_images_per_prompt)

# check sizes
if not mask.shape == init_latents.shape:
Expand All @@ -262,7 +265,7 @@ def __call__(
init_timestep = min(init_timestep, num_inference_steps)

timesteps = self.scheduler.timesteps[-init_timestep]
timesteps = torch.tensor([timesteps] * batch_size, device=self.device)
timesteps = torch.tensor([timesteps] * batch_size * num_images_per_prompt, device=self.device)

# add noise to latents using the timesteps
noise = torch.randn(init_latents.shape, generator=generator, device=self.device)
Expand All @@ -286,6 +289,9 @@ def __call__(
text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0]

# duplicate text embeddings for each generation per prompt
text_embeddings = text_embeddings.repeat_interleave(num_images_per_prompt, dim=0)

# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
Expand All @@ -294,14 +300,14 @@ def __call__(
if do_classifier_free_guidance:
uncond_tokens: List[str]
if negative_prompt is None:
uncond_tokens = [""] * batch_size
uncond_tokens = [""]
elif type(prompt) is not type(negative_prompt):
raise TypeError(
"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
" {type(prompt)}."
)
elif isinstance(negative_prompt, str):
uncond_tokens = [negative_prompt] * batch_size
uncond_tokens = [negative_prompt]
elif batch_size != len(negative_prompt):
raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
Expand All @@ -321,6 +327,9 @@ def __call__(
)
uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]

# duplicate unconditional embeddings for each generation per prompt
uncond_embeddings = uncond_embeddings.repeat_interleave(batch_size * num_images_per_prompt, dim=0)

# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
Expand Down
Loading

0 comments on commit c119dc4

Please sign in to comment.