From 22db58239133c1101f22a75b8aae70d428ad71a8 Mon Sep 17 00:00:00 2001 From: Hongyu Chiu <20734616+james77777778@users.noreply.github.com> Date: Sun, 29 Sep 2024 14:03:38 +0800 Subject: [PATCH] Add VAEBackbone and refactor SD3 to utilize VAE. --- .../flow_match_euler_discrete_scheduler.py | 9 +- .../stable_diffusion_3_backbone.py | 123 +-- .../stable_diffusion_3_backbone_test.py | 26 +- .../stable_diffusion_3_presets.py | 6 +- .../stable_diffusion_3_text_to_image.py | 4 +- .../stable_diffusion_3_text_to_image_test.py | 24 +- .../stable_diffusion_3/vae_image_decoder.py | 320 -------- keras_hub/src/models/task.py | 31 +- keras_hub/src/models/vae/__init__.py | 1 + keras_hub/src/models/vae/vae_backbone.py | 172 ++++ keras_hub/src/models/vae/vae_backbone_test.py | 35 + keras_hub/src/models/vae/vae_layers.py | 740 ++++++++++++++++++ .../convert_stable_diffusion_3_checkpoints.py | 210 +++-- 13 files changed, 1185 insertions(+), 516 deletions(-) delete mode 100644 keras_hub/src/models/stable_diffusion_3/vae_image_decoder.py create mode 100644 keras_hub/src/models/vae/__init__.py create mode 100644 keras_hub/src/models/vae/vae_backbone.py create mode 100644 keras_hub/src/models/vae/vae_backbone_test.py create mode 100644 keras_hub/src/models/vae/vae_layers.py diff --git a/keras_hub/src/models/stable_diffusion_3/flow_match_euler_discrete_scheduler.py b/keras_hub/src/models/stable_diffusion_3/flow_match_euler_discrete_scheduler.py index c80cf7d46..427d5d681 100644 --- a/keras_hub/src/models/stable_diffusion_3/flow_match_euler_discrete_scheduler.py +++ b/keras_hub/src/models/stable_diffusion_3/flow_match_euler_discrete_scheduler.py @@ -27,7 +27,7 @@ class FlowMatchEulerDiscreteScheduler(layers.Layer): https://arxiv.org/abs/2403.03206). """ - def __init__(self, num_train_timesteps=1000, shift=1.0, **kwargs): + def __init__(self, num_train_timesteps=1000, shift=3.0, **kwargs): super().__init__(**kwargs) self.num_train_timesteps = int(num_train_timesteps) self.shift = float(shift) @@ -65,6 +65,13 @@ def call(self, inputs, num_steps): timestep = self._sigma_to_timestep(sigma) return sigma, timestep + def scale_noise(self, inputs, noises, step, num_steps): + sigma, _ = self(step, num_steps) + return ops.add( + ops.multiply(sigma, noises), + ops.multiply(ops.subtract(1.0, sigma), inputs), + ) + def get_config(self): config = super().get_config() config.update( diff --git a/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone.py b/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone.py index 62f5890e7..0efca389b 100644 --- a/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone.py +++ b/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone.py @@ -8,9 +8,6 @@ FlowMatchEulerDiscreteScheduler, ) from keras_hub.src.models.stable_diffusion_3.mmdit import MMDiT -from keras_hub.src.models.stable_diffusion_3.vae_image_decoder import ( - VAEImageDecoder, -) from keras_hub.src.utils.keras_utils import standardize_data_format @@ -159,48 +156,6 @@ def compute_output_shape(self, latents_shape): return latents_shape -class LatentSpaceDecoder(layers.Layer): - """Decoder to transform the latent space back to the original image space. - - During decoding, the latents are transformed back to the original image - space using the equation: `latents / scale + shift`. - - Args: - scale: float. The scaling factor. - shift: float. The shift factor. - **kwargs: other keyword arguments passed to `keras.layers.Layer`, - including `name`, `dtype` etc. - - Call arguments: - latents: The latent tensor to be transformed. - - Reference: - - [High-Resolution Image Synthesis with Latent Diffusion Models]( - https://arxiv.org/abs/2112.10752). - """ - - def __init__(self, scale, shift, **kwargs): - super().__init__(**kwargs) - self.scale = scale - self.shift = shift - - def call(self, latents): - return ops.add(ops.divide(latents, self.scale), self.shift) - - def get_config(self): - config = super().get_config() - config.update( - { - "scale": self.scale, - "shift": self.shift, - } - ) - return config - - def compute_output_shape(self, latents_shape): - return latents_shape - - @keras_hub_export("keras_hub.models.StableDiffusion3Backbone") class StableDiffusion3Backbone(Backbone): """Stable Diffusion 3 core network with hyperparameters. @@ -222,16 +177,11 @@ class StableDiffusion3Backbone(Backbone): transformer in MMDiT. mmdit_position_size: int. The size of the height and width for the position embedding in MMDiT. - vae_stackwise_num_filters: list of ints. The number of filters for each - stack in VAE. - vae_stackwise_num_blocks: list of ints. The number of blocks for each - stack in VAE. - clip_l: `keras_hub.models.CLIPTextEncoder`. The text encoder for - encoding the inputs. - clip_g: `keras_hub.models.CLIPTextEncoder`. The text encoder for - encoding the inputs. - t5: optional `keras_hub.models.T5Encoder`. The text encoder for - encoding the inputs. + vae: The VAE used for transformations between pixel + space and latent space. + clip_l: The CLIP text encoder for encoding the inputs. + clip_g: The CLIP text encoder for encoding the inputs. + t5: optional The T5 text encoder for encoding the inputs. latent_channels: int. The number of channels in the latent. Defaults to `16`. output_channels: int. The number of channels in the output. Defaults to @@ -239,7 +189,7 @@ class StableDiffusion3Backbone(Backbone): num_train_timesteps: int. The number of diffusion steps to train the model. Defaults to `1000`. shift: float. The shift value for the timestep schedule. Defaults to - `1.0`. + `3.0`. height: optional int. The output height of the image. width: optional int. The output width of the image. data_format: `None` or str. If specified, either `"channels_last"` or @@ -287,15 +237,14 @@ def __init__( mmdit_num_layers, mmdit_num_heads, mmdit_position_size, - vae_stackwise_num_filters, - vae_stackwise_num_blocks, + vae, clip_l, clip_g, t5=None, latent_channels=16, output_channels=3, num_train_timesteps=1000, - shift=1.0, + shift=3.0, height=None, width=None, data_format=None, @@ -312,6 +261,7 @@ def __init__( data_format = standardize_data_format(data_format) if data_format != "channels_last": raise NotImplementedError + image_input = (height, width, vae.input_channels) latent_shape = (height // 8, width // 8, latent_channels) context_shape = (None, 4096 if t5 is None else t5.hidden_dim) pooled_projection_shape = (clip_l.hidden_dim + clip_g.hidden_dim,) @@ -341,15 +291,7 @@ def __init__( dtype=dtype, name="diffuser", ) - self.decoder = VAEImageDecoder( - vae_stackwise_num_filters, - vae_stackwise_num_blocks, - output_channels, - latent_shape=latent_shape, - data_format=data_format, - dtype=dtype, - name="decoder", - ) + self.vae = vae # Set `dtype="float32"` to ensure the high precision for the noise # residual. self.scheduler = FlowMatchEulerDiscreteScheduler( @@ -365,14 +307,18 @@ def __init__( dtype="float32", name="classifier_free_guidance" ) self.euler_step = EulerStep(dtype="float32", name="euler_step") - self.latent_space_decoder = LatentSpaceDecoder( - scale=self.decoder.scaling_factor, - shift=self.decoder.shift_factor, + self.latent_rescaling = layers.Rescaling( + scale=1.0 / self.vae.scale, + offset=self.vae.shift, dtype="float32", - name="latent_space_decoder", + name="latent_rescaling", ) # === Functional Model === + image_input = keras.Input( + shape=image_input, + name="images", + ) latent_input = keras.Input( shape=latent_shape, name="latents", @@ -428,17 +374,19 @@ def __init__( dtype="float32", name="guidance_scale", ) - embeddings = self.encode_step(token_ids, negative_token_ids) + embeddings = self.encode_text_step(token_ids, negative_token_ids) + latents = self.encode_image_step(image_input) # Use `steps=0` to define the functional model. - latents = self.denoise_step( + denoised_latents = self.denoise_step( latent_input, embeddings, 0, num_step_input[0], guidance_scale_input[0], ) - outputs = self.decode_step(latents) + images = self.decode_step(denoised_latents) inputs = { + "images": image_input, "latents": latent_input, "clip_l_token_ids": clip_l_token_id_input, "clip_l_negative_token_ids": clip_l_negative_token_id_input, @@ -447,6 +395,10 @@ def __init__( "num_steps": num_step_input, "guidance_scale": guidance_scale_input, } + outputs = { + "latents": latents, + "images": images, + } if self.t5 is not None: inputs["t5_token_ids"] = t5_token_id_input inputs["t5_negative_token_ids"] = t5_negative_token_id_input @@ -463,8 +415,6 @@ def __init__( self.mmdit_num_layers = mmdit_num_layers self.mmdit_num_heads = mmdit_num_heads self.mmdit_position_size = mmdit_position_size - self.vae_stackwise_num_filters = vae_stackwise_num_filters - self.vae_stackwise_num_blocks = vae_stackwise_num_blocks self.latent_channels = latent_channels self.output_channels = output_channels self.num_train_timesteps = num_train_timesteps @@ -484,7 +434,7 @@ def clip_hidden_dim(self): def t5_hidden_dim(self): return 4096 if self.t5 is None else self.t5.hidden_dim - def encode_step(self, token_ids, negative_token_ids): + def encode_text_step(self, token_ids, negative_token_ids): clip_hidden_dim = self.clip_hidden_dim t5_hidden_dim = self.t5_hidden_dim @@ -537,6 +487,9 @@ def encode(token_ids): negative_pooled_embeddings, ) + def encode_image_step(self, images): + return self.vae.encode(images) + def denoise_step( self, latents, @@ -573,8 +526,8 @@ def denoise_step( return self.euler_step(latents, predicted_noise, sigma, sigma_next) def decode_step(self, latents): - latents = self.latent_space_decoder(latents) - return self.decoder(latents, training=False) + latents = self.latent_rescaling(latents) + return self.vae.decode(latents, training=False) def get_config(self): config = super().get_config() @@ -585,8 +538,7 @@ def get_config(self): "mmdit_num_layers": self.mmdit_num_layers, "mmdit_num_heads": self.mmdit_num_heads, "mmdit_position_size": self.mmdit_position_size, - "vae_stackwise_num_filters": self.vae_stackwise_num_filters, - "vae_stackwise_num_blocks": self.vae_stackwise_num_blocks, + "vae": layers.serialize(self.vae), "clip_l": layers.serialize(self.clip_l), "clip_g": layers.serialize(self.clip_g), "t5": layers.serialize(self.t5), @@ -607,6 +559,8 @@ def from_config(cls, config, custom_objects=None): # Propagate `dtype` to text encoders if needed. if "dtype" in config and config["dtype"] is not None: dtype_config = config["dtype"] + if "dtype" not in config["vae"]["config"]: + config["vae"]["config"]["dtype"] = dtype_config if "dtype" not in config["clip_l"]["config"]: config["clip_l"]["config"]["dtype"] = dtype_config if "dtype" not in config["clip_g"]["config"]: @@ -617,7 +571,10 @@ def from_config(cls, config, custom_objects=None): ): config["t5"]["config"]["dtype"] = dtype_config - # We expect `clip_l`, `clip_g` and/or `t5` to be instantiated. + # We expect `vae`, `clip_l`, `clip_g` and/or `t5` to be instantiated. + config["vae"] = layers.deserialize( + config["vae"], custom_objects=custom_objects + ) config["clip_l"] = layers.deserialize( config["clip_l"], custom_objects=custom_objects ) diff --git a/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone_test.py b/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone_test.py index d0a5a82a9..b942ef061 100644 --- a/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone_test.py +++ b/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone_test.py @@ -5,11 +5,22 @@ from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_backbone import ( StableDiffusion3Backbone, ) +from keras_hub.src.models.vae.vae_backbone import VAEBackbone from keras_hub.src.tests.test_case import TestCase class StableDiffusion3BackboneTest(TestCase): def setUp(self): + height, width = 64, 64 + vae = VAEBackbone( + [32, 32, 32, 32], + [1, 1, 1, 1], + [32, 32, 32, 32], + [1, 1, 1, 1], + # Use `mode` generate a deterministic output. + sampler_method="mode", + name="vae", + ) clip_l = CLIPTextEncoder( 20, 32, 32, 2, 2, 64, "quick_gelu", -2, name="clip_l" ) @@ -22,15 +33,15 @@ def setUp(self): "mmdit_num_layers": 2, "mmdit_num_heads": 2, "mmdit_position_size": 192, - "vae_stackwise_num_filters": [32, 32, 32, 32], - "vae_stackwise_num_blocks": [1, 1, 1, 1], + "vae": vae, "clip_l": clip_l, "clip_g": clip_g, - "height": 64, - "width": 64, + "height": height, + "width": width, } self.input_data = { - "latents": ops.ones((2, 8, 8, 16)), + "images": ops.ones((2, height, width, 3)), + "latents": ops.ones((2, height // 8, width // 8, 16)), "clip_l_token_ids": ops.ones((2, 5), dtype="int32"), "clip_l_negative_token_ids": ops.ones((2, 5), dtype="int32"), "clip_g_token_ids": ops.ones((2, 5), dtype="int32"), @@ -44,7 +55,10 @@ def test_backbone_basics(self): cls=StableDiffusion3Backbone, init_kwargs=self.init_kwargs, input_data=self.input_data, - expected_output_shape=(2, 64, 64, 3), + expected_output_shape={ + "images": (2, 64, 64, 3), + "latents": (2, 8, 8, 16), + }, # Since `clip_l` and `clip_g` were instantiated outside of # `StableDiffusion3Backbone`, the mixed precision and # quantization checks will fail. diff --git a/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_presets.py b/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_presets.py index a05534611..2067fdb8d 100644 --- a/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_presets.py +++ b/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_presets.py @@ -5,14 +5,14 @@ "metadata": { "description": ( "3 billion parameter, including CLIP L and CLIP G text " - "encoders, MMDiT generative model, and VAE decoder. " + "encoders, MMDiT generative model, and VAE autoencoder. " "Developed by Stability AI." ), - "params": 2952806723, + "params": 2987080931, "official_name": "StableDiffusion3", "path": "stablediffusion3", "model_card": "https://arxiv.org/abs/2110.00476", }, - "kaggle_handle": "kaggle://kerashub/stablediffusion3/keras/stable_diffusion_3_medium/1", + "kaggle_handle": "kaggle://kerashub/stablediffusion3/keras/stable_diffusion_3_medium/3", } } diff --git a/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image.py b/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image.py index 514ee5d10..9a32d8a68 100644 --- a/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image.py +++ b/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image.py @@ -104,7 +104,9 @@ def generate_step( the expense of lower image quality. """ # Encode inputs. - embeddings = self.backbone.encode_step(token_ids, negative_token_ids) + embeddings = self.backbone.encode_text_step( + token_ids, negative_token_ids + ) # Denoise. def body_fun(step, latents): diff --git a/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image_test.py b/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image_test.py index b3cbf9ec0..837c95fa3 100644 --- a/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image_test.py +++ b/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image_test.py @@ -14,6 +14,7 @@ from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_text_to_image_preprocessor import ( StableDiffusion3TextToImagePreprocessor, ) +from keras_hub.src.models.vae.vae_backbone import VAEBackbone from keras_hub.src.tests.test_case import TestCase @@ -39,23 +40,31 @@ def setUp(self): mmdit_num_layers=2, mmdit_num_heads=2, mmdit_position_size=192, - vae_stackwise_num_filters=[32, 32, 32, 32], - vae_stackwise_num_blocks=[1, 1, 1, 1], + vae=VAEBackbone( + [32, 32, 32, 32], + [1, 1, 1, 1], + [32, 32, 32, 32], + [1, 1, 1, 1], + # Use `mode` generate a deterministic output. + sampler_method="mode", + name="vae", + ), clip_l=CLIPTextEncoder( 20, 64, 64, 2, 2, 128, "quick_gelu", -2, name="clip_l" ), clip_g=CLIPTextEncoder( 20, 128, 128, 2, 2, 256, "gelu", -2, name="clip_g" ), - height=128, - width=128, + height=64, + width=64, ) self.init_kwargs = { "preprocessor": self.preprocessor, "backbone": self.backbone, } self.input_data = { - "latents": ops.ones((2, 16, 16, 16)), + "images": ops.ones((2, 64, 64, 3)), + "latents": ops.ones((2, 8, 8, 16)), "clip_l_token_ids": ops.ones((2, 5), dtype="int32"), "clip_l_negative_token_ids": ops.ones((2, 5), dtype="int32"), "clip_g_token_ids": ops.ones((2, 5), dtype="int32"), @@ -72,7 +81,10 @@ def test_text_to_image_basics(self): cls=StableDiffusion3TextToImage, init_kwargs=self.init_kwargs, train_data=None, - expected_output_shape=(2, 128, 128, 3), + expected_output_shape={ + "images": (2, 64, 64, 3), + "latents": (2, 8, 8, 16), + }, ) def test_generate(self): diff --git a/keras_hub/src/models/stable_diffusion_3/vae_image_decoder.py b/keras_hub/src/models/stable_diffusion_3/vae_image_decoder.py deleted file mode 100644 index 5df9790bb..000000000 --- a/keras_hub/src/models/stable_diffusion_3/vae_image_decoder.py +++ /dev/null @@ -1,320 +0,0 @@ -import math - -from keras import layers -from keras import ops - -from keras_hub.src.models.backbone import Backbone -from keras_hub.src.utils.keras_utils import standardize_data_format - - -class VAEAttention(layers.Layer): - def __init__(self, filters, groups=32, data_format=None, **kwargs): - super().__init__(**kwargs) - self.filters = filters - self.data_format = standardize_data_format(data_format) - gn_axis = -1 if self.data_format == "channels_last" else 1 - - self.group_norm = layers.GroupNormalization( - groups=groups, - axis=gn_axis, - epsilon=1e-6, - dtype="float32", - name="group_norm", - ) - self.query_conv2d = layers.Conv2D( - filters, - 1, - 1, - data_format=self.data_format, - dtype=self.dtype_policy, - name="query_conv2d", - ) - self.key_conv2d = layers.Conv2D( - filters, - 1, - 1, - data_format=self.data_format, - dtype=self.dtype_policy, - name="key_conv2d", - ) - self.value_conv2d = layers.Conv2D( - filters, - 1, - 1, - data_format=self.data_format, - dtype=self.dtype_policy, - name="value_conv2d", - ) - self.softmax = layers.Softmax(dtype="float32") - self.output_conv2d = layers.Conv2D( - filters, - 1, - 1, - data_format=self.data_format, - dtype=self.dtype_policy, - name="output_conv2d", - ) - - self.groups = groups - self._inverse_sqrt_filters = 1.0 / math.sqrt(float(filters)) - - def build(self, input_shape): - self.group_norm.build(input_shape) - self.query_conv2d.build(input_shape) - self.key_conv2d.build(input_shape) - self.value_conv2d.build(input_shape) - self.output_conv2d.build(input_shape) - - def call(self, inputs, training=None): - x = self.group_norm(inputs) - query = self.query_conv2d(x) - key = self.key_conv2d(x) - value = self.value_conv2d(x) - - if self.data_format == "channels_first": - query = ops.transpose(query, (0, 2, 3, 1)) - key = ops.transpose(key, (0, 2, 3, 1)) - value = ops.transpose(value, (0, 2, 3, 1)) - shape = ops.shape(inputs) - b = shape[0] - query = ops.reshape(query, (b, -1, self.filters)) - key = ops.reshape(key, (b, -1, self.filters)) - value = ops.reshape(value, (b, -1, self.filters)) - - # Compute attention. - query = ops.multiply( - query, ops.cast(self._inverse_sqrt_filters, query.dtype) - ) - # [B, H0 * W0, C], [B, H1 * W1, C] -> [B, H0 * W0, H1 * W1] - attention_scores = ops.einsum("abc,adc->abd", query, key) - attention_scores = ops.cast( - self.softmax(attention_scores), self.compute_dtype - ) - # [B, H2 * W2, C], [B, H0 * W0, H1 * W1] -> [B, H1 * W1 ,C] - attention_output = ops.einsum("abc,adb->adc", value, attention_scores) - x = ops.reshape(attention_output, shape) - - x = self.output_conv2d(x) - if self.data_format == "channels_first": - x = ops.transpose(x, (0, 3, 1, 2)) - x = ops.add(x, inputs) - return x - - def get_config(self): - config = super().get_config() - config.update( - { - "filters": self.filters, - "groups": self.groups, - } - ) - return config - - def compute_output_shape(self, input_shape): - return input_shape - - -def apply_resnet_block(x, filters, data_format=None, dtype=None, name=None): - data_format = standardize_data_format(data_format) - gn_axis = -1 if data_format == "channels_last" else 1 - input_filters = x.shape[gn_axis] - - residual = x - x = layers.GroupNormalization( - groups=32, - axis=gn_axis, - epsilon=1e-6, - dtype="float32", - name=f"{name}_norm1", - )(x) - x = layers.Activation("swish", dtype=dtype)(x) - x = layers.Conv2D( - filters, - 3, - 1, - padding="same", - data_format=data_format, - dtype=dtype, - name=f"{name}_conv1", - )(x) - x = layers.GroupNormalization( - groups=32, - axis=gn_axis, - epsilon=1e-6, - dtype="float32", - name=f"{name}_norm2", - )(x) - x = layers.Activation("swish", dtype=dtype)(x) - x = layers.Conv2D( - filters, - 3, - 1, - padding="same", - data_format=data_format, - dtype=dtype, - name=f"{name}_conv2", - )(x) - if input_filters != filters: - residual = layers.Conv2D( - filters, - 1, - 1, - data_format=data_format, - dtype=dtype, - name=f"{name}_residual_projection", - )(residual) - x = layers.Add(dtype=dtype)([residual, x]) - return x - - -class VAEImageDecoder(Backbone): - """Decoder for the VAE model used in Stable Diffusion 3. - - Args: - stackwise_num_filters: list of ints. The number of filters for each - stack. - stackwise_num_blocks: list of ints. The number of blocks for each stack. - output_channels: int. The number of channels in the output. - latent_shape: tuple. The shape of the latent image. - data_format: `None` or str. If specified, either `"channels_last"` or - `"channels_first"`. The ordering of the dimensions in the - inputs. `"channels_last"` corresponds to inputs with shape - `(batch_size, height, width, channels)` - while `"channels_first"` corresponds to inputs with shape - `(batch_size, channels, height, width)`. It defaults to the - `image_data_format` value found in your Keras config file at - `~/.keras/keras.json`. If you never set it, then it will be - `"channels_last"`. - dtype: `None` or str or `keras.mixed_precision.DTypePolicy`. The dtype - to use for the model's computations and weights. - """ - - def __init__( - self, - stackwise_num_filters, - stackwise_num_blocks, - output_channels=3, - latent_shape=(None, None, 16), - data_format=None, - dtype=None, - **kwargs, - ): - data_format = standardize_data_format(data_format) - gn_axis = -1 if data_format == "channels_last" else 1 - - # === Functional Model === - latent_inputs = layers.Input(shape=latent_shape) - - x = layers.Conv2D( - stackwise_num_filters[0], - 3, - 1, - padding="same", - data_format=data_format, - dtype=dtype, - name="input_projection", - )(latent_inputs) - x = apply_resnet_block( - x, - stackwise_num_filters[0], - data_format=data_format, - dtype=dtype, - name="input_block0", - ) - x = VAEAttention( - stackwise_num_filters[0], - data_format=data_format, - dtype=dtype, - name="input_attention", - )(x) - x = apply_resnet_block( - x, - stackwise_num_filters[0], - data_format=data_format, - dtype=dtype, - name="input_block1", - ) - - # Stacks. - for i, filters in enumerate(stackwise_num_filters): - for j in range(stackwise_num_blocks[i]): - x = apply_resnet_block( - x, - filters, - data_format=data_format, - dtype=dtype, - name=f"block{i}_{j}", - ) - if i != len(stackwise_num_filters) - 1: - # No upsamling in the last blcok. - x = layers.UpSampling2D( - 2, - data_format=data_format, - dtype=dtype, - name=f"upsample_{i}", - )(x) - x = layers.Conv2D( - filters, - 3, - 1, - padding="same", - data_format=data_format, - dtype=dtype, - name=f"upsample_{i}_conv", - )(x) - - # Ouput block. - x = layers.GroupNormalization( - groups=32, - axis=gn_axis, - epsilon=1e-6, - dtype="float32", - name="output_norm", - )(x) - x = layers.Activation("swish", dtype=dtype, name="output_activation")(x) - image_outputs = layers.Conv2D( - output_channels, - 3, - 1, - padding="same", - data_format=data_format, - dtype=dtype, - name="output_projection", - )(x) - super().__init__(inputs=latent_inputs, outputs=image_outputs, **kwargs) - - # === Config === - self.stackwise_num_filters = stackwise_num_filters - self.stackwise_num_blocks = stackwise_num_blocks - self.output_channels = output_channels - self.latent_shape = latent_shape - - @property - def scaling_factor(self): - """The scaling factor for the latent space. - - This is used to scale the latent space to have unit variance when - training the diffusion model. - """ - return 1.5305 - - @property - def shift_factor(self): - """The shift factor for the latent space. - - This is used to shift the latent space to have zero mean when - training the diffusion model. - """ - return 0.0609 - - def get_config(self): - config = super().get_config() - config.update( - { - "stackwise_num_filters": self.stackwise_num_filters, - "stackwise_num_blocks": self.stackwise_num_blocks, - "output_channels": self.output_channels, - "image_shape": self.latent_shape, - } - ) - return config diff --git a/keras_hub/src/models/task.py b/keras_hub/src/models/task.py index 080c67c22..73c551ab4 100644 --- a/keras_hub/src/models/task.py +++ b/keras_hub/src/models/task.py @@ -6,6 +6,9 @@ from rich import table as rich_table from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.layers.preprocessing.audio_converter import AudioConverter +from keras_hub.src.layers.preprocessing.image_converter import ImageConverter +from keras_hub.src.tokenizers.tokenizer import Tokenizer from keras_hub.src.utils.keras_utils import print_msg from keras_hub.src.utils.pipeline_model import PipelineModel from keras_hub.src.utils.preset_utils import TASK_CONFIG_FILE @@ -327,21 +330,19 @@ def add_layer(layer, info): info, ) - tokenizer = self.preprocessor.tokenizer - if tokenizer: - info = "Vocab size: " - info += highlight_number(tokenizer.vocabulary_size()) - add_layer(tokenizer, info) - image_converter = self.preprocessor.image_converter - if image_converter: - info = "Image size: " - info += highlight_shape(image_converter.image_size()) - add_layer(image_converter, info) - audio_converter = self.preprocessor.audio_converter - if audio_converter: - info = "Audio shape: " - info += highlight_shape(audio_converter.audio_shape()) - add_layer(audio_converter, info) + for layer in self.preprocessor._flatten_layers(include_self=False): + if isinstance(layer, Tokenizer): + info = "Vocab size: " + info += highlight_number(layer.vocabulary_size()) + add_layer(layer, info) + elif isinstance(layer, ImageConverter): + info = "Image size: " + info += highlight_shape(layer.image_size()) + add_layer(layer, info) + elif isinstance(layer, AudioConverter): + info = "Audio shape: " + info += highlight_shape(layer.audio_shape()) + add_layer(layer, info) # Print the to the console. preprocessor_name = markup.escape(self.preprocessor.name) diff --git a/keras_hub/src/models/vae/__init__.py b/keras_hub/src/models/vae/__init__.py new file mode 100644 index 000000000..9f6cf4a62 --- /dev/null +++ b/keras_hub/src/models/vae/__init__.py @@ -0,0 +1 @@ +from keras_hub.src.models.vae.vae_backbone import VAEBackbone diff --git a/keras_hub/src/models/vae/vae_backbone.py b/keras_hub/src/models/vae/vae_backbone.py new file mode 100644 index 000000000..c84986314 --- /dev/null +++ b/keras_hub/src/models/vae/vae_backbone.py @@ -0,0 +1,172 @@ +import keras + +from keras_hub.src.models.backbone import Backbone +from keras_hub.src.models.vae.vae_layers import ( + DiagonalGaussianDistributionSampler, +) +from keras_hub.src.models.vae.vae_layers import VAEDecoder +from keras_hub.src.models.vae.vae_layers import VAEEncoder +from keras_hub.src.utils.keras_utils import standardize_data_format + + +class VAEBackbone(Backbone): + """VAE backbone used in latent diffusion models. + + When encoding, this model generates mean and log variance of the input + images. When decoding, it reconstructs images from the latent space. + + Args: + encoder_num_filters: list of ints. The number of filters for each + block in encoder. + encoder_num_blocks: list of ints. The number of blocks for each block in + encoder. + decoder_num_filters: list of ints. The number of filters for each + block in decoder. + decoder_num_blocks: list of ints. The number of blocks for each block in + decoder. + sampler_method: str. The method of the sampler for the intermediate + output. Available methods are `"sample"` and `"mode"`. `"sample"` + draws from the distribution using both the mean and log variance. + `"mode"` draws from the distribution using the mean only. Defaults + to `sample`. + input_channels: int. The number of channels in the input. + sample_channels: int. The number of channels in the sample. Typically, + this indicates the intermediate output of VAE, which is mean and + log variance. + output_channels: int. The number of channels in the output. + scale: float. The scaling factor applied to the latent space to ensure + it has unit variance during training of the diffusion model. + Defaults to `1.5305`, which is the value used in Stable Diffusion 3. + shift: float. The shift factor applied to the latent space to ensure it + has zero mean during training of the diffusion model. Defaults to + `0.0609`, which is the value used in Stable Diffusion 3. + data_format: `None` or str. If specified, either `"channels_last"` or + `"channels_first"`. The ordering of the dimensions in the + inputs. `"channels_last"` corresponds to inputs with shape + `(batch_size, height, width, channels)` + while `"channels_first"` corresponds to inputs with shape + `(batch_size, channels, height, width)`. It defaults to the + `image_data_format` value found in your Keras config file at + `~/.keras/keras.json`. If you never set it, then it will be + `"channels_last"`. + dtype: `None` or str or `keras.mixed_precision.DTypePolicy`. The dtype + to use for the model's computations and weights. + """ + + def __init__( + self, + encoder_num_filters, + encoder_num_blocks, + decoder_num_filters, + decoder_num_blocks, + sampler_method="sample", + input_channels=3, + sample_channels=32, + output_channels=3, + scale=1.5305, + shift=0.0609, + data_format=None, + dtype=None, + **kwargs, + ): + data_format = standardize_data_format(data_format) + if data_format == "channels_last": + image_shape = (None, None, input_channels) + channel_axis = -1 + else: + image_shape = (input_channels, None, None) + channel_axis = 1 + + # === Layers === + self.encoder = VAEEncoder( + encoder_num_filters, + encoder_num_blocks, + output_channels=sample_channels, + data_format=data_format, + dtype=dtype, + name="encoder", + ) + # Use `sample()` to define the functional model. + self.distribution_sampler = DiagonalGaussianDistributionSampler( + method=sampler_method, + axis=channel_axis, + dtype=dtype, + name="distribution_sampler", + ) + self.decoder = VAEDecoder( + decoder_num_filters, + decoder_num_blocks, + output_channels=output_channels, + data_format=data_format, + dtype=dtype, + name="decoder", + ) + + # === Functional Model === + image_input = keras.Input(shape=image_shape) + sample = self.encoder(image_input) + latent = self.distribution_sampler(sample) + image_output = self.decoder(latent) + super().__init__( + inputs=image_input, + outputs=image_output, + dtype=dtype, + **kwargs, + ) + + # === Config === + self.encoder_num_filters = encoder_num_filters + self.encoder_num_blocks = encoder_num_blocks + self.decoder_num_filters = decoder_num_filters + self.decoder_num_blocks = decoder_num_blocks + self.sampler_method = sampler_method + self.input_channels = input_channels + self.sample_channels = sample_channels + self.output_channels = output_channels + self._scale = scale + self._shift = shift + + @property + def scale(self): + """The scaling factor for the latent space. + + This is used to scale the latent space to have unit variance when + training the diffusion model. + """ + return self._scale + + @property + def shift(self): + """The shift factor for the latent space. + + This is used to shift the latent space to have zero mean when + training the diffusion model. + """ + return self._shift + + def encode(self, inputs, **kwargs): + """Encode the input images into latent space.""" + sample = self.encoder(inputs, **kwargs) + return self.distribution_sampler(sample) + + def decode(self, inputs, **kwargs): + """Decode the input latent space into images.""" + return self.decoder(inputs, **kwargs) + + def get_config(self): + config = super().get_config() + config.update( + { + "encoder_num_filters": self.encoder_num_filters, + "encoder_num_blocks": self.encoder_num_blocks, + "decoder_num_filters": self.decoder_num_filters, + "decoder_num_blocks": self.decoder_num_blocks, + "sampler_method": self.sampler_method, + "input_channels": self.input_channels, + "sample_channels": self.sample_channels, + "output_channels": self.output_channels, + "scale": self.scale, + "shift": self.shift, + } + ) + return config diff --git a/keras_hub/src/models/vae/vae_backbone_test.py b/keras_hub/src/models/vae/vae_backbone_test.py new file mode 100644 index 000000000..f5bd6f27a --- /dev/null +++ b/keras_hub/src/models/vae/vae_backbone_test.py @@ -0,0 +1,35 @@ +import pytest +from keras import ops + +from keras_hub.src.models.vae.vae_backbone import VAEBackbone +from keras_hub.src.tests.test_case import TestCase + + +class VAEBackboneTest(TestCase): + def setUp(self): + self.height, self.width = 64, 64 + self.init_kwargs = { + "encoder_num_filters": [32, 32, 32, 32], + "encoder_num_blocks": [1, 1, 1, 1], + "decoder_num_filters": [32, 32, 32, 32], + "decoder_num_blocks": [1, 1, 1, 1], + # Use `mode` generate a deterministic output. + "sampler_method": "mode", + } + self.input_data = ops.ones((2, self.height, self.width, 3)) + + def test_backbone_basics(self): + self.run_backbone_test( + cls=VAEBackbone, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + expected_output_shape=(2, self.height, self.width, 3), + ) + + @pytest.mark.large + def test_saved_model(self): + self.run_model_saving_test( + cls=VAEBackbone, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + ) diff --git a/keras_hub/src/models/vae/vae_layers.py b/keras_hub/src/models/vae/vae_layers.py new file mode 100644 index 000000000..0f5ad82c4 --- /dev/null +++ b/keras_hub/src/models/vae/vae_layers.py @@ -0,0 +1,740 @@ +import math + +import keras +from keras import ops + +from keras_hub.src.utils.keras_utils import standardize_data_format + + +class Conv2DMultiHeadAttention(keras.layers.Layer): + """A MultiHeadAttention layer utilizing `Conv2D` and `GroupNormalization`. + + Args: + filters: int. The number of the filters for the convolutional layers. + groups: int. The number of the groups for the group normalization + layers. Defaults to `32`. + data_format: `None` or str. If specified, either `"channels_last"` or + `"channels_first"`. The ordering of the dimensions in the + inputs. `"channels_last"` corresponds to inputs with shape + `(batch_size, height, width, channels)` + while `"channels_first"` corresponds to inputs with shape + `(batch_size, channels, height, width)`. It defaults to the + `image_data_format` value found in your Keras config file at + `~/.keras/keras.json`. If you never set it, then it will be + `"channels_last"`. + **kwargs: other keyword arguments passed to `keras.layers.Layer`, + including `name`, `dtype` etc. + """ + + def __init__(self, filters, groups=32, data_format=None, **kwargs): + super().__init__(**kwargs) + data_format = standardize_data_format(data_format) + channel_axis = -1 if data_format == "channels_last" else 1 + self.filters = int(filters) + self.groups = int(groups) + self._inverse_sqrt_filters = 1.0 / math.sqrt(float(filters)) + self.data_format = data_format + + self.group_norm = keras.layers.GroupNormalization( + groups=groups, + axis=channel_axis, + epsilon=1e-6, + dtype=self.dtype_policy, + name="group_norm", + ) + self.query_conv2d = keras.layers.Conv2D( + filters, + 1, + 1, + data_format=data_format, + dtype=self.dtype_policy, + name="query_conv2d", + ) + self.key_conv2d = keras.layers.Conv2D( + filters, + 1, + 1, + data_format=data_format, + dtype=self.dtype_policy, + name="key_conv2d", + ) + self.value_conv2d = keras.layers.Conv2D( + filters, + 1, + 1, + data_format=data_format, + dtype=self.dtype_policy, + name="value_conv2d", + ) + self.softmax = keras.layers.Softmax(dtype="float32") + self.output_conv2d = keras.layers.Conv2D( + filters, + 1, + 1, + data_format=data_format, + dtype=self.dtype_policy, + name="output_conv2d", + ) + + def build(self, input_shape): + self.group_norm.build(input_shape) + self.query_conv2d.build(input_shape) + self.key_conv2d.build(input_shape) + self.value_conv2d.build(input_shape) + self.output_conv2d.build(input_shape) + + def call(self, inputs, training=None): + x = self.group_norm(inputs, training=training) + query = self.query_conv2d(x, training=training) + key = self.key_conv2d(x, training=training) + value = self.value_conv2d(x, training=training) + + if self.data_format == "channels_first": + query = ops.transpose(query, (0, 2, 3, 1)) + key = ops.transpose(key, (0, 2, 3, 1)) + value = ops.transpose(value, (0, 2, 3, 1)) + shape = ops.shape(inputs) + b = shape[0] + query = ops.reshape(query, (b, -1, self.filters)) + key = ops.reshape(key, (b, -1, self.filters)) + value = ops.reshape(value, (b, -1, self.filters)) + + # Compute attention. + query = ops.multiply( + query, ops.cast(self._inverse_sqrt_filters, query.dtype) + ) + # [B, H0 * W0, C], [B, H1 * W1, C] -> [B, H0 * W0, H1 * W1] + attention_scores = ops.einsum("abc,adc->abd", query, key) + attention_scores = ops.cast( + self.softmax(attention_scores), self.compute_dtype + ) + # [B, H2 * W2, C], [B, H0 * W0, H1 * W1] -> [B, H1 * W1 ,C] + attention_output = ops.einsum("abc,adb->adc", value, attention_scores) + x = ops.reshape(attention_output, shape) + + x = self.output_conv2d(x, training=training) + if self.data_format == "channels_first": + x = ops.transpose(x, (0, 3, 1, 2)) + x = ops.add(x, inputs) + return x + + def get_config(self): + config = super().get_config() + config.update( + { + "filters": self.filters, + "groups": self.groups, + } + ) + return config + + def compute_output_shape(self, input_shape): + return input_shape + + +class ResNetBlock(keras.layers.Layer): + """A ResNet block utilizing `GroupNormalization` and SiLU activation. + + Args: + filters: The number of filters in the block. + has_residual_projection: Whether to add a projection layer for the + residual connection. Defaults to `False`. + data_format: `None` or str. If specified, either `"channels_last"` or + `"channels_first"`. The ordering of the dimensions in the + inputs. `"channels_last"` corresponds to inputs with shape + `(batch_size, height, width, channels)` + while `"channels_first"` corresponds to inputs with shape + `(batch_size, channels, height, width)`. It defaults to the + `image_data_format` value found in your Keras config file at + `~/.keras/keras.json`. If you never set it, then it will be + `"channels_last"`. + **kwargs: other keyword arguments passed to `keras.layers.Layer`, + including `name`, `dtype` etc. + """ + + def __init__( + self, + filters, + has_residual_projection=False, + data_format=None, + **kwargs, + ): + + super().__init__(**kwargs) + data_format = standardize_data_format(data_format) + channel_axis = -1 if data_format == "channels_last" else 1 + self.filters = int(filters) + self.has_residual_projection = bool(has_residual_projection) + + # === Layers === + self.norm1 = keras.layers.GroupNormalization( + groups=32, + axis=channel_axis, + epsilon=1e-6, + dtype=self.dtype_policy, + name="norm1", + ) + self.act1 = keras.layers.Activation("silu", dtype=self.dtype_policy) + self.conv1 = keras.layers.Conv2D( + filters, + 3, + 1, + padding="same", + data_format=data_format, + dtype=self.dtype_policy, + name="conv1", + ) + self.norm2 = keras.layers.GroupNormalization( + groups=32, + axis=channel_axis, + epsilon=1e-6, + dtype=self.dtype_policy, + name="norm2", + ) + self.act2 = keras.layers.Activation("silu", dtype=self.dtype_policy) + self.conv2 = keras.layers.Conv2D( + filters, + 3, + 1, + padding="same", + data_format=data_format, + dtype=self.dtype_policy, + name="conv2", + ) + if self.has_residual_projection: + self.residual_projection = keras.layers.Conv2D( + filters, + 1, + 1, + data_format=data_format, + dtype=self.dtype_policy, + name="residual_projection", + ) + self.add = keras.layers.Add(dtype=self.dtype_policy) + + def build(self, input_shape): + residual_shape = list(input_shape) + self.norm1.build(input_shape) + self.act1.build(input_shape) + self.conv1.build(input_shape) + input_shape = self.conv1.compute_output_shape(input_shape) + self.norm2.build(input_shape) + self.act2.build(input_shape) + self.conv2.build(input_shape) + input_shape = self.conv2.compute_output_shape(input_shape) + if self.has_residual_projection: + self.residual_projection.build(residual_shape) + self.add.build([input_shape, input_shape]) + + def call(self, inputs, training=None): + x = inputs + residual = x + x = self.norm1(x, training=training) + x = self.act1(x, training=training) + x = self.conv1(x, training=training) + x = self.norm2(x, training=training) + x = self.act2(x, training=training) + x = self.conv2(x, training=training) + if self.has_residual_projection: + residual = self.residual_projection(residual, training=training) + x = self.add([residual, x]) + return x + + def get_config(self): + config = super().get_config() + config.update( + { + "filters": self.filters, + "has_residual_projection": self.has_residual_projection, + } + ) + return config + + def compute_output_shape(self, input_shape): + outputs_shape = list(input_shape) + if self.has_residual_projection: + outputs_shape = self.residual_projection.compute_output_shape( + outputs_shape + ) + return outputs_shape + + +class VAEEncoder(keras.layers.Layer): + """The encoder part of VAE. + + Args: + stackwise_num_filters: list of ints. The number of filters for each + stack. + stackwise_num_blocks: list of ints. The number of blocks for each stack. + output_channels: int. The number of channels in the output. Defaults to + `32`. + data_format: `None` or str. If specified, either `"channels_last"` or + `"channels_first"`. The ordering of the dimensions in the + inputs. `"channels_last"` corresponds to inputs with shape + `(batch_size, height, width, channels)` + while `"channels_first"` corresponds to inputs with shape + `(batch_size, channels, height, width)`. It defaults to the + `image_data_format` value found in your Keras config file at + `~/.keras/keras.json`. If you never set it, then it will be + `"channels_last"`. + **kwargs: other keyword arguments passed to `keras.layers.Layer`, + including `name`, `dtype` etc. + """ + + def __init__( + self, + stackwise_num_filters, + stackwise_num_blocks, + output_channels=32, + data_format=None, + **kwargs, + ): + super().__init__(**kwargs) + data_format = standardize_data_format(data_format) + channel_axis = -1 if data_format == "channels_last" else 1 + self.stackwise_num_filters = stackwise_num_filters + self.stackwise_num_blocks = stackwise_num_blocks + self.output_channels = int(output_channels) + self.data_format = data_format + + # === Layers === + self.input_projection = keras.layers.Conv2D( + stackwise_num_filters[0], + 3, + 1, + padding="same", + data_format=data_format, + dtype=self.dtype_policy, + name="input_projection", + ) + + # Blocks. + input_filters = stackwise_num_filters[0] + self.blocks = [] + self.downsamples = [] + for i, filters in enumerate(stackwise_num_filters): + for j in range(stackwise_num_blocks[i]): + self.blocks.append( + ResNetBlock( + filters, + has_residual_projection=input_filters != filters, + data_format=data_format, + dtype=self.dtype_policy, + name=f"block_{i}_{j}", + ) + ) + input_filters = filters + # No downsample in the last block. + if i != len(stackwise_num_filters) - 1: + self.downsamples.append( + keras.layers.ZeroPadding2D( + padding=((0, 1), (0, 1)), + data_format=data_format, + dtype=self.dtype_policy, + name=f"downsample_{i}_pad", + ) + ) + self.downsamples.append( + keras.layers.Conv2D( + filters, + 3, + 2, + data_format=data_format, + dtype=self.dtype_policy, + name=f"downsample_{i}_conv", + ) + ) + + # Mid block. + self.mid_block_0 = ResNetBlock( + stackwise_num_filters[-1], + has_residual_projection=False, + data_format=data_format, + dtype=self.dtype_policy, + name="mid_block_0", + ) + self.mid_attention = Conv2DMultiHeadAttention( + stackwise_num_filters[-1], + data_format=data_format, + dtype=self.dtype_policy, + name="mid_attention", + ) + self.mid_block_1 = ResNetBlock( + stackwise_num_filters[-1], + has_residual_projection=False, + data_format=data_format, + dtype=self.dtype_policy, + name="mid_block_1", + ) + + # Output layers. + self.output_norm = keras.layers.GroupNormalization( + groups=32, + axis=channel_axis, + epsilon=1e-6, + dtype=self.dtype_policy, + name="output_norm", + ) + self.output_act = keras.layers.Activation( + "swish", dtype=self.dtype_policy + ) + self.output_projection = keras.layers.Conv2D( + output_channels, + 3, + 1, + padding="same", + data_format=data_format, + dtype=self.dtype_policy, + name="output_projection", + ) + + def build(self, input_shape): + self.input_projection.build(input_shape) + input_shape = self.input_projection.compute_output_shape(input_shape) + blocks_idx = 0 + downsamples_idx = 0 + for i, _ in enumerate(self.stackwise_num_filters): + for _ in range(self.stackwise_num_blocks[i]): + self.blocks[blocks_idx].build(input_shape) + input_shape = self.blocks[blocks_idx].compute_output_shape( + input_shape + ) + blocks_idx += 1 + if i != len(self.stackwise_num_filters) - 1: + self.downsamples[downsamples_idx].build(input_shape) + input_shape = self.downsamples[ + downsamples_idx + ].compute_output_shape(input_shape) + downsamples_idx += 1 + self.downsamples[downsamples_idx].build(input_shape) + input_shape = self.downsamples[ + downsamples_idx + ].compute_output_shape(input_shape) + downsamples_idx += 1 + self.mid_block_0.build(input_shape) + input_shape = self.mid_block_0.compute_output_shape(input_shape) + self.mid_attention.build(input_shape) + input_shape = self.mid_attention.compute_output_shape(input_shape) + self.mid_block_1.build(input_shape) + input_shape = self.mid_block_1.compute_output_shape(input_shape) + self.output_norm.build(input_shape) + self.output_act.build(input_shape) + self.output_projection.build(input_shape) + + def call(self, inputs, training=None): + x = inputs + x = self.input_projection(x, training=training) + blocks_idx = 0 + upsamples_idx = 0 + for i, _ in enumerate(self.stackwise_num_filters): + for _ in range(self.stackwise_num_blocks[i]): + x = self.blocks[blocks_idx](x, training=training) + blocks_idx += 1 + if i != len(self.stackwise_num_filters) - 1: + x = self.downsamples[upsamples_idx](x, training=training) + x = self.downsamples[upsamples_idx + 1](x, training=training) + upsamples_idx += 2 + x = self.mid_block_0(x, training=training) + x = self.mid_attention(x, training=training) + x = self.mid_block_1(x, training=training) + x = self.output_norm(x, training=training) + x = self.output_act(x, training=training) + x = self.output_projection(x, training=training) + return x + + def get_config(self): + config = super().get_config() + config.update( + { + "stackwise_num_filters": self.stackwise_num_filters, + "stackwise_num_blocks": self.stackwise_num_blocks, + "output_channels": self.output_channels, + } + ) + return config + + def compute_output_shape(self, input_shape): + if self.data_format == "channels_last": + h_axis, w_axis, c_axis = 1, 2, 3 + else: + c_axis, h_axis, w_axis = 1, 2, 3 + scale_factor = 2 ** (len(self.stackwise_num_filters) - 1) + outputs_shape = list(input_shape) + if ( + outputs_shape[h_axis] is not None + and outputs_shape[w_axis] is not None + ): + outputs_shape[h_axis] = outputs_shape[h_axis] // scale_factor + outputs_shape[w_axis] = outputs_shape[w_axis] // scale_factor + outputs_shape[c_axis] = self.output_channels + return outputs_shape + + +class VAEDecoder(keras.layers.Layer): + """The decoder part of VAE. + + Args: + stackwise_num_filters: list of ints. The number of filters for each + stack. + stackwise_num_blocks: list of ints. The number of blocks for each stack. + output_channels: int. The number of channels in the output. Defaults to + `3`. + data_format: `None` or str. If specified, either `"channels_last"` or + `"channels_first"`. The ordering of the dimensions in the + inputs. `"channels_last"` corresponds to inputs with shape + `(batch_size, height, width, channels)` + while `"channels_first"` corresponds to inputs with shape + `(batch_size, channels, height, width)`. It defaults to the + `image_data_format` value found in your Keras config file at + `~/.keras/keras.json`. If you never set it, then it will be + `"channels_last"`. + **kwargs: other keyword arguments passed to `keras.layers.Layer`, + including `name`, `dtype` etc. + """ + + def __init__( + self, + stackwise_num_filters, + stackwise_num_blocks, + output_channels=3, + data_format=None, + **kwargs, + ): + super().__init__(**kwargs) + data_format = standardize_data_format(data_format) + channel_axis = -1 if data_format == "channels_last" else 1 + self.stackwise_num_filters = stackwise_num_filters + self.stackwise_num_blocks = stackwise_num_blocks + self.output_channels = int(output_channels) + self.data_format = data_format + + # === Layers === + self.input_projection = keras.layers.Conv2D( + stackwise_num_filters[0], + 3, + 1, + padding="same", + data_format=data_format, + dtype=self.dtype_policy, + name="input_projection", + ) + + # Mid block. + self.mid_block_0 = ResNetBlock( + stackwise_num_filters[0], + data_format=data_format, + dtype=self.dtype_policy, + name="mid_block_0", + ) + self.mid_attention = Conv2DMultiHeadAttention( + stackwise_num_filters[0], + data_format=data_format, + dtype=self.dtype_policy, + name="mid_attention", + ) + self.mid_block_1 = ResNetBlock( + stackwise_num_filters[0], + data_format=data_format, + dtype=self.dtype_policy, + name="mid_block_1", + ) + + # Blocks. + input_filters = stackwise_num_filters[0] + self.blocks = [] + self.upsamples = [] + for i, filters in enumerate(stackwise_num_filters): + for j in range(stackwise_num_blocks[i]): + self.blocks.append( + ResNetBlock( + filters, + has_residual_projection=input_filters != filters, + data_format=data_format, + dtype=self.dtype_policy, + name=f"block_{i}_{j}", + ) + ) + input_filters = filters + # No upsample in the last block. + if i != len(stackwise_num_filters) - 1: + self.upsamples.append( + keras.layers.UpSampling2D( + 2, + data_format=data_format, + dtype=self.dtype_policy, + name=f"upsample_{i}", + ) + ) + self.upsamples.append( + keras.layers.Conv2D( + filters, + 3, + 1, + padding="same", + data_format=data_format, + dtype=self.dtype_policy, + name=f"upsample_{i}_conv", + ) + ) + + # Output layers. + self.output_norm = keras.layers.GroupNormalization( + groups=32, + axis=channel_axis, + epsilon=1e-6, + dtype=self.dtype_policy, + name="output_norm", + ) + self.output_act = keras.layers.Activation( + "swish", dtype=self.dtype_policy + ) + self.output_projection = keras.layers.Conv2D( + output_channels, + 3, + 1, + padding="same", + data_format=data_format, + dtype=self.dtype_policy, + name="output_projection", + ) + + def build(self, input_shape): + self.input_projection.build(input_shape) + input_shape = self.input_projection.compute_output_shape(input_shape) + self.mid_block_0.build(input_shape) + input_shape = self.mid_block_0.compute_output_shape(input_shape) + self.mid_attention.build(input_shape) + input_shape = self.mid_attention.compute_output_shape(input_shape) + self.mid_block_1.build(input_shape) + input_shape = self.mid_block_1.compute_output_shape(input_shape) + blocks_idx = 0 + upsamples_idx = 0 + for i, _ in enumerate(self.stackwise_num_filters): + for _ in range(self.stackwise_num_blocks[i]): + self.blocks[blocks_idx].build(input_shape) + input_shape = self.blocks[blocks_idx].compute_output_shape( + input_shape + ) + blocks_idx += 1 + if i != len(self.stackwise_num_filters) - 1: + self.upsamples[upsamples_idx].build(input_shape) + input_shape = self.upsamples[ + upsamples_idx + ].compute_output_shape(input_shape) + self.upsamples[upsamples_idx + 1].build(input_shape) + input_shape = self.upsamples[ + upsamples_idx + 1 + ].compute_output_shape(input_shape) + upsamples_idx += 2 + self.output_norm.build(input_shape) + self.output_act.build(input_shape) + self.output_projection.build(input_shape) + + def call(self, inputs, training=None): + x = inputs + x = self.input_projection(x, training=training) + x = self.mid_block_0(x, training=training) + x = self.mid_attention(x, training=training) + x = self.mid_block_1(x, training=training) + blocks_idx = 0 + upsamples_idx = 0 + for i, _ in enumerate(self.stackwise_num_filters): + for _ in range(self.stackwise_num_blocks[i]): + x = self.blocks[blocks_idx](x, training=training) + blocks_idx += 1 + if i != len(self.stackwise_num_filters) - 1: + x = self.upsamples[upsamples_idx](x, training=training) + x = self.upsamples[upsamples_idx + 1](x, training=training) + upsamples_idx += 2 + x = self.output_norm(x, training=training) + x = self.output_act(x, training=training) + x = self.output_projection(x, training=training) + return x + + def get_config(self): + config = super().get_config() + config.update( + { + "stackwise_num_filters": self.stackwise_num_filters, + "stackwise_num_blocks": self.stackwise_num_blocks, + "output_channels": self.output_channels, + } + ) + return config + + def compute_output_shape(self, input_shape): + if self.data_format == "channels_last": + h_axis, w_axis, c_axis = 1, 2, 3 + else: + c_axis, h_axis, w_axis = 1, 2, 3 + scale_factor = 2 ** (len(self.stackwise_num_filters) - 1) + outputs_shape = list(input_shape) + if ( + outputs_shape[h_axis] is not None + and outputs_shape[w_axis] is not None + ): + outputs_shape[h_axis] = outputs_shape[h_axis] * scale_factor + outputs_shape[w_axis] = outputs_shape[w_axis] * scale_factor + outputs_shape[c_axis] = self.output_channels + return outputs_shape + + +class DiagonalGaussianDistributionSampler(keras.layers.Layer): + """A sampler for a diagonal Gaussian distribution. + + This layer samples latent variables from a diagonal Gaussian distribution. + + Args: + method: str. The method used to sample from the distribution. Available + methods are `"sample"` and `"mode"`. `"sample"` draws from the + distribution using both the mean and log variance. `"mode"` draws + from the distribution using the mean only. + axis: int. The axis along which to split the mean and log variance. + Defaults to `-1`. + seed: optional int. Used as a random seed. + **kwargs: other keyword arguments passed to `keras.layers.Layer`, + including `name`, `dtype` etc. + """ + + def __init__(self, method, axis=-1, seed=None, **kwargs): + super().__init__(**kwargs) + # TODO: Support `kl` and `nll` modes. + valid_methods = ("sample", "mode") + if method not in valid_methods: + raise ValueError( + f"Invalid method {method}. Valid methods are " + f"{list(valid_methods)}." + ) + self.method = method + self.axis = axis + self.seed = seed + self.seed_generator = keras.random.SeedGenerator(seed) + + def call(self, inputs): + x = inputs + if self.method == "sample": + x_mean, x_logvar = ops.split(x, 2, axis=self.axis) + x_logvar = ops.clip(x_logvar, -30.0, 20.0) + x_std = ops.exp(ops.multiply(0.5, x_logvar)) + sample = keras.random.normal( + ops.shape(x_mean), dtype=x_mean.dtype, seed=self.seed_generator + ) + x = ops.add(x_mean, ops.multiply(x_std, sample)) + else: + x, _ = ops.split(x, 2, axis=self.axis) + return x + + def get_config(self): + config = super().get_config() + config.update( + { + "axis": self.axis, + "seed": self.seed, + } + ) + return config + + def compute_output_shape(self, input_shape): + output_shape = list(input_shape) + output_shape[self.axis] = output_shape[self.axis] // 2 + return output_shape diff --git a/tools/checkpoint_conversion/convert_stable_diffusion_3_checkpoints.py b/tools/checkpoint_conversion/convert_stable_diffusion_3_checkpoints.py index 481117eb8..15b969153 100644 --- a/tools/checkpoint_conversion/convert_stable_diffusion_3_checkpoints.py +++ b/tools/checkpoint_conversion/convert_stable_diffusion_3_checkpoints.py @@ -29,6 +29,7 @@ from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_text_to_image_preprocessor import ( StableDiffusion3TextToImagePreprocessor, ) +from keras_hub.src.models.vae.vae_backbone import VAEBackbone from keras_hub.src.utils.preset_utils import load_json from keras_hub.src.utils.transformers.safetensor_utils import SafetensorLoader @@ -42,7 +43,7 @@ "clip_l": "text_encoders/clip_l.safetensors", "clip_g": "text_encoders/clip_g.safetensors", "diffuser": "sd3_medium.safetensors", - "decoder": "sd3_medium.safetensors", + "vae": "sd3_medium.safetensors", # Tokenizer "clip_tokenizer": "hf://openai/clip-vit-large-patch14", } @@ -69,12 +70,35 @@ def convert_model(preset, height, width): - # The text encoders are all the same. + # The vae and text encoders are common in all presets. + vae = VAEBackbone( + [128, 256, 512, 512], + [2, 2, 2, 2], + [512, 512, 256, 128], + [3, 3, 3, 3], + name="vae", + ) clip_l = CLIPTextEncoder( - 49408, 768, 768, 12, 12, 3072, "quick_gelu", -2, name="clip_l" + 49408, + 768, + 768, + 12, + 12, + 3072, + "quick_gelu", + -2, + name="clip_l", ) clip_g = CLIPTextEncoder( - 49408, 1280, 1280, 32, 20, 5120, "gelu", -2, name="clip_g" + 49408, + 1280, + 1280, + 32, + 20, + 5120, + "gelu", + -2, + name="clip_g", ) # TODO: Add T5. @@ -86,12 +110,12 @@ def convert_model(preset, height, width): 24, 24, 192, - [512, 512, 256, 128], - [3, 3, 3, 3], + vae, clip_l, clip_g, height=height, width=width, + name="stable_diffusion_3_backbone", ) return backbone @@ -107,18 +131,28 @@ def convert_preprocessor(): merges, pad_with_end_token=True, config_name="clip_l_tokenizer.json", + name="clip_l_tokenizer", ) clip_g_tokenizer = CLIPTokenizer( - vocabulary, merges, config_name="clip_g_tokenizer.json" + vocabulary, + merges, + config_name="clip_g_tokenizer.json", + name="clip_g_tokenizer", ) clip_l_preprocessor = CLIPPreprocessor( - clip_l_tokenizer, config_name="clip_l_preprocessor.json" + clip_l_tokenizer, + config_name="clip_l_preprocessor.json", + name="clip_l_preprocessor", ) clip_g_preprocessor = CLIPPreprocessor( - clip_g_tokenizer, config_name="clip_g_preprocessor.json" + clip_g_tokenizer, + config_name="clip_g_preprocessor.json", + name="clip_g_preprocessor", ) preprocessor = StableDiffusion3TextToImagePreprocessor( - clip_l_preprocessor, clip_g_preprocessor + clip_l_preprocessor, + clip_g_preprocessor, + name="stable_diffusion_3_text_to_image_preprocessor", ) return preprocessor @@ -331,104 +365,117 @@ def port_diffuser(preset, filename, model): ) return model - def port_decoder(preset, filename, model): + def port_vae(preset, filename, model): hf_prefix = "first_stage_model." - def port_resnet_block( - keras_variable_name, hf_weight_key, has_residual=False - ): + def port_resnet_block(loader, keras_variable, hf_weight_key): port_ln_or_gn( - loader, - model.get_layer(f"{keras_variable_name}_norm1"), - f"{hf_weight_key}.norm1", - ) - port_conv2d( - loader, - model.get_layer(f"{keras_variable_name}_conv1"), - f"{hf_weight_key}.conv1", + loader, keras_variable.norm1, f"{hf_weight_key}.norm1" ) + port_conv2d(loader, keras_variable.conv1, f"{hf_weight_key}.conv1") port_ln_or_gn( - loader, - model.get_layer(f"{keras_variable_name}_norm2"), - f"{hf_weight_key}.norm2", - ) - port_conv2d( - loader, - model.get_layer(f"{keras_variable_name}_conv2"), - f"{hf_weight_key}.conv2", + loader, keras_variable.norm2, f"{hf_weight_key}.norm2" ) - if has_residual: + port_conv2d(loader, keras_variable.conv2, f"{hf_weight_key}.conv2") + if hasattr(keras_variable, "residual_projection"): port_conv2d( loader, - model.get_layer( - f"{keras_variable_name}_residual_projection" - ), + keras_variable.residual_projection, f"{hf_weight_key}.nin_shortcut", ) - def port_attention(keras_variable_name, hf_weight_key): + def port_attention(loader, keras_variable, hf_weight_key): port_ln_or_gn( - loader, - model.get_layer(keras_variable_name).group_norm, - f"{hf_weight_key}.norm", - ) - port_conv2d( - loader, - model.get_layer(keras_variable_name).query_conv2d, - f"{hf_weight_key}.q", + loader, keras_variable.group_norm, f"{hf_weight_key}.norm" ) port_conv2d( - loader, - model.get_layer(keras_variable_name).key_conv2d, - f"{hf_weight_key}.k", + loader, keras_variable.query_conv2d, f"{hf_weight_key}.q" ) + port_conv2d(loader, keras_variable.key_conv2d, f"{hf_weight_key}.k") port_conv2d( - loader, - model.get_layer(keras_variable_name).value_conv2d, - f"{hf_weight_key}.v", + loader, keras_variable.value_conv2d, f"{hf_weight_key}.v" ) port_conv2d( loader, - model.get_layer(keras_variable_name).output_conv2d, + keras_variable.output_conv2d, f"{hf_weight_key}.proj_out", ) + # Port encdoer. with SafetensorLoader( preset, prefix=hf_prefix, fname=filename ) as loader: - # Stem - port_conv2d( - loader, model.get_layer("input_projection"), "decoder.conv_in" - ) - port_resnet_block("input_block0", "decoder.mid.block_1") - port_attention("input_attention", "decoder.mid.attn_1") - port_resnet_block("input_block1", "decoder.mid.block_2") - - # Stacks - input_filters = model.stackwise_num_filters[0] - for i, filters in enumerate(model.stackwise_num_filters): - for j in range(model.stackwise_num_blocks[i]): - n = model.stackwise_num_blocks[i] - prefix = f"decoder.up.{n-i}.block.{j}" + encoder = keras_model.vae.encoder + # Stem. + port_conv2d(loader, encoder.input_projection, "encoder.conv_in") + + # Blocks. + blocks_idx = 0 + downsamples_idx = 0 + for i, _ in enumerate(encoder.stackwise_num_filters): + for j in range(encoder.stackwise_num_blocks[i]): + prefix = f"encoder.down.{i}.block.{j}" port_resnet_block( - f"block{i}_{j}", - prefix, - has_residual=filters != input_filters, + loader, encoder.blocks[blocks_idx], prefix ) - input_filters = filters - if i != len(model.stackwise_num_filters) - 1: - port_conv2d( - loader, - model.get_layer(f"upsample_{i}_conv"), - f"decoder.up.{n-i}.upsample.conv", - ) + blocks_idx += 1 + if i != len(encoder.stackwise_num_filters) - 1: + port_conv2d( + loader, + encoder.downsamples[downsamples_idx + 1], + f"encoder.down.{i}.downsample.conv", + ) + downsamples_idx += 2 # Skip `ZeroPadding2D`. + # Output layers - port_ln_or_gn( - loader, model.get_layer("output_norm"), "decoder.norm_out" + port_resnet_block( + loader, encoder.mid_block_0, "encoder.mid.block_1" ) - port_conv2d( - loader, model.get_layer("output_projection"), "decoder.conv_out" + port_attention(loader, encoder.mid_attention, "encoder.mid.attn_1") + port_resnet_block( + loader, encoder.mid_block_1, "encoder.mid.block_2" + ) + port_ln_or_gn(loader, encoder.output_norm, "encoder.norm_out") + port_conv2d(loader, encoder.output_projection, "encoder.conv_out") + + # Port decoder. + with SafetensorLoader( + preset, prefix=hf_prefix, fname=filename + ) as loader: + decoder = keras_model.vae.decoder + # Stem. + port_conv2d(loader, decoder.input_projection, "decoder.conv_in") + port_resnet_block( + loader, decoder.mid_block_0, "decoder.mid.block_1" ) + port_attention(loader, decoder.mid_attention, "decoder.mid.attn_1") + port_resnet_block( + loader, decoder.mid_block_1, "decoder.mid.block_2" + ) + + # Blocks. + blocks_idx = 0 + upsamples_idx = 0 + for i, _ in enumerate(decoder.stackwise_num_filters): + for j in range(decoder.stackwise_num_blocks[i]): + n = len(decoder.stackwise_num_blocks) - 1 + prefix = f"decoder.up.{n-i}.block.{j}" + port_resnet_block( + loader, decoder.blocks[blocks_idx], prefix + ) + blocks_idx += 1 + if i != len(decoder.stackwise_num_filters) - 1: + port_conv2d( + loader, + decoder.upsamples[upsamples_idx + 1], + f"decoder.up.{n-i}.upsample.conv", + ) + upsamples_idx += 2 # Skip `UpSampling2D`. + + # Output layers + port_ln_or_gn(loader, decoder.output_norm, "decoder.norm_out") + port_conv2d(loader, decoder.output_projection, "decoder.conv_out") + return model # Start conversion. @@ -446,13 +493,14 @@ def port_attention(keras_variable_name, hf_weight_key): keras_model.clip_g_projection, ) port_diffuser(config["root"], config["diffuser"], keras_model.diffuser) - port_decoder(config["root"], config["decoder"], keras_model.decoder) + port_vae(config["root"], config["vae"], keras_model.vae) def validate_output(keras_model, keras_preprocessor, output_dir): # TODO: Verify the numerics. + prompt = "A cat holding a sign that says hello world" text_to_image = StableDiffusion3TextToImage(keras_model, keras_preprocessor) - image = text_to_image.generate("cute wallpaper art of a cat", seed=42) + image = text_to_image.generate(prompt, seed=42) image = Image.fromarray(image) image.save(os.path.join(output_dir, "test.png")) @@ -472,7 +520,7 @@ def main(_): # Currently SD3 weights are float16 (and have much faster download # times for it). We follow suit with Keras weights. keras.config.set_dtype_policy("float16") - height, width = 512, 512 # Use a smaller image size to speed up generation. + height, width = 800, 800 # Use a smaller image size to speed up generation. keras_preprocessor = convert_preprocessor() keras_model = convert_model(preset, height, width)