Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
DN6 committed Nov 14, 2024
1 parent 7204481 commit cc6833c
Show file tree
Hide file tree
Showing 10 changed files with 122 additions and 40 deletions.
10 changes: 6 additions & 4 deletions src/diffusers/pipelines/flux/pipeline_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,9 @@ def __init__(
self.vae_scale_factor = (
2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
# Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
# by the patch size. So the vae scale factor is multiplied by the patch size to account for this
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
self.tokenizer_max_length = (
self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
)
Expand Down Expand Up @@ -386,9 +388,9 @@ def check_inputs(
callback_on_step_end_tensor_inputs=None,
max_sequence_length=None,
):
if height % self.vae_scale_factor != 0 or width % self.vae_scale_factor != 0:
raise ValueError(
f"`height` and `width` have to be divisible by {self.vae_scale_factor} but are {height} and {width}."
if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
logger.warning(
f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
)

if callback_on_step_end_tensor_inputs is not None and not all(
Expand Down
16 changes: 10 additions & 6 deletions src/diffusers/pipelines/flux/pipeline_flux_controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,9 @@ def __init__(
self.vae_scale_factor = (
2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
# Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
# by the patch size. So the vae scale factor is multiplied by the patch size to account for this
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
self.tokenizer_max_length = (
self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
)
Expand Down Expand Up @@ -410,9 +412,9 @@ def check_inputs(
callback_on_step_end_tensor_inputs=None,
max_sequence_length=None,
):
if height % self.vae_scale_factor != 0 or width % self.vae_scale_factor != 0:
raise ValueError(
f"`height` and `width` have to be divisible by {self.vae_scale_factor} but are {height} and {width}."
if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
logger.warning(
f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
)

if callback_on_step_end_tensor_inputs is not None and not all(
Expand Down Expand Up @@ -500,8 +502,10 @@ def prepare_latents(
generator,
latents=None,
):
height = int(height) // self.vae_scale_factor
width = int(width) // self.vae_scale_factor
# VAE applies 8x compression on images but we must also account for packing which requires
# latent height and width to be divisible by 2.
height = int(height) // self.vae_scale_factor - ((int(height) // self.vae_scale_factor) % 2)
width = int(width) // self.vae_scale_factor - ((int(width) // self.vae_scale_factor) % 2)

shape = (batch_size, num_channels_latents, height, width)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,9 @@ def __init__(
self.vae_scale_factor = (
2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
# Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
# by the patch size. So the vae scale factor is multiplied by the patch size to account for this
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
self.tokenizer_max_length = (
self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
)
Expand Down Expand Up @@ -453,9 +455,9 @@ def check_inputs(
if strength < 0 or strength > 1:
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")

if height % self.vae_scale_factor != 0 or width % self.vae_scale_factor != 0:
raise ValueError(
f"`height` and `width` have to be divisible by {self.vae_scale_factor} but are {height} and {width}."
if height % self.vae_scale_factor * 2 != 0 or width % self.vae_scale_factor * 2 != 0:
logger.warning(
f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
)

if callback_on_step_end_tensor_inputs is not None and not all(
Expand Down Expand Up @@ -551,8 +553,10 @@ def prepare_latents(
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)

height = int(height) // self.vae_scale_factor
width = int(width) // self.vae_scale_factor
# VAE applies 8x compression on images but we must also account for packing which requires
# latent height and width to be divisible by 2.
height = int(height) // self.vae_scale_factor - ((int(height) // self.vae_scale_factor) % 2)
width = int(width) // self.vae_scale_factor - ((int(width) // self.vae_scale_factor) % 2)

shape = (batch_size, num_channels_latents, height, width)
latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -233,9 +233,11 @@ def __init__(
self.vae_scale_factor = (
2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
# Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
# by the patch size. So the vae scale factor is multiplied by the patch size to account for this
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
self.mask_processor = VaeImageProcessor(
vae_scale_factor=self.vae_scale_factor,
vae_scale_factor=self.vae_scale_factor * 2,
vae_latent_channels=self.vae.config.latent_channels,
do_normalize=False,
do_binarize=True,
Expand Down Expand Up @@ -467,9 +469,9 @@ def check_inputs(
if strength < 0 or strength > 1:
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")

if height % self.vae_scale_factor != 0 or width % self.vae_scale_factor != 0:
raise ValueError(
f"`height` and `width` have to be divisible by {self.vae_scale_factor} but are {height} and {width}."
if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
logger.warning(
f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
)

if callback_on_step_end_tensor_inputs is not None and not all(
Expand Down Expand Up @@ -578,8 +580,10 @@ def prepare_latents(
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)

height = int(height) // self.vae_scale_factor
width = int(width) // self.vae_scale_factor
# VAE applies 8x compression on images but we must also account for packing which requires
# latent height and width to be divisible by 2.
height = int(height) // self.vae_scale_factor - ((int(height) // self.vae_scale_factor) % 2)
width = int(width) // self.vae_scale_factor - ((int(width) // self.vae_scale_factor) % 2)

shape = (batch_size, num_channels_latents, height, width)
latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
Expand Down Expand Up @@ -624,8 +628,10 @@ def prepare_mask_latents(
device,
generator,
):
height = int(height) // self.vae_scale_factor
width = int(width) // self.vae_scale_factor
# VAE applies 8x compression on images but we must also account for packing which requires
# latent height and width to be divisible by 2.
height = int(height) // self.vae_scale_factor - ((int(height) // self.vae_scale_factor) % 2)
width = int(width) // self.vae_scale_factor - ((int(width) // self.vae_scale_factor) % 2)
# resize the mask to latents shape as we concatenate the mask to the latents
# we do that before converting to dtype to avoid breaking in case we're using cpu_offload
# and half precision
Expand Down
16 changes: 10 additions & 6 deletions src/diffusers/pipelines/flux/pipeline_flux_img2img.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,9 @@ def __init__(
self.vae_scale_factor = (
2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
# Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
# by the patch size. So the vae scale factor is multiplied by the patch size to account for this
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
self.tokenizer_max_length = (
self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
)
Expand Down Expand Up @@ -436,9 +438,9 @@ def check_inputs(
if strength < 0 or strength > 1:
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")

if height % self.vae_scale_factor != 0 or width % self.vae_scale_factor != 0:
raise ValueError(
f"`height` and `width` have to be divisible by {self.vae_scale_factor} but are {height} and {width}."
if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
logger.warning(
f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
)

if callback_on_step_end_tensor_inputs is not None and not all(
Expand Down Expand Up @@ -533,8 +535,10 @@ def prepare_latents(
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)

height = int(height) // self.vae_scale_factor
width = int(width) // self.vae_scale_factor
# VAE applies 8x compression on images but we must also account for packing which requires
# latent height and width to be divisible by 2.
height = int(height) // self.vae_scale_factor - ((int(height) // self.vae_scale_factor) % 2)
width = int(width) // self.vae_scale_factor - ((int(width) // self.vae_scale_factor) % 2)

shape = (batch_size, num_channels_latents, height, width)
latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
Expand Down
24 changes: 15 additions & 9 deletions src/diffusers/pipelines/flux/pipeline_flux_inpaint.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,9 +211,11 @@ def __init__(
self.vae_scale_factor = (
2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
# Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
# by the patch size. So the vae scale factor is multiplied by the patch size to account for this
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
self.mask_processor = VaeImageProcessor(
vae_scale_factor=self.vae_scale_factor,
vae_scale_factor=self.vae_scale_factor * 2,
vae_latent_channels=self.vae.config.latent_channels,
do_normalize=False,
do_binarize=True,
Expand Down Expand Up @@ -445,9 +447,9 @@ def check_inputs(
if strength < 0 or strength > 1:
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")

if height % self.vae_scale_factor != 0 or width % self.vae_scale_factor != 0:
raise ValueError(
f"`height` and `width` have to be divisible by {self.vae_scale_factor} but are {height} and {width}."
if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
logger.warning(
f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
)

if callback_on_step_end_tensor_inputs is not None and not all(
Expand Down Expand Up @@ -555,8 +557,10 @@ def prepare_latents(
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)

height = int(height) // self.vae_scale_factor
width = int(width) // self.vae_scale_factor
# VAE applies 8x compression on images but we must also account for packing which requires
# latent height and width to be divisible by 2.
height = int(height) // self.vae_scale_factor - ((int(height) // self.vae_scale_factor) % 2)
width = int(width) // self.vae_scale_factor - ((int(width) // self.vae_scale_factor) % 2)

shape = (batch_size, num_channels_latents, height, width)
latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
Expand Down Expand Up @@ -600,8 +604,10 @@ def prepare_mask_latents(
device,
generator,
):
height = int(height) // self.vae_scale_factor
width = int(width) // self.vae_scale_factor
# VAE applies 8x compression on images but we must also account for packing which requires
# latent height and width to be divisible by 2.
height = int(height) // self.vae_scale_factor - ((int(height) // self.vae_scale_factor) % 2)
width = int(width) // self.vae_scale_factor - ((int(width) // self.vae_scale_factor) % 2)
# resize the mask to latents shape as we concatenate the mask to the latents
# we do that before converting to dtype to avoid breaking in case we're using cpu_offload
# and half precision
Expand Down
14 changes: 14 additions & 0 deletions tests/pipelines/controlnet_flux/test_controlnet_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,20 @@ def test_controlnet_flux(self):
def test_xformers_attention_forwardGenerator_pass(self):
pass

def test_flux_image_output_shape(self):
pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
inputs = self.get_dummy_inputs(torch_device)

height_width_pairs = [(32, 32), (72, 56)]
for height, width in height_width_pairs:
expected_height = height - height % (pipe.vae_scale_factor * 2)
expected_width = width - width % (pipe.vae_scale_factor * 2)

inputs.update({"height": height, "width": width})
image = pipe(**inputs).images[0]
output_height, output_width = image.shape
assert (output_height, output_width) == (expected_height, expected_width)


@slow
@require_big_gpu_with_torch_cuda
Expand Down
14 changes: 14 additions & 0 deletions tests/pipelines/flux/test_pipeline_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,20 @@ def test_fused_qkv_projections(self):
original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2
), "Original outputs should match when fused QKV projections are disabled."

def test_flux_image_output_shape(self):
pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
inputs = self.get_dummy_inputs(torch_device)

height_width_pairs = [(32, 32), (72, 56)]
for height, width in height_width_pairs:
expected_height = height - height % (pipe.vae_scale_factor * 2)
expected_width = width - width % (pipe.vae_scale_factor * 2)

inputs.update({"height": height, "width": width})
image = pipe(**inputs).images[0]
output_height, output_width, _ = image.shape
assert (output_height, output_width) == (expected_height, expected_width)


@slow
@require_big_gpu_with_torch_cuda
Expand Down
14 changes: 14 additions & 0 deletions tests/pipelines/flux/test_pipeline_flux_img2img.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,3 +147,17 @@ def test_flux_prompt_embeds(self):

max_diff = np.abs(output_with_prompt - output_with_embeds).max()
assert max_diff < 1e-4

def test_flux_image_output_shape(self):
pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
inputs = self.get_dummy_inputs(torch_device)

height_width_pairs = [(32, 32), (72, 56)]
for height, width in height_width_pairs:
expected_height = height - height % (pipe.vae_scale_factor * 2)
expected_width = width - width % (pipe.vae_scale_factor * 2)

inputs.update({"height": height, "width": width})
image = pipe(**inputs).images[0]
output_height, output_width, _ = image.shape
assert (output_height, output_width) == (expected_height, expected_width)
14 changes: 14 additions & 0 deletions tests/pipelines/flux/test_pipeline_flux_inpaint.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,3 +149,17 @@ def test_flux_inpaint_prompt_embeds(self):

max_diff = np.abs(output_with_prompt - output_with_embeds).max()
assert max_diff < 1e-4

def test_flux_image_output_shape(self):
pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
inputs = self.get_dummy_inputs(torch_device)

height_width_pairs = [(32, 32), (72, 56)]
for height, width in height_width_pairs:
expected_height = height - height % (pipe.vae_scale_factor * 2)
expected_width = width - width % (pipe.vae_scale_factor * 2)

inputs.update({"height": height, "width": width})
image = pipe(**inputs).images[0]
output_height, output_width, _ = image.shape
assert (output_height, output_width) == (expected_height, expected_width)

0 comments on commit cc6833c

Please sign in to comment.