Skip to content

Commit

Permalink
Add VAEBackbone and use it for SD3 (#1892)
Browse files Browse the repository at this point in the history
* Add VAEBackbone and refactor SD3 to utilize VAE.

* Minor updates.

* Update `Task`

* Fix CI.

* Fix nit and add comments
  • Loading branch information
james77777778 authored Oct 3, 2024
1 parent d13bfa1 commit 11227f3
Show file tree
Hide file tree
Showing 13 changed files with 1,221 additions and 528 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class FlowMatchEulerDiscreteScheduler(layers.Layer):
https://arxiv.org/abs/2403.03206).
"""

def __init__(self, num_train_timesteps=1000, shift=1.0, **kwargs):
def __init__(self, num_train_timesteps=1000, shift=3.0, **kwargs):
super().__init__(**kwargs)
self.num_train_timesteps = int(num_train_timesteps)
self.shift = float(shift)
Expand Down Expand Up @@ -65,6 +65,13 @@ def call(self, inputs, num_steps):
timestep = self._sigma_to_timestep(sigma)
return sigma, timestep

def add_noise(self, inputs, noises, step, num_steps):
sigma, _ = self(step, num_steps)
return ops.add(
ops.multiply(sigma, noises),
ops.multiply(ops.subtract(1.0, sigma), inputs),
)

def get_config(self):
config = super().get_config()
config.update(
Expand Down
150 changes: 57 additions & 93 deletions keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,6 @@
FlowMatchEulerDiscreteScheduler,
)
from keras_hub.src.models.stable_diffusion_3.mmdit import MMDiT
from keras_hub.src.models.stable_diffusion_3.vae_image_decoder import (
VAEImageDecoder,
)
from keras_hub.src.utils.keras_utils import standardize_data_format


Expand Down Expand Up @@ -159,48 +156,6 @@ def compute_output_shape(self, latents_shape):
return latents_shape


class LatentSpaceDecoder(layers.Layer):
"""Decoder to transform the latent space back to the original image space.
During decoding, the latents are transformed back to the original image
space using the equation: `latents / scale + shift`.
Args:
scale: float. The scaling factor.
shift: float. The shift factor.
**kwargs: other keyword arguments passed to `keras.layers.Layer`,
including `name`, `dtype` etc.
Call arguments:
latents: The latent tensor to be transformed.
Reference:
- [High-Resolution Image Synthesis with Latent Diffusion Models](
https://arxiv.org/abs/2112.10752).
"""

def __init__(self, scale, shift, **kwargs):
super().__init__(**kwargs)
self.scale = scale
self.shift = shift

def call(self, latents):
return ops.add(ops.divide(latents, self.scale), self.shift)

def get_config(self):
config = super().get_config()
config.update(
{
"scale": self.scale,
"shift": self.shift,
}
)
return config

def compute_output_shape(self, latents_shape):
return latents_shape


@keras_hub_export("keras_hub.models.StableDiffusion3Backbone")
class StableDiffusion3Backbone(Backbone):
"""Stable Diffusion 3 core network with hyperparameters.
Expand All @@ -222,24 +177,19 @@ class StableDiffusion3Backbone(Backbone):
transformer in MMDiT.
mmdit_position_size: int. The size of the height and width for the
position embedding in MMDiT.
vae_stackwise_num_filters: list of ints. The number of filters for each
stack in VAE.
vae_stackwise_num_blocks: list of ints. The number of blocks for each
stack in VAE.
clip_l: `keras_hub.models.CLIPTextEncoder`. The text encoder for
encoding the inputs.
clip_g: `keras_hub.models.CLIPTextEncoder`. The text encoder for
encoding the inputs.
t5: optional `keras_hub.models.T5Encoder`. The text encoder for
encoding the inputs.
vae: The VAE used for transformations between pixel space and latent
space.
clip_l: The CLIP text encoder for encoding the inputs.
clip_g: The CLIP text encoder for encoding the inputs.
t5: optional The T5 text encoder for encoding the inputs.
latent_channels: int. The number of channels in the latent. Defaults to
`16`.
output_channels: int. The number of channels in the output. Defaults to
`3`.
num_train_timesteps: int. The number of diffusion steps to train the
model. Defaults to `1000`.
shift: float. The shift value for the timestep schedule. Defaults to
`1.0`.
`3.0`.
height: optional int. The output height of the image.
width: optional int. The output width of the image.
data_format: `None` or str. If specified, either `"channels_last"` or
Expand All @@ -264,6 +214,7 @@ class StableDiffusion3Backbone(Backbone):
)
# Randomly initialized Stable Diffusion 3 model with custom config.
vae = keras_hub.models.VAEBackbone(...)
clip_l = keras_hub.models.CLIPTextEncoder(...)
clip_g = keras_hub.models.CLIPTextEncoder(...)
model = keras_hub.models.StableDiffusion3Backbone(
Expand All @@ -272,8 +223,7 @@ class StableDiffusion3Backbone(Backbone):
mmdit_hidden_dim=256,
mmdit_depth=4,
mmdit_position_size=192,
vae_stackwise_num_filters=[128, 128, 64, 32],
vae_stackwise_num_blocks=[1, 1, 1, 1],
vae=vae,
clip_l=clip_l,
clip_g=clip_g,
)
Expand All @@ -287,15 +237,14 @@ def __init__(
mmdit_num_layers,
mmdit_num_heads,
mmdit_position_size,
vae_stackwise_num_filters,
vae_stackwise_num_blocks,
vae,
clip_l,
clip_g,
t5=None,
latent_channels=16,
output_channels=3,
num_train_timesteps=1000,
shift=1.0,
shift=3.0,
height=None,
width=None,
data_format=None,
Expand All @@ -312,9 +261,11 @@ def __init__(
data_format = standardize_data_format(data_format)
if data_format != "channels_last":
raise NotImplementedError
latent_shape = (height // 8, width // 8, latent_channels)
image_shape = (height, width, int(vae.input_channels))
latent_shape = (height // 8, width // 8, int(latent_channels))
context_shape = (None, 4096 if t5 is None else t5.hidden_dim)
pooled_projection_shape = (clip_l.hidden_dim + clip_g.hidden_dim,)
self._latent_shape = latent_shape

# === Layers ===
self.clip_l = clip_l
Expand All @@ -341,15 +292,7 @@ def __init__(
dtype=dtype,
name="diffuser",
)
self.decoder = VAEImageDecoder(
vae_stackwise_num_filters,
vae_stackwise_num_blocks,
output_channels,
latent_shape=latent_shape,
data_format=data_format,
dtype=dtype,
name="decoder",
)
self.vae = vae
# Set `dtype="float32"` to ensure the high precision for the noise
# residual.
self.scheduler = FlowMatchEulerDiscreteScheduler(
Expand All @@ -365,14 +308,18 @@ def __init__(
dtype="float32", name="classifier_free_guidance"
)
self.euler_step = EulerStep(dtype="float32", name="euler_step")
self.latent_space_decoder = LatentSpaceDecoder(
scale=self.decoder.scaling_factor,
shift=self.decoder.shift_factor,
self.latent_rescaling = layers.Rescaling(
scale=1.0 / self.vae.scale,
offset=self.vae.shift,
dtype="float32",
name="latent_space_decoder",
name="latent_rescaling",
)

# === Functional Model ===
image_input = keras.Input(
shape=image_shape,
name="images",
)
latent_input = keras.Input(
shape=latent_shape,
name="latents",
Expand Down Expand Up @@ -428,17 +375,19 @@ def __init__(
dtype="float32",
name="guidance_scale",
)
embeddings = self.encode_step(token_ids, negative_token_ids)
embeddings = self.encode_text_step(token_ids, negative_token_ids)
latents = self.encode_image_step(image_input)
# Use `steps=0` to define the functional model.
latents = self.denoise_step(
denoised_latents = self.denoise_step(
latent_input,
embeddings,
0,
num_step_input[0],
guidance_scale_input[0],
)
outputs = self.decode_step(latents)
images = self.decode_step(denoised_latents)
inputs = {
"images": image_input,
"latents": latent_input,
"clip_l_token_ids": clip_l_token_id_input,
"clip_l_negative_token_ids": clip_l_negative_token_id_input,
Expand All @@ -447,6 +396,10 @@ def __init__(
"num_steps": num_step_input,
"guidance_scale": guidance_scale_input,
}
outputs = {
"latents": latents,
"images": images,
}
if self.t5 is not None:
inputs["t5_token_ids"] = t5_token_id_input
inputs["t5_negative_token_ids"] = t5_negative_token_id_input
Expand All @@ -463,8 +416,6 @@ def __init__(
self.mmdit_num_layers = mmdit_num_layers
self.mmdit_num_heads = mmdit_num_heads
self.mmdit_position_size = mmdit_position_size
self.vae_stackwise_num_filters = vae_stackwise_num_filters
self.vae_stackwise_num_blocks = vae_stackwise_num_blocks
self.latent_channels = latent_channels
self.output_channels = output_channels
self.num_train_timesteps = num_train_timesteps
Expand All @@ -474,7 +425,7 @@ def __init__(

@property
def latent_shape(self):
return (None,) + tuple(self.diffuser.latent_shape)
return (None,) + self._latent_shape

@property
def clip_hidden_dim(self):
Expand All @@ -484,7 +435,7 @@ def clip_hidden_dim(self):
def t5_hidden_dim(self):
return 4096 if self.t5 is None else self.t5.hidden_dim

def encode_step(self, token_ids, negative_token_ids):
def encode_text_step(self, token_ids, negative_token_ids):
clip_hidden_dim = self.clip_hidden_dim
t5_hidden_dim = self.t5_hidden_dim

Expand Down Expand Up @@ -537,18 +488,27 @@ def encode(token_ids):
negative_pooled_embeddings,
)

def encode_image_step(self, images):
latents = self.vae.encode(images)
return ops.multiply(
ops.subtract(latents, self.vae.shift), self.vae.scale
)

def add_noise_step(self, latents, noises, step, num_steps):
return self.scheduler.add_noise(latents, noises, step, num_steps)

def denoise_step(
self,
latents,
embeddings,
steps,
step,
num_steps,
guidance_scale,
):
steps = ops.convert_to_tensor(steps)
steps_next = ops.add(steps, 1)
sigma, timestep = self.scheduler(steps, num_steps)
sigma_next, _ = self.scheduler(steps_next, num_steps)
step = ops.convert_to_tensor(step)
next_step = ops.add(step, 1)
sigma, timestep = self.scheduler(step, num_steps)
next_sigma, _ = self.scheduler(next_step, num_steps)

# Concatenation for classifier-free guidance.
concated_latents, contexts, pooled_projs, timesteps = self.cfg_concat(
Expand All @@ -570,11 +530,11 @@ def denoise_step(
predicted_noise = self.cfg(predicted_noise, guidance_scale)

# Euler step.
return self.euler_step(latents, predicted_noise, sigma, sigma_next)
return self.euler_step(latents, predicted_noise, sigma, next_sigma)

def decode_step(self, latents):
latents = self.latent_space_decoder(latents)
return self.decoder(latents, training=False)
latents = self.latent_rescaling(latents)
return self.vae.decode(latents, training=False)

def get_config(self):
config = super().get_config()
Expand All @@ -585,8 +545,7 @@ def get_config(self):
"mmdit_num_layers": self.mmdit_num_layers,
"mmdit_num_heads": self.mmdit_num_heads,
"mmdit_position_size": self.mmdit_position_size,
"vae_stackwise_num_filters": self.vae_stackwise_num_filters,
"vae_stackwise_num_blocks": self.vae_stackwise_num_blocks,
"vae": layers.serialize(self.vae),
"clip_l": layers.serialize(self.clip_l),
"clip_g": layers.serialize(self.clip_g),
"t5": layers.serialize(self.t5),
Expand All @@ -607,6 +566,8 @@ def from_config(cls, config, custom_objects=None):
# Propagate `dtype` to text encoders if needed.
if "dtype" in config and config["dtype"] is not None:
dtype_config = config["dtype"]
if "dtype" not in config["vae"]["config"]:
config["vae"]["config"]["dtype"] = dtype_config
if "dtype" not in config["clip_l"]["config"]:
config["clip_l"]["config"]["dtype"] = dtype_config
if "dtype" not in config["clip_g"]["config"]:
Expand All @@ -617,7 +578,10 @@ def from_config(cls, config, custom_objects=None):
):
config["t5"]["config"]["dtype"] = dtype_config

# We expect `clip_l`, `clip_g` and/or `t5` to be instantiated.
# We expect `vae`, `clip_l`, `clip_g` and/or `t5` to be instantiated.
config["vae"] = layers.deserialize(
config["vae"], custom_objects=custom_objects
)
config["clip_l"] = layers.deserialize(
config["clip_l"], custom_objects=custom_objects
)
Expand Down
Loading

0 comments on commit 11227f3

Please sign in to comment.