Skip to content

Commit

Permalink
Add sigmas to Flux pipelines
Browse files Browse the repository at this point in the history
  • Loading branch information
hlky committed Dec 2, 2024
1 parent 827b6c2 commit 7e5ad6b
Show file tree
Hide file tree
Showing 9 changed files with 63 additions and 68 deletions.
15 changes: 7 additions & 8 deletions src/diffusers/pipelines/flux/pipeline_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -554,7 +554,7 @@ def __call__(
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 28,
timesteps: List[int] = None,
sigmas: Optional[List[float]] = None,
guidance_scale: float = 3.5,
num_images_per_prompt: Optional[int] = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
Expand Down Expand Up @@ -585,10 +585,10 @@ def __call__(
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
timesteps (`List[int]`, *optional*):
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
passed will be used. Must be in descending order.
sigmas (`List[float]`, *optional*):
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
will be used.
guidance_scale (`float`, *optional*, defaults to 7.0):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Expand Down Expand Up @@ -699,7 +699,7 @@ def __call__(
)

# 5. Prepare timesteps
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
sigmas = sigmas or np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
image_seq_len = latents.shape[1]
mu = calculate_shift(
image_seq_len,
Expand All @@ -712,8 +712,7 @@ def __call__(
self.scheduler,
num_inference_steps,
device,
timesteps,
sigmas,
sigmas=sigmas,
mu=mu,
)
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
Expand Down
15 changes: 7 additions & 8 deletions src/diffusers/pipelines/flux/pipeline_flux_control.py
Original file line number Diff line number Diff line change
Expand Up @@ -621,7 +621,7 @@ def __call__(
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 28,
timesteps: List[int] = None,
sigmas: Optional[List[float]] = None,
guidance_scale: float = 3.5,
num_images_per_prompt: Optional[int] = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
Expand Down Expand Up @@ -660,10 +660,10 @@ def __call__(
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
timesteps (`List[int]`, *optional*):
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
passed will be used. Must be in descending order.
sigmas (`List[float]`, *optional*):
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
will be used.
guidance_scale (`float`, *optional*, defaults to 7.0):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Expand Down Expand Up @@ -799,7 +799,7 @@ def __call__(
)

# 5. Prepare timesteps
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
sigmas = sigmas or np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
image_seq_len = latents.shape[1]
mu = calculate_shift(
image_seq_len,
Expand All @@ -812,8 +812,7 @@ def __call__(
self.scheduler,
num_inference_steps,
device,
timesteps,
sigmas,
sigmas=sigmas,
mu=mu,
)
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
Expand Down
15 changes: 7 additions & 8 deletions src/diffusers/pipelines/flux/pipeline_flux_control_img2img.py
Original file line number Diff line number Diff line change
Expand Up @@ -647,7 +647,7 @@ def __call__(
width: Optional[int] = None,
strength: float = 0.6,
num_inference_steps: int = 28,
timesteps: List[int] = None,
sigmas: Optional[List[float]] = None,
guidance_scale: float = 7.0,
num_images_per_prompt: Optional[int] = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
Expand Down Expand Up @@ -698,10 +698,10 @@ def __call__(
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
timesteps (`List[int]`, *optional*):
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
passed will be used. Must be in descending order.
sigmas (`List[float]`, *optional*):
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
will be used.
guidance_scale (`float`, *optional*, defaults to 7.0):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Expand Down Expand Up @@ -805,7 +805,7 @@ def __call__(
)

# 4.Prepare timesteps
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
sigmas = sigmas or np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
image_seq_len = (int(height) // self.vae_scale_factor // 2) * (int(width) // self.vae_scale_factor // 2)
mu = calculate_shift(
image_seq_len,
Expand All @@ -818,8 +818,7 @@ def __call__(
self.scheduler,
num_inference_steps,
device,
timesteps,
sigmas,
sigmas=sigmas,
mu=mu,
)
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
Expand Down
15 changes: 7 additions & 8 deletions src/diffusers/pipelines/flux/pipeline_flux_controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -602,7 +602,7 @@ def __call__(
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 28,
timesteps: List[int] = None,
sigmas: Optional[List[float]] = None,
guidance_scale: float = 7.0,
control_guidance_start: Union[float, List[float]] = 0.0,
control_guidance_end: Union[float, List[float]] = 1.0,
Expand Down Expand Up @@ -638,10 +638,10 @@ def __call__(
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
timesteps (`List[int]`, *optional*):
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
passed will be used. Must be in descending order.
sigmas (`List[float]`, *optional*):
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
will be used.
guidance_scale (`float`, *optional*, defaults to 7.0):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Expand Down Expand Up @@ -872,7 +872,7 @@ def __call__(
)

# 5. Prepare timesteps
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
sigmas = sigmas or np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
image_seq_len = latents.shape[1]
mu = calculate_shift(
image_seq_len,
Expand All @@ -885,8 +885,7 @@ def __call__(
self.scheduler,
num_inference_steps,
device,
timesteps,
sigmas,
sigmas=sigmas,
mu=mu,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -646,7 +646,7 @@ def __call__(
width: Optional[int] = None,
strength: float = 0.6,
num_inference_steps: int = 28,
timesteps: List[int] = None,
sigmas: Optional[List[float]] = None,
guidance_scale: float = 7.0,
control_guidance_start: Union[float, List[float]] = 0.0,
control_guidance_end: Union[float, List[float]] = 1.0,
Expand Down Expand Up @@ -685,8 +685,10 @@ def __call__(
num_inference_steps (`int`, *optional*, defaults to 28):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
timesteps (`List[int]`, *optional*):
Custom timesteps to use for the denoising process.
sigmas (`List[float]`, *optional*):
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
will be used.
guidance_scale (`float`, *optional*, defaults to 7.0):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
control_mode (`int` or `List[int]`, *optional*):
Expand Down Expand Up @@ -858,7 +860,7 @@ def __call__(
control_mode = torch.tensor(control_mode_).to(device, dtype=torch.long)
control_mode = control_mode.reshape([-1, 1])

sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
sigmas = sigmas or np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
image_seq_len = (int(height) // self.vae_scale_factor // 2) * (int(width) // self.vae_scale_factor // 2)
mu = calculate_shift(
image_seq_len,
Expand All @@ -871,8 +873,7 @@ def __call__(
self.scheduler,
num_inference_steps,
device,
timesteps,
sigmas,
sigmas=sigmas,
mu=mu,
)
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -752,7 +752,7 @@ def __call__(
width: Optional[int] = None,
strength: float = 0.6,
padding_mask_crop: Optional[int] = None,
timesteps: List[int] = None,
sigmas: Optional[List[float]] = None,
num_inference_steps: int = 28,
guidance_scale: float = 7.0,
control_guidance_start: Union[float, List[float]] = 0.0,
Expand Down Expand Up @@ -799,8 +799,10 @@ def __call__(
num_inference_steps (`int`, *optional*, defaults to 28):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
timesteps (`List[int]`, *optional*):
Custom timesteps to use for the denoising process.
sigmas (`List[float]`, *optional*):
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
will be used.
guidance_scale (`float`, *optional*, defaults to 7.0):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0):
Expand Down Expand Up @@ -1009,7 +1011,7 @@ def __call__(

# 6. Prepare timesteps

sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
sigmas = sigmas or np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
image_seq_len = (int(global_height) // self.vae_scale_factor // 2) * (
int(global_width) // self.vae_scale_factor // 2
)
Expand All @@ -1024,8 +1026,7 @@ def __call__(
self.scheduler,
num_inference_steps,
device,
timesteps,
sigmas,
sigmas=sigmas,
mu=mu,
)
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
Expand Down
15 changes: 7 additions & 8 deletions src/diffusers/pipelines/flux/pipeline_flux_fill.py
Original file line number Diff line number Diff line change
Expand Up @@ -689,7 +689,7 @@ def __call__(
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 50,
timesteps: List[int] = None,
sigmas: Optional[List[float]] = None,
guidance_scale: float = 30.0,
num_images_per_prompt: Optional[int] = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
Expand Down Expand Up @@ -735,10 +735,10 @@ def __call__(
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
timesteps (`List[int]`, *optional*):
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
passed will be used. Must be in descending order.
sigmas (`List[float]`, *optional*):
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
will be used.
guidance_scale (`float`, *optional*, defaults to 7.0):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Expand Down Expand Up @@ -878,7 +878,7 @@ def __call__(
masked_image_latents = torch.cat((masked_image_latents, mask), dim=-1)

# 6. Prepare timesteps
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
sigmas = sigmas or np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
image_seq_len = latents.shape[1]
mu = calculate_shift(
image_seq_len,
Expand All @@ -891,8 +891,7 @@ def __call__(
self.scheduler,
num_inference_steps,
device,
timesteps,
sigmas,
sigmas=sigmas,
mu=mu,
)
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
Expand Down
15 changes: 7 additions & 8 deletions src/diffusers/pipelines/flux/pipeline_flux_img2img.py
Original file line number Diff line number Diff line change
Expand Up @@ -593,7 +593,7 @@ def __call__(
width: Optional[int] = None,
strength: float = 0.6,
num_inference_steps: int = 28,
timesteps: List[int] = None,
sigmas: Optional[List[float]] = None,
guidance_scale: float = 7.0,
num_images_per_prompt: Optional[int] = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
Expand Down Expand Up @@ -636,10 +636,10 @@ def __call__(
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
timesteps (`List[int]`, *optional*):
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
passed will be used. Must be in descending order.
sigmas (`List[float]`, *optional*):
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
will be used.
guidance_scale (`float`, *optional*, defaults to 7.0):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Expand Down Expand Up @@ -742,7 +742,7 @@ def __call__(
)

# 4.Prepare timesteps
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
sigmas = sigmas or np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
image_seq_len = (int(height) // self.vae_scale_factor // 2) * (int(width) // self.vae_scale_factor // 2)
mu = calculate_shift(
image_seq_len,
Expand All @@ -755,8 +755,7 @@ def __call__(
self.scheduler,
num_inference_steps,
device,
timesteps,
sigmas,
sigmas=sigmas,
mu=mu,
)
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
Expand Down
15 changes: 7 additions & 8 deletions src/diffusers/pipelines/flux/pipeline_flux_inpaint.py
Original file line number Diff line number Diff line change
Expand Up @@ -693,7 +693,7 @@ def __call__(
padding_mask_crop: Optional[int] = None,
strength: float = 0.6,
num_inference_steps: int = 28,
timesteps: List[int] = None,
sigmas: Optional[List[float]] = None,
guidance_scale: float = 7.0,
num_images_per_prompt: Optional[int] = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
Expand Down Expand Up @@ -753,10 +753,10 @@ def __call__(
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
timesteps (`List[int]`, *optional*):
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
passed will be used. Must be in descending order.
sigmas (`List[float]`, *optional*):
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
will be used.
guidance_scale (`float`, *optional*, defaults to 7.0):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Expand Down Expand Up @@ -873,7 +873,7 @@ def __call__(
)

# 4.Prepare timesteps
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
sigmas = sigmas or np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
image_seq_len = (int(height) // self.vae_scale_factor // 2) * (int(width) // self.vae_scale_factor // 2)
mu = calculate_shift(
image_seq_len,
Expand All @@ -886,8 +886,7 @@ def __call__(
self.scheduler,
num_inference_steps,
device,
timesteps,
sigmas,
sigmas=sigmas,
mu=mu,
)
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
Expand Down

0 comments on commit 7e5ad6b

Please sign in to comment.