diff --git a/src/diffusers/models/unet_2d.py b/src/diffusers/models/unet_2d.py index 0432405760..3279830d7d 100644 --- a/src/diffusers/models/unet_2d.py +++ b/src/diffusers/models/unet_2d.py @@ -43,8 +43,8 @@ class UNet2DModel(ModelMixin, ConfigMixin): implements for all the model (such as downloading or saving, etc.) Parameters: - sample_size (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`, *optional*): - Input sample size. + sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`): + Height and width of input/output sample. in_channels (`int`, *optional*, defaults to 3): Number of channels in the input image. out_channels (`int`, *optional*, defaults to 3): Number of channels in the output. center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample. @@ -71,7 +71,7 @@ class UNet2DModel(ModelMixin, ConfigMixin): @register_to_config def __init__( self, - sample_size: Optional[int] = None, + sample_size: Optional[Union[int, Tuple[int, int]]] = None, in_channels: int = 3, out_channels: int = 3, center_input_sample: bool = False, diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index c3f2fb87b6..b09044b57b 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -56,7 +56,8 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin): implements for all the models (such as downloading or saving, etc.) Parameters: - sample_size (`int`, *optional*): The size of the input sample. + sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`): + Height and width of input/output sample. in_channels (`int`, *optional*, defaults to 4): The number of channels in the input sample. out_channels (`int`, *optional*, defaults to 4): The number of channels in the output. center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample. diff --git a/src/diffusers/pipelines/ddim/pipeline_ddim.py b/src/diffusers/pipelines/ddim/pipeline_ddim.py index 6db6298329..b9e590dea6 100644 --- a/src/diffusers/pipelines/ddim/pipeline_ddim.py +++ b/src/diffusers/pipelines/ddim/pipeline_ddim.py @@ -89,7 +89,11 @@ def __call__( generator = None # Sample gaussian noise to begin loop - image_shape = (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size) + if isinstance(self.unet.sample_size, int): + image_shape = (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size) + else: + image_shape = (batch_size, self.unet.in_channels, *self.unet.sample_size) + if self.device.type == "mps": # randn does not work reproducibly on mps image = torch.randn(image_shape, generator=generator) diff --git a/src/diffusers/pipelines/ddpm/pipeline_ddpm.py b/src/diffusers/pipelines/ddpm/pipeline_ddpm.py index c937a23003..634e1c0f99 100644 --- a/src/diffusers/pipelines/ddpm/pipeline_ddpm.py +++ b/src/diffusers/pipelines/ddpm/pipeline_ddpm.py @@ -94,7 +94,11 @@ def __call__( generator = None # Sample gaussian noise to begin loop - image_shape = (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size) + if isinstance(self.unet.sample_size, int): + image_shape = (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size) + else: + image_shape = (batch_size, self.unet.in_channels, *self.unet.sample_size) + if self.device.type == "mps": # randn does not work reproducibly on mps image = torch.randn(image_shape, generator=generator) diff --git a/tests/test_pipelines.py b/tests/test_pipelines.py index b593791c94..19493e3231 100644 --- a/tests/test_pipelines.py +++ b/tests/test_pipelines.py @@ -18,6 +18,7 @@ import random import tempfile import unittest +from functools import partial import numpy as np import torch @@ -46,6 +47,7 @@ from diffusers.schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME from diffusers.utils import CONFIG_NAME, WEIGHTS_NAME, floats_tensor, slow, torch_device from diffusers.utils.testing_utils import CaptureLogger, get_tests_dir, require_torch_gpu +from parameterized import parameterized from PIL import Image from transformers import CLIPFeatureExtractor, CLIPModel, CLIPTextConfig, CLIPTextModel, CLIPTokenizer @@ -247,7 +249,6 @@ def test_load_pipeline_from_git(self): class PipelineFastTests(unittest.TestCase): - @property def dummy_image(self): batch_size = 1 num_channels = 3 @@ -256,13 +257,12 @@ def dummy_image(self): image = floats_tensor((batch_size, num_channels) + sizes, rng=random.Random(0)).to(torch_device) return image - @property - def dummy_uncond_unet(self): + def dummy_uncond_unet(self, sample_size=32): torch.manual_seed(0) model = UNet2DModel( block_out_channels=(32, 64), layers_per_block=2, - sample_size=32, + sample_size=sample_size, in_channels=3, out_channels=3, down_block_types=("DownBlock2D", "AttnDownBlock2D"), @@ -270,13 +270,12 @@ def dummy_uncond_unet(self): ) return model - @property - def dummy_cond_unet(self): + def dummy_cond_unet(self, sample_size=32): torch.manual_seed(0) model = UNet2DConditionModel( block_out_channels=(32, 64), layers_per_block=2, - sample_size=32, + sample_size=sample_size, in_channels=4, out_channels=4, down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"), @@ -285,13 +284,12 @@ def dummy_cond_unet(self): ) return model - @property - def dummy_cond_unet_inpaint(self): + def dummy_cond_unet_inpaint(self, sample_size=32): torch.manual_seed(0) model = UNet2DConditionModel( block_out_channels=(32, 64), layers_per_block=2, - sample_size=32, + sample_size=sample_size, in_channels=9, out_channels=4, down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"), @@ -300,7 +298,6 @@ def dummy_cond_unet_inpaint(self): ) return model - @property def dummy_vq_model(self): torch.manual_seed(0) model = VQModel( @@ -313,7 +310,6 @@ def dummy_vq_model(self): ) return model - @property def dummy_vae(self): torch.manual_seed(0) model = AutoencoderKL( @@ -326,7 +322,6 @@ def dummy_vae(self): ) return model - @property def dummy_text_encoder(self): torch.manual_seed(0) config = CLIPTextConfig( @@ -342,7 +337,6 @@ def dummy_text_encoder(self): ) return CLIPTextModel(config) - @property def dummy_extractor(self): def extract(*args, **kwargs): class Out: @@ -357,15 +351,43 @@ def to(self, device): return extract - def test_components(self): + @parameterized.expand( + [ + [DDIMScheduler, DDIMPipeline, 32], + [partial(DDPMScheduler, predict_epsilon=True), DDPMPipeline, 32], + [DDIMScheduler, DDIMPipeline, (32, 64)], + [partial(DDPMScheduler, predict_epsilon=True), DDPMPipeline, (64, 32)], + ] + ) + def test_uncond_unet_components(self, scheduler_fn=DDPMScheduler, pipeline_fn=DDPMPipeline, sample_size=32): + unet = self.dummy_uncond_unet(sample_size) + # DDIM doesn't take `predict_epsilon`, and DDPM requires it -- so using partial in parameterized decorator + scheduler = scheduler_fn() + pipeline = pipeline_fn(unet, scheduler).to(torch_device) + + # Device type MPS is not supported for torch.Generator() api. + if torch_device == "mps": + generator = torch.manual_seed(0) + else: + generator = torch.Generator(device=torch_device).manual_seed(0) + + out_image = pipeline( + generator=generator, + num_inference_steps=2, + output_type="np", + ).images + sample_size = (sample_size, sample_size) if isinstance(sample_size, int) else sample_size + assert out_image.shape == (1, *sample_size, 3) + + def test_stable_diffusion_components(self): """Test that components property works correctly""" - unet = self.dummy_cond_unet + unet = self.dummy_cond_unet() scheduler = PNDMScheduler(skip_prk_steps=True) - vae = self.dummy_vae - bert = self.dummy_text_encoder + vae = self.dummy_vae() + bert = self.dummy_text_encoder() tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") - image = self.dummy_image.cpu().permute(0, 2, 3, 1)[0] + image = self.dummy_image().cpu().permute(0, 2, 3, 1)[0] init_image = Image.fromarray(np.uint8(image)).convert("RGB") mask_image = Image.fromarray(np.uint8(image + 4)).convert("RGB").resize((128, 128)) @@ -377,7 +399,7 @@ def test_components(self): text_encoder=bert, tokenizer=tokenizer, safety_checker=None, - feature_extractor=self.dummy_extractor, + feature_extractor=self.dummy_extractor(), ).to(torch_device) img2img = StableDiffusionImg2ImgPipeline(**inpaint.components).to(torch_device) text2img = StableDiffusionPipeline(**inpaint.components).to(torch_device)