Skip to content

Commit

Permalink
Add VAEBackbone and refactor SD3 to utilize VAE.
Browse files Browse the repository at this point in the history
  • Loading branch information
james77777778 committed Oct 1, 2024
1 parent 1ffc0d1 commit 22db582
Show file tree
Hide file tree
Showing 13 changed files with 1,185 additions and 516 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class FlowMatchEulerDiscreteScheduler(layers.Layer):
https://arxiv.org/abs/2403.03206).
"""

def __init__(self, num_train_timesteps=1000, shift=1.0, **kwargs):
def __init__(self, num_train_timesteps=1000, shift=3.0, **kwargs):
super().__init__(**kwargs)
self.num_train_timesteps = int(num_train_timesteps)
self.shift = float(shift)
Expand Down Expand Up @@ -65,6 +65,13 @@ def call(self, inputs, num_steps):
timestep = self._sigma_to_timestep(sigma)
return sigma, timestep

def scale_noise(self, inputs, noises, step, num_steps):
sigma, _ = self(step, num_steps)
return ops.add(
ops.multiply(sigma, noises),
ops.multiply(ops.subtract(1.0, sigma), inputs),
)

def get_config(self):
config = super().get_config()
config.update(
Expand Down
123 changes: 40 additions & 83 deletions keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,6 @@
FlowMatchEulerDiscreteScheduler,
)
from keras_hub.src.models.stable_diffusion_3.mmdit import MMDiT
from keras_hub.src.models.stable_diffusion_3.vae_image_decoder import (
VAEImageDecoder,
)
from keras_hub.src.utils.keras_utils import standardize_data_format


Expand Down Expand Up @@ -159,48 +156,6 @@ def compute_output_shape(self, latents_shape):
return latents_shape


class LatentSpaceDecoder(layers.Layer):
"""Decoder to transform the latent space back to the original image space.
During decoding, the latents are transformed back to the original image
space using the equation: `latents / scale + shift`.
Args:
scale: float. The scaling factor.
shift: float. The shift factor.
**kwargs: other keyword arguments passed to `keras.layers.Layer`,
including `name`, `dtype` etc.
Call arguments:
latents: The latent tensor to be transformed.
Reference:
- [High-Resolution Image Synthesis with Latent Diffusion Models](
https://arxiv.org/abs/2112.10752).
"""

def __init__(self, scale, shift, **kwargs):
super().__init__(**kwargs)
self.scale = scale
self.shift = shift

def call(self, latents):
return ops.add(ops.divide(latents, self.scale), self.shift)

def get_config(self):
config = super().get_config()
config.update(
{
"scale": self.scale,
"shift": self.shift,
}
)
return config

def compute_output_shape(self, latents_shape):
return latents_shape


@keras_hub_export("keras_hub.models.StableDiffusion3Backbone")
class StableDiffusion3Backbone(Backbone):
"""Stable Diffusion 3 core network with hyperparameters.
Expand All @@ -222,24 +177,19 @@ class StableDiffusion3Backbone(Backbone):
transformer in MMDiT.
mmdit_position_size: int. The size of the height and width for the
position embedding in MMDiT.
vae_stackwise_num_filters: list of ints. The number of filters for each
stack in VAE.
vae_stackwise_num_blocks: list of ints. The number of blocks for each
stack in VAE.
clip_l: `keras_hub.models.CLIPTextEncoder`. The text encoder for
encoding the inputs.
clip_g: `keras_hub.models.CLIPTextEncoder`. The text encoder for
encoding the inputs.
t5: optional `keras_hub.models.T5Encoder`. The text encoder for
encoding the inputs.
vae: The VAE used for transformations between pixel
space and latent space.
clip_l: The CLIP text encoder for encoding the inputs.
clip_g: The CLIP text encoder for encoding the inputs.
t5: optional The T5 text encoder for encoding the inputs.
latent_channels: int. The number of channels in the latent. Defaults to
`16`.
output_channels: int. The number of channels in the output. Defaults to
`3`.
num_train_timesteps: int. The number of diffusion steps to train the
model. Defaults to `1000`.
shift: float. The shift value for the timestep schedule. Defaults to
`1.0`.
`3.0`.
height: optional int. The output height of the image.
width: optional int. The output width of the image.
data_format: `None` or str. If specified, either `"channels_last"` or
Expand Down Expand Up @@ -287,15 +237,14 @@ def __init__(
mmdit_num_layers,
mmdit_num_heads,
mmdit_position_size,
vae_stackwise_num_filters,
vae_stackwise_num_blocks,
vae,
clip_l,
clip_g,
t5=None,
latent_channels=16,
output_channels=3,
num_train_timesteps=1000,
shift=1.0,
shift=3.0,
height=None,
width=None,
data_format=None,
Expand All @@ -312,6 +261,7 @@ def __init__(
data_format = standardize_data_format(data_format)
if data_format != "channels_last":
raise NotImplementedError
image_input = (height, width, vae.input_channels)
latent_shape = (height // 8, width // 8, latent_channels)
context_shape = (None, 4096 if t5 is None else t5.hidden_dim)
pooled_projection_shape = (clip_l.hidden_dim + clip_g.hidden_dim,)
Expand Down Expand Up @@ -341,15 +291,7 @@ def __init__(
dtype=dtype,
name="diffuser",
)
self.decoder = VAEImageDecoder(
vae_stackwise_num_filters,
vae_stackwise_num_blocks,
output_channels,
latent_shape=latent_shape,
data_format=data_format,
dtype=dtype,
name="decoder",
)
self.vae = vae
# Set `dtype="float32"` to ensure the high precision for the noise
# residual.
self.scheduler = FlowMatchEulerDiscreteScheduler(
Expand All @@ -365,14 +307,18 @@ def __init__(
dtype="float32", name="classifier_free_guidance"
)
self.euler_step = EulerStep(dtype="float32", name="euler_step")
self.latent_space_decoder = LatentSpaceDecoder(
scale=self.decoder.scaling_factor,
shift=self.decoder.shift_factor,
self.latent_rescaling = layers.Rescaling(
scale=1.0 / self.vae.scale,
offset=self.vae.shift,
dtype="float32",
name="latent_space_decoder",
name="latent_rescaling",
)

# === Functional Model ===
image_input = keras.Input(
shape=image_input,
name="images",
)
latent_input = keras.Input(
shape=latent_shape,
name="latents",
Expand Down Expand Up @@ -428,17 +374,19 @@ def __init__(
dtype="float32",
name="guidance_scale",
)
embeddings = self.encode_step(token_ids, negative_token_ids)
embeddings = self.encode_text_step(token_ids, negative_token_ids)
latents = self.encode_image_step(image_input)
# Use `steps=0` to define the functional model.
latents = self.denoise_step(
denoised_latents = self.denoise_step(
latent_input,
embeddings,
0,
num_step_input[0],
guidance_scale_input[0],
)
outputs = self.decode_step(latents)
images = self.decode_step(denoised_latents)
inputs = {
"images": image_input,
"latents": latent_input,
"clip_l_token_ids": clip_l_token_id_input,
"clip_l_negative_token_ids": clip_l_negative_token_id_input,
Expand All @@ -447,6 +395,10 @@ def __init__(
"num_steps": num_step_input,
"guidance_scale": guidance_scale_input,
}
outputs = {
"latents": latents,
"images": images,
}
if self.t5 is not None:
inputs["t5_token_ids"] = t5_token_id_input
inputs["t5_negative_token_ids"] = t5_negative_token_id_input
Expand All @@ -463,8 +415,6 @@ def __init__(
self.mmdit_num_layers = mmdit_num_layers
self.mmdit_num_heads = mmdit_num_heads
self.mmdit_position_size = mmdit_position_size
self.vae_stackwise_num_filters = vae_stackwise_num_filters
self.vae_stackwise_num_blocks = vae_stackwise_num_blocks
self.latent_channels = latent_channels
self.output_channels = output_channels
self.num_train_timesteps = num_train_timesteps
Expand All @@ -484,7 +434,7 @@ def clip_hidden_dim(self):
def t5_hidden_dim(self):
return 4096 if self.t5 is None else self.t5.hidden_dim

def encode_step(self, token_ids, negative_token_ids):
def encode_text_step(self, token_ids, negative_token_ids):
clip_hidden_dim = self.clip_hidden_dim
t5_hidden_dim = self.t5_hidden_dim

Expand Down Expand Up @@ -537,6 +487,9 @@ def encode(token_ids):
negative_pooled_embeddings,
)

def encode_image_step(self, images):
return self.vae.encode(images)

def denoise_step(
self,
latents,
Expand Down Expand Up @@ -573,8 +526,8 @@ def denoise_step(
return self.euler_step(latents, predicted_noise, sigma, sigma_next)

def decode_step(self, latents):
latents = self.latent_space_decoder(latents)
return self.decoder(latents, training=False)
latents = self.latent_rescaling(latents)
return self.vae.decode(latents, training=False)

def get_config(self):
config = super().get_config()
Expand All @@ -585,8 +538,7 @@ def get_config(self):
"mmdit_num_layers": self.mmdit_num_layers,
"mmdit_num_heads": self.mmdit_num_heads,
"mmdit_position_size": self.mmdit_position_size,
"vae_stackwise_num_filters": self.vae_stackwise_num_filters,
"vae_stackwise_num_blocks": self.vae_stackwise_num_blocks,
"vae": layers.serialize(self.vae),
"clip_l": layers.serialize(self.clip_l),
"clip_g": layers.serialize(self.clip_g),
"t5": layers.serialize(self.t5),
Expand All @@ -607,6 +559,8 @@ def from_config(cls, config, custom_objects=None):
# Propagate `dtype` to text encoders if needed.
if "dtype" in config and config["dtype"] is not None:
dtype_config = config["dtype"]
if "dtype" not in config["vae"]["config"]:
config["vae"]["config"]["dtype"] = dtype_config
if "dtype" not in config["clip_l"]["config"]:
config["clip_l"]["config"]["dtype"] = dtype_config
if "dtype" not in config["clip_g"]["config"]:
Expand All @@ -617,7 +571,10 @@ def from_config(cls, config, custom_objects=None):
):
config["t5"]["config"]["dtype"] = dtype_config

# We expect `clip_l`, `clip_g` and/or `t5` to be instantiated.
# We expect `vae`, `clip_l`, `clip_g` and/or `t5` to be instantiated.
config["vae"] = layers.deserialize(
config["vae"], custom_objects=custom_objects
)
config["clip_l"] = layers.deserialize(
config["clip_l"], custom_objects=custom_objects
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,22 @@
from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_backbone import (
StableDiffusion3Backbone,
)
from keras_hub.src.models.vae.vae_backbone import VAEBackbone
from keras_hub.src.tests.test_case import TestCase


class StableDiffusion3BackboneTest(TestCase):
def setUp(self):
height, width = 64, 64
vae = VAEBackbone(
[32, 32, 32, 32],
[1, 1, 1, 1],
[32, 32, 32, 32],
[1, 1, 1, 1],
# Use `mode` generate a deterministic output.
sampler_method="mode",
name="vae",
)
clip_l = CLIPTextEncoder(
20, 32, 32, 2, 2, 64, "quick_gelu", -2, name="clip_l"
)
Expand All @@ -22,15 +33,15 @@ def setUp(self):
"mmdit_num_layers": 2,
"mmdit_num_heads": 2,
"mmdit_position_size": 192,
"vae_stackwise_num_filters": [32, 32, 32, 32],
"vae_stackwise_num_blocks": [1, 1, 1, 1],
"vae": vae,
"clip_l": clip_l,
"clip_g": clip_g,
"height": 64,
"width": 64,
"height": height,
"width": width,
}
self.input_data = {
"latents": ops.ones((2, 8, 8, 16)),
"images": ops.ones((2, height, width, 3)),
"latents": ops.ones((2, height // 8, width // 8, 16)),
"clip_l_token_ids": ops.ones((2, 5), dtype="int32"),
"clip_l_negative_token_ids": ops.ones((2, 5), dtype="int32"),
"clip_g_token_ids": ops.ones((2, 5), dtype="int32"),
Expand All @@ -44,7 +55,10 @@ def test_backbone_basics(self):
cls=StableDiffusion3Backbone,
init_kwargs=self.init_kwargs,
input_data=self.input_data,
expected_output_shape=(2, 64, 64, 3),
expected_output_shape={
"images": (2, 64, 64, 3),
"latents": (2, 8, 8, 16),
},
# Since `clip_l` and `clip_g` were instantiated outside of
# `StableDiffusion3Backbone`, the mixed precision and
# quantization checks will fail.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@
"metadata": {
"description": (
"3 billion parameter, including CLIP L and CLIP G text "
"encoders, MMDiT generative model, and VAE decoder. "
"encoders, MMDiT generative model, and VAE autoencoder. "
"Developed by Stability AI."
),
"params": 2952806723,
"params": 2987080931,
"official_name": "StableDiffusion3",
"path": "stablediffusion3",
"model_card": "https://arxiv.org/abs/2110.00476",
},
"kaggle_handle": "kaggle://kerashub/stablediffusion3/keras/stable_diffusion_3_medium/1",
"kaggle_handle": "kaggle://kerashub/stablediffusion3/keras/stable_diffusion_3_medium/3",
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,9 @@ def generate_step(
the expense of lower image quality.
"""
# Encode inputs.
embeddings = self.backbone.encode_step(token_ids, negative_token_ids)
embeddings = self.backbone.encode_text_step(
token_ids, negative_token_ids
)

# Denoise.
def body_fun(step, latents):
Expand Down
Loading

0 comments on commit 22db582

Please sign in to comment.