Skip to content

Commit

Permalink
Upscaling fixed (open-mmlab#1402)
Browse files Browse the repository at this point in the history
* Upscaling fixed

* up

* more fixes

* fix

* more fixes

* finish again

* up
  • Loading branch information
patrickvonplaten authored Nov 24, 2022
1 parent cbfed0c commit 05a36d5
Show file tree
Hide file tree
Showing 26 changed files with 226 additions and 120 deletions.
20 changes: 10 additions & 10 deletions src/diffusers/pipeline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -544,7 +544,14 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
init_kwargs = {**init_kwargs, **passed_pipe_kwargs}

# remove `null` components
init_dict = {k: v for k, v in init_dict.items() if v[0] is not None}
def load_module(name, value):
if value[0] is None:
return False
if name in passed_class_obj and passed_class_obj[name] is None:
return False
return True

init_dict = {k: v for k, v in init_dict.items() if load_module(k, v)}

if len(unused_kwargs) > 0:
logger.warning(f"Keyword arguments {unused_kwargs} not recognized.")
Expand All @@ -560,12 +567,11 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P

is_pipeline_module = hasattr(pipelines, library_name)
loaded_sub_model = None
sub_model_should_be_defined = True

# if the model is in a pipeline module, then we load it from the pipeline
if name in passed_class_obj:
# 1. check that passed_class_obj has correct parent class
if not is_pipeline_module and passed_class_obj[name] is not None:
if not is_pipeline_module:
library = importlib.import_module(library_name)
class_obj = getattr(library, class_name)
importable_classes = LOADABLE_CLASSES[library_name]
Expand All @@ -581,12 +587,6 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
f"{passed_class_obj[name]} is of type: {type(passed_class_obj[name])}, but should be"
f" {expected_class_obj}"
)
elif passed_class_obj[name] is None and name not in pipeline_class._optional_components:
logger.warning(
f"You have passed `None` for {name} to disable its functionality in {pipeline_class}. Note"
f" that this might lead to problems when using {pipeline_class} and is not recommended."
)
sub_model_should_be_defined = False
else:
logger.warning(
f"You have passed a non-standard module {passed_class_obj[name]}. We cannot verify whether it"
Expand All @@ -608,7 +608,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
importable_classes = LOADABLE_CLASSES[library_name]
class_candidates = {c: getattr(library, c, None) for c in importable_classes.keys()}

if loaded_sub_model is None and sub_model_should_be_defined:
if loaded_sub_model is None:
load_method_name = None
for class_name, class_candidate in class_candidates.items():
if class_candidate is not None and issubclass(class_obj, class_candidate):
Expand Down
11 changes: 6 additions & 5 deletions src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ def __init__(
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.register_to_config(requires_safety_checker=requires_safety_checker)

def enable_xformers_memory_efficient_attention(self):
Expand Down Expand Up @@ -379,7 +380,7 @@ def check_inputs(self, prompt, height, width, callback_steps):
)

def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
shape = (batch_size, num_channels_latents, height // 8, width // 8)
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
if latents is None:
if device.type == "mps":
# randn does not work reproducibly on mps
Expand Down Expand Up @@ -420,9 +421,9 @@ def __call__(
Args:
prompt (`str` or `List[str]`):
The prompt or prompts to guide the image generation.
height (`int`, *optional*, defaults to self.unet.config.sample_size * 8):
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The height in pixels of the generated image.
width (`int`, *optional*, defaults to self.unet.config.sample_size * 8):
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The width in pixels of the generated image.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
Expand Down Expand Up @@ -469,8 +470,8 @@ def __call__(
(nsfw) content, according to the `safety_checker`.
"""
# 0. Default height and width to unet
height = height or self.unet.config.sample_size * 8
width = width or self.unet.config.sample_size * 8
height = height or self.unet.config.sample_size * self.vae_scale_factor
width = width or self.unet.config.sample_size * self.vae_scale_factor

# 1. Check inputs. Raise error if not correct
self.check_inputs(prompt, height, width, callback_steps)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ def __init__(
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.register_to_config(requires_safety_checker=requires_safety_checker)

def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def __init__(
):
super().__init__()
self.register_modules(vqvae=vqvae, bert=bert, tokenizer=tokenizer, unet=unet, scheduler=scheduler)
self.vae_scale_factor = 2 ** (len(self.vqvae.config.block_out_channels) - 1)

@torch.no_grad()
def __call__(
Expand All @@ -79,9 +80,9 @@ def __call__(
Args:
prompt (`str` or `List[str]`):
The prompt or prompts to guide the image generation.
height (`int`, *optional*, defaults to self.unet.config.sample_size * 8):
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The height in pixels of the generated image.
width (`int`, *optional*, defaults to self.unet.config.sample_size * 8):
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The width in pixels of the generated image.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
Expand All @@ -107,8 +108,8 @@ def __call__(
generated images.
"""
# 0. Default height and width to unet
height = height or self.unet.config.sample_size * 8
width = width or self.unet.config.sample_size * 8
height = height or self.unet.config.sample_size * self.vae_scale_factor
width = width or self.unet.config.sample_size * self.vae_scale_factor

if isinstance(prompt, str):
batch_size = 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ def __init__(
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)

def prepare_inputs(self, prompt: Union[str, List[str]]):
if not isinstance(prompt, (str, list)):
Expand Down Expand Up @@ -168,8 +169,8 @@ def _generate(
neg_prompt_ids: jnp.array = None,
):
# 0. Default height and width to unet
height = height or self.unet.config.sample_size * 8
width = width or self.unet.config.sample_size * 8
height = height or self.unet.config.sample_size * self.vae_scale_factor
width = width or self.unet.config.sample_size * self.vae_scale_factor

if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
Expand All @@ -192,7 +193,12 @@ def _generate(
uncond_embeddings = self.text_encoder(uncond_input, params=params["text_encoder"])[0]
context = jnp.concatenate([uncond_embeddings, text_embeddings])

latents_shape = (batch_size, self.unet.in_channels, height // 8, width // 8)
latents_shape = (
batch_size,
self.unet.in_channels,
height // self.vae_scale_factor,
width // self.vae_scale_factor,
)
if latents is None:
latents = jax.random.normal(prng_seed, shape=latents_shape, dtype=jnp.float32)
else:
Expand Down Expand Up @@ -269,9 +275,9 @@ def __call__(
Args:
prompt (`str` or `List[str]`):
The prompt or prompts to guide the image generation.
height (`int`, *optional*, defaults to self.unet.config.sample_size * 8):
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The height in pixels of the generated image.
width (`int`, *optional*, defaults to self.unet.config.sample_size * 8):
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The width in pixels of the generated image.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
Expand Down Expand Up @@ -307,8 +313,8 @@ def __call__(
"not-safe-for-work" (nsfw) content, according to the `safety_checker`.
"""
# 0. Default height and width to unet
height = height or self.unet.config.sample_size * 8
width = width or self.unet.config.sample_size * 8
height = height or self.unet.config.sample_size * self.vae_scale_factor
width = width or self.unet.config.sample_size * self.vae_scale_factor

if jit:
images = _p_generate(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ def __init__(
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.register_to_config(requires_safety_checker=requires_safety_checker)

def _encode_prompt(self, prompt, num_images_per_prompt, do_classifier_free_guidance, negative_prompt):
Expand Down Expand Up @@ -206,8 +207,8 @@ def __call__(
**kwargs,
):
# 0. Default height and width to unet
height = height or self.unet.config.sample_size * 8
width = width or self.unet.config.sample_size * 8
height = height or self.unet.config.sample_size * self.vae_scale_factor
width = width or self.unet.config.sample_size * self.vae_scale_factor

if isinstance(prompt, str):
batch_size = 1
Expand Down Expand Up @@ -241,7 +242,12 @@ def __call__(

# get the initial random noise unless the user supplied it
latents_dtype = text_embeddings.dtype
latents_shape = (batch_size * num_images_per_prompt, 4, height // 8, width // 8)
latents_shape = (
batch_size * num_images_per_prompt,
4,
height // self.vae_scale_factor,
width // self.vae_scale_factor,
)
if latents is None:
latents = generator.randn(*latents_shape).astype(latents_dtype)
elif latents.shape != latents_shape:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ def __init__(
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.register_to_config(requires_safety_checker=requires_safety_checker)

# Copied from diffusers.pipelines.stable_diffusion.pipeline_onnx_stable_diffusion.OnnxStableDiffusionPipeline._encode_prompt
Expand Down Expand Up @@ -273,9 +274,9 @@ def __call__(
repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted
to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L)
instead of 3, so the expected shape would be `(B, H, W, 1)`.
height (`int`, *optional*, defaults to self.unet.config.sample_size * 8):
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The height in pixels of the generated image.
width (`int`, *optional*, defaults to self.unet.config.sample_size * 8):
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The width in pixels of the generated image.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
Expand Down Expand Up @@ -321,8 +322,8 @@ def __call__(
(nsfw) content, according to the `safety_checker`.
"""
# 0. Default height and width to unet
height = height or self.unet.config.sample_size * 8
width = width or self.unet.config.sample_size * 8
height = height or self.unet.config.sample_size * self.vae_scale_factor
width = width or self.unet.config.sample_size * self.vae_scale_factor

if isinstance(prompt, str):
batch_size = 1
Expand Down Expand Up @@ -358,7 +359,12 @@ def __call__(
)

num_channels_latents = NUM_LATENT_CHANNELS
latents_shape = (batch_size * num_images_per_prompt, num_channels_latents, height // 8, width // 8)
latents_shape = (
batch_size * num_images_per_prompt,
num_channels_latents,
height // self.vae_scale_factor,
width // self.vae_scale_factor,
)
latents_dtype = text_embeddings.dtype
if latents is None:
latents = generator.randn(*latents_shape).astype(latents_dtype)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,11 @@ def preprocess(image):
return 2.0 * image - 1.0


def preprocess_mask(mask):
def preprocess_mask(mask, scale_factor=8):
mask = mask.convert("L")
w, h = mask.size
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
mask = mask.resize((w // 8, h // 8), resample=PIL.Image.NEAREST)
mask = mask.resize((w // scale_factor, h // scale_factor), resample=PIL.Image.NEAREST)
mask = np.array(mask).astype(np.float32) / 255.0
mask = np.tile(mask, (4, 1, 1))
mask = mask[None].transpose(0, 1, 2, 3) # what does this step do?
Expand Down Expand Up @@ -143,6 +143,7 @@ def __init__(
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.register_to_config(requires_safety_checker=requires_safety_checker)

# Copied from diffusers.pipelines.stable_diffusion.pipeline_onnx_stable_diffusion.OnnxStableDiffusionPipeline._encode_prompt
Expand Down Expand Up @@ -349,7 +350,7 @@ def __call__(

# preprocess mask
if not isinstance(mask_image, np.ndarray):
mask_image = preprocess_mask(mask_image)
mask_image = preprocess_mask(mask_image, self.vae_scale_factor)
mask_image = mask_image.astype(latents_dtype)
mask = np.concatenate([mask_image] * num_images_per_prompt, axis=0)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ def __init__(
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.register_to_config(requires_safety_checker=requires_safety_checker)

def enable_xformers_memory_efficient_attention(self):
Expand Down Expand Up @@ -378,7 +379,7 @@ def check_inputs(self, prompt, height, width, callback_steps):
)

def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
shape = (batch_size, num_channels_latents, height // 8, width // 8)
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
if latents is None:
if device.type == "mps":
# randn does not work reproducibly on mps
Expand Down Expand Up @@ -419,9 +420,9 @@ def __call__(
Args:
prompt (`str` or `List[str]`):
The prompt or prompts to guide the image generation.
height (`int`, *optional*, defaults to self.unet.config.sample_size * 8):
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The height in pixels of the generated image.
width (`int`, *optional*, defaults to self.unet.config.sample_size * 8):
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The width in pixels of the generated image.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
Expand Down Expand Up @@ -468,8 +469,8 @@ def __call__(
(nsfw) content, according to the `safety_checker`.
"""
# 0. Default height and width to unet
height = height or self.unet.config.sample_size * 8
width = width or self.unet.config.sample_size * 8
height = height or self.unet.config.sample_size * self.vae_scale_factor
width = width or self.unet.config.sample_size * self.vae_scale_factor

# 1. Check inputs. Raise error if not correct
self.check_inputs(prompt, height, width, callback_steps)
Expand Down
Loading

0 comments on commit 05a36d5

Please sign in to comment.