Skip to content

Commit

Permalink
Fix using non-square images with UNet2DModel and DDIM/DDPM pipelines (o…
Browse files Browse the repository at this point in the history
…pen-mmlab#1289)

* fix non square images with UNet2DModel and DDIM/DDPM pipelines

* fix unet_2d `sample_size` docstring

* update pipeline tests for unet uncond

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
  • Loading branch information
jenkspt and patrickvonplaten authored Nov 23, 2022
1 parent 44e56de commit 8fd3a74
Show file tree
Hide file tree
Showing 5 changed files with 57 additions and 26 deletions.
6 changes: 3 additions & 3 deletions src/diffusers/models/unet_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ class UNet2DModel(ModelMixin, ConfigMixin):
implements for all the model (such as downloading or saving, etc.)
Parameters:
sample_size (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`, *optional*):
Input sample size.
sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
Height and width of input/output sample.
in_channels (`int`, *optional*, defaults to 3): Number of channels in the input image.
out_channels (`int`, *optional*, defaults to 3): Number of channels in the output.
center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
Expand All @@ -71,7 +71,7 @@ class UNet2DModel(ModelMixin, ConfigMixin):
@register_to_config
def __init__(
self,
sample_size: Optional[int] = None,
sample_size: Optional[Union[int, Tuple[int, int]]] = None,
in_channels: int = 3,
out_channels: int = 3,
center_input_sample: bool = False,
Expand Down
3 changes: 2 additions & 1 deletion src/diffusers/models/unet_2d_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,8 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
implements for all the models (such as downloading or saving, etc.)
Parameters:
sample_size (`int`, *optional*): The size of the input sample.
sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
Height and width of input/output sample.
in_channels (`int`, *optional*, defaults to 4): The number of channels in the input sample.
out_channels (`int`, *optional*, defaults to 4): The number of channels in the output.
center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
Expand Down
6 changes: 5 additions & 1 deletion src/diffusers/pipelines/ddim/pipeline_ddim.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,11 @@ def __call__(
generator = None

# Sample gaussian noise to begin loop
image_shape = (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size)
if isinstance(self.unet.sample_size, int):
image_shape = (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size)
else:
image_shape = (batch_size, self.unet.in_channels, *self.unet.sample_size)

if self.device.type == "mps":
# randn does not work reproducibly on mps
image = torch.randn(image_shape, generator=generator)
Expand Down
6 changes: 5 additions & 1 deletion src/diffusers/pipelines/ddpm/pipeline_ddpm.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,11 @@ def __call__(
generator = None

# Sample gaussian noise to begin loop
image_shape = (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size)
if isinstance(self.unet.sample_size, int):
image_shape = (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size)
else:
image_shape = (batch_size, self.unet.in_channels, *self.unet.sample_size)

if self.device.type == "mps":
# randn does not work reproducibly on mps
image = torch.randn(image_shape, generator=generator)
Expand Down
62 changes: 42 additions & 20 deletions tests/test_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import random
import tempfile
import unittest
from functools import partial

import numpy as np
import torch
Expand Down Expand Up @@ -46,6 +47,7 @@
from diffusers.schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
from diffusers.utils import CONFIG_NAME, WEIGHTS_NAME, floats_tensor, slow, torch_device
from diffusers.utils.testing_utils import CaptureLogger, get_tests_dir, require_torch_gpu
from parameterized import parameterized
from PIL import Image
from transformers import CLIPFeatureExtractor, CLIPModel, CLIPTextConfig, CLIPTextModel, CLIPTokenizer

Expand Down Expand Up @@ -247,7 +249,6 @@ def test_load_pipeline_from_git(self):


class PipelineFastTests(unittest.TestCase):
@property
def dummy_image(self):
batch_size = 1
num_channels = 3
Expand All @@ -256,27 +257,25 @@ def dummy_image(self):
image = floats_tensor((batch_size, num_channels) + sizes, rng=random.Random(0)).to(torch_device)
return image

@property
def dummy_uncond_unet(self):
def dummy_uncond_unet(self, sample_size=32):
torch.manual_seed(0)
model = UNet2DModel(
block_out_channels=(32, 64),
layers_per_block=2,
sample_size=32,
sample_size=sample_size,
in_channels=3,
out_channels=3,
down_block_types=("DownBlock2D", "AttnDownBlock2D"),
up_block_types=("AttnUpBlock2D", "UpBlock2D"),
)
return model

@property
def dummy_cond_unet(self):
def dummy_cond_unet(self, sample_size=32):
torch.manual_seed(0)
model = UNet2DConditionModel(
block_out_channels=(32, 64),
layers_per_block=2,
sample_size=32,
sample_size=sample_size,
in_channels=4,
out_channels=4,
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
Expand All @@ -285,13 +284,12 @@ def dummy_cond_unet(self):
)
return model

@property
def dummy_cond_unet_inpaint(self):
def dummy_cond_unet_inpaint(self, sample_size=32):
torch.manual_seed(0)
model = UNet2DConditionModel(
block_out_channels=(32, 64),
layers_per_block=2,
sample_size=32,
sample_size=sample_size,
in_channels=9,
out_channels=4,
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
Expand All @@ -300,7 +298,6 @@ def dummy_cond_unet_inpaint(self):
)
return model

@property
def dummy_vq_model(self):
torch.manual_seed(0)
model = VQModel(
Expand All @@ -313,7 +310,6 @@ def dummy_vq_model(self):
)
return model

@property
def dummy_vae(self):
torch.manual_seed(0)
model = AutoencoderKL(
Expand All @@ -326,7 +322,6 @@ def dummy_vae(self):
)
return model

@property
def dummy_text_encoder(self):
torch.manual_seed(0)
config = CLIPTextConfig(
Expand All @@ -342,7 +337,6 @@ def dummy_text_encoder(self):
)
return CLIPTextModel(config)

@property
def dummy_extractor(self):
def extract(*args, **kwargs):
class Out:
Expand All @@ -357,15 +351,43 @@ def to(self, device):

return extract

def test_components(self):
@parameterized.expand(
[
[DDIMScheduler, DDIMPipeline, 32],
[partial(DDPMScheduler, predict_epsilon=True), DDPMPipeline, 32],
[DDIMScheduler, DDIMPipeline, (32, 64)],
[partial(DDPMScheduler, predict_epsilon=True), DDPMPipeline, (64, 32)],
]
)
def test_uncond_unet_components(self, scheduler_fn=DDPMScheduler, pipeline_fn=DDPMPipeline, sample_size=32):
unet = self.dummy_uncond_unet(sample_size)
# DDIM doesn't take `predict_epsilon`, and DDPM requires it -- so using partial in parameterized decorator
scheduler = scheduler_fn()
pipeline = pipeline_fn(unet, scheduler).to(torch_device)

# Device type MPS is not supported for torch.Generator() api.
if torch_device == "mps":
generator = torch.manual_seed(0)
else:
generator = torch.Generator(device=torch_device).manual_seed(0)

out_image = pipeline(
generator=generator,
num_inference_steps=2,
output_type="np",
).images
sample_size = (sample_size, sample_size) if isinstance(sample_size, int) else sample_size
assert out_image.shape == (1, *sample_size, 3)

def test_stable_diffusion_components(self):
"""Test that components property works correctly"""
unet = self.dummy_cond_unet
unet = self.dummy_cond_unet()
scheduler = PNDMScheduler(skip_prk_steps=True)
vae = self.dummy_vae
bert = self.dummy_text_encoder
vae = self.dummy_vae()
bert = self.dummy_text_encoder()
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")

image = self.dummy_image.cpu().permute(0, 2, 3, 1)[0]
image = self.dummy_image().cpu().permute(0, 2, 3, 1)[0]
init_image = Image.fromarray(np.uint8(image)).convert("RGB")
mask_image = Image.fromarray(np.uint8(image + 4)).convert("RGB").resize((128, 128))

Expand All @@ -377,7 +399,7 @@ def test_components(self):
text_encoder=bert,
tokenizer=tokenizer,
safety_checker=None,
feature_extractor=self.dummy_extractor,
feature_extractor=self.dummy_extractor(),
).to(torch_device)
img2img = StableDiffusionImg2ImgPipeline(**inpaint.components).to(torch_device)
text2img = StableDiffusionPipeline(**inpaint.components).to(torch_device)
Expand Down

0 comments on commit 8fd3a74

Please sign in to comment.