diff --git a/README.md b/README.md index cdbce8b532..5c9157e43c 100644 --- a/README.md +++ b/README.md @@ -4,7 +4,7 @@ [![contributions welcome](https://img.shields.io/badge/contributions-welcome-brightgreen.svg?style=flat)](https://github.com/keras-team/keras-hub/issues) > [!IMPORTANT] -> šŸ“¢ KerasNLP is becoming KerasHub! šŸ“¢ Read +> šŸ“¢ KerasNLP is now KerasHub! šŸ“¢ Read > [the announcement](https://github.com/keras-team/keras-hub/issues/1831). > > We have renamed the repo to KerasHub in preparation for the release, but have not yet @@ -26,7 +26,7 @@ All models support JAX, TensorFlow, and PyTorch from a single model definition and can be fine-tuned on GPUs and TPUs out of the box. Models can be trained on individual accelerators with built-in PEFT techniques, or fine-tuned at scale with model and data parallel training. See our -[Getting Started guide](https://keras.io/guides/keras_nlp/getting_started) +[Getting Started guide](https://keras.io/guides/keras_hub/getting_started) to start learning our API. Browse our models on [Kaggle](https://www.kaggle.com/organizations/keras/models). We welcome contributions. @@ -35,9 +35,9 @@ We welcome contributions. ### For everyone -- [Home Page](https://keras.io/keras_nlp) -- [Developer Guides](https://keras.io/guides/keras_nlp) -- [API Reference](https://keras.io/api/keras_nlp) +- [Home Page](https://keras.io/keras_hub) +- [Developer Guides](https://keras.io/guides/keras_hub) +- [API Reference](https://keras.io/api/keras_hub) - [Pre-trained Models](https://www.kaggle.com/organizations/keras/models) ### For contributors @@ -56,7 +56,7 @@ Fine-tune a BERT classifier on IMDb movie reviews: import os os.environ["KERAS_BACKEND"] = "jax" # Or "tensorflow" or "torch"! -import keras_nlp +import keras_hub import tensorflow_datasets as tfds imdb_train, imdb_test = tfds.load( @@ -67,7 +67,7 @@ imdb_train, imdb_test = tfds.load( ) # Load a BERT model. -classifier = keras_nlp.models.Classifier.from_preset( +classifier = keras_hub.models.Classifier.from_preset( "bert_base_en", num_classes=2, activation="softmax", @@ -79,25 +79,17 @@ classifier.fit(imdb_train, validation_data=imdb_test) classifier.predict(["What an amazing movie!", "A total waste of my time."]) ``` -Try it out [in a colab](https://colab.research.google.com/gist/mattdangerw/e457e42d5ea827110c8d5cb4eb9d9a07/kerasnlp-quickstart.ipynb). +Try it out [in a colab](https://colab.research.google.com/drive/1gSWkh3yOLwmKAaNh2dQQ6kQIlnGte7P2?usp=sharing). For more in depth guides and examples, visit -[keras.io/keras_nlp](https://keras.io/keras_nlp/). +[keras.io/keras_hub](https://keras.io/keras_hub/). ## Installation -KerasHub is currently in pre-release. Note that pre-release versions may -introduce breaking changes to the API in future versions. For a stable and -supported experience, we recommend installing `keras-nlp` version 0.15.1: - -```bash -pip install keras-nlp==0.15.1 -``` - -To try out the latest pre-release version of KerasHub, you can use +To try out the latest version of KerasHub, you can use our nightly package: ```bash -pip install keras-hub-nightly +pip install keras-hub ``` KerasHub currently requires TensorFlow to be installed for use of the diff --git a/keras_hub/api/layers/__init__.py b/keras_hub/api/layers/__init__.py index 95d0d40919..53e0074414 100644 --- a/keras_hub/api/layers/__init__.py +++ b/keras_hub/api/layers/__init__.py @@ -40,6 +40,9 @@ from keras_hub.src.models.densenet.densenet_image_converter import ( DenseNetImageConverter, ) +from keras_hub.src.models.mix_transformer.mix_transformer_image_converter import ( + MiTImageConverter, +) from keras_hub.src.models.mobilenet.mobilenet_image_converter import ( MobileNetImageConverter, ) @@ -52,6 +55,10 @@ from keras_hub.src.models.sam.sam_image_converter import SAMImageConverter from keras_hub.src.models.sam.sam_mask_decoder import SAMMaskDecoder from keras_hub.src.models.sam.sam_prompt_encoder import SAMPromptEncoder +from keras_hub.src.models.segformer.segformer_image_converter import ( + SegFormerImageConverter, +) +from keras_hub.src.models.vgg.vgg_image_converter import VGGImageConverter from keras_hub.src.models.whisper.whisper_audio_converter import ( WhisperAudioConverter, ) diff --git a/keras_hub/api/models/__init__.py b/keras_hub/api/models/__init__.py index 7f4f87d9cc..88aa733c78 100644 --- a/keras_hub/api/models/__init__.py +++ b/keras_hub/api/models/__init__.py @@ -180,6 +180,8 @@ from keras_hub.src.models.image_segmenter_preprocessor import ( ImageSegmenterPreprocessor, ) +from keras_hub.src.models.image_to_image import ImageToImage +from keras_hub.src.models.inpaint import Inpaint from keras_hub.src.models.llama3.llama3_backbone import Llama3Backbone from keras_hub.src.models.llama3.llama3_causal_lm import Llama3CausalLM from keras_hub.src.models.llama3.llama3_causal_lm_preprocessor import ( @@ -200,11 +202,10 @@ MistralCausalLMPreprocessor, ) from keras_hub.src.models.mistral.mistral_tokenizer import MistralTokenizer -from keras_hub.src.models.mix_transformer.mix_transformer_backbone import ( - MiTBackbone, -) -from keras_hub.src.models.mix_transformer.mix_transformer_classifier import ( - MiTImageClassifier, +from keras_hub.src.models.mit.mit_backbone import MiTBackbone +from keras_hub.src.models.mit.mit_image_classifier import MiTImageClassifier +from keras_hub.src.models.mit.mit_image_classifier_preprocessor import ( + MiTImageClassifierPreprocessor, ) from keras_hub.src.models.mobilenet.mobilenet_backbone import MobileNetBackbone from keras_hub.src.models.mobilenet.mobilenet_image_classifier import ( @@ -268,11 +269,24 @@ from keras_hub.src.models.sam.sam_image_segmenter_preprocessor import ( SAMImageSegmenterPreprocessor, ) +from keras_hub.src.models.segformer.segformer_backbone import SegFormerBackbone +from keras_hub.src.models.segformer.segformer_image_segmenter import ( + SegFormerImageSegmenter, +) +from keras_hub.src.models.segformer.segformer_image_segmenter_preprocessor import ( + SegFormerImageSegmenterPreprocessor, +) from keras_hub.src.models.seq_2_seq_lm import Seq2SeqLM from keras_hub.src.models.seq_2_seq_lm_preprocessor import Seq2SeqLMPreprocessor from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_backbone import ( StableDiffusion3Backbone, ) +from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_image_to_image import ( + StableDiffusion3ImageToImage, +) +from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_inpaint import ( + StableDiffusion3Inpaint, +) from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_text_to_image import ( StableDiffusion3TextToImage, ) @@ -291,6 +305,9 @@ from keras_hub.src.models.text_to_image import TextToImage from keras_hub.src.models.vgg.vgg_backbone import VGGBackbone from keras_hub.src.models.vgg.vgg_image_classifier import VGGImageClassifier +from keras_hub.src.models.vgg.vgg_image_classifier_preprocessor import ( + VGGImageClassifierPreprocessor, +) from keras_hub.src.models.vit_det.vit_det_backbone import ViTDetBackbone from keras_hub.src.models.whisper.whisper_backbone import WhisperBackbone from keras_hub.src.models.whisper.whisper_tokenizer import WhisperTokenizer diff --git a/keras_hub/src/layers/modeling/transformer_encoder.py b/keras_hub/src/layers/modeling/transformer_encoder.py index 8d3fb0f950..5ed121e457 100644 --- a/keras_hub/src/layers/modeling/transformer_encoder.py +++ b/keras_hub/src/layers/modeling/transformer_encoder.py @@ -170,7 +170,12 @@ def build(self, inputs_shape): self.built = True def call( - self, inputs, padding_mask=None, attention_mask=None, training=None + self, + inputs, + padding_mask=None, + attention_mask=None, + training=None, + return_attention_scores=False, ): """Forward pass of the TransformerEncoder. @@ -185,6 +190,7 @@ def call( [batch_size, sequence_length, sequence_length]. training: a boolean indicating whether the layer should behave in training mode or in inference mode. + return_attention_scores: a boolean indicating whether the output should be `(attention_output, attention_scores)` if `True` or `attention_output` if `False`. Defaults to `False`. Returns: A Tensor of the same shape as the `inputs`. @@ -200,12 +206,24 @@ def call( residual = x if self.normalize_first: x = self._self_attention_layer_norm(x) - x = self._self_attention_layer( - query=x, - value=x, - attention_mask=self_attention_mask, - training=training, - ) + + if return_attention_scores: + x, attention_scores = self._self_attention_layer( + query=x, + value=x, + attention_mask=self_attention_mask, + return_attention_scores=return_attention_scores, + training=training, + ) + return x, attention_scores + else: + x = self._self_attention_layer( + query=x, + value=x, + attention_mask=self_attention_mask, + training=training, + ) + x = self._self_attention_dropout(x, training=training) x = x + residual if not self.normalize_first: @@ -222,6 +240,9 @@ def call( if not self.normalize_first: x = self._feedforward_layer_norm(x) + if return_attention_scores: + return x, attention_scores + return x def get_config(self): diff --git a/keras_hub/src/layers/modeling/transformer_encoder_test.py b/keras_hub/src/layers/modeling/transformer_encoder_test.py index c4763d3763..0f12a0920b 100644 --- a/keras_hub/src/layers/modeling/transformer_encoder_test.py +++ b/keras_hub/src/layers/modeling/transformer_encoder_test.py @@ -95,3 +95,14 @@ def test_mask_propagation(self): inputs._keras_mask = mask outputs = encoder(inputs) self.assertAllEqual(outputs._keras_mask, mask) + + def test_attention_scores(self): + encoder = TransformerEncoder(intermediate_dim=4, num_heads=2) + inputs = random.uniform(shape=[1, 4, 6]) + outputs, attention_scores = encoder( + inputs, return_attention_scores=True + ) + self.assertAllEqual(outputs.shape, inputs.shape) + + # attention scores shape (batch_size, num_of_attn_heads, seq_length, seq_length) + self.assertAllEqual(attention_scores.shape, [1, 2, 4, 4]) diff --git a/keras_hub/src/layers/preprocessing/image_converter.py b/keras_hub/src/layers/preprocessing/image_converter.py index e3b55bbde0..89142c469b 100644 --- a/keras_hub/src/layers/preprocessing/image_converter.py +++ b/keras_hub/src/layers/preprocessing/image_converter.py @@ -145,8 +145,9 @@ def image_size(self, value): @preprocessing_function def call(self, inputs): + x = inputs if self.image_size is not None: - x = self.resizing(inputs) + x = self.resizing(x) if self.scale is not None: x = x * self._expand_non_channel_dims(self.scale, x) if self.offset is not None: diff --git a/keras_hub/src/layers/preprocessing/image_converter_test.py b/keras_hub/src/layers/preprocessing/image_converter_test.py index 5e0fd940c2..d638ccf9ab 100644 --- a/keras_hub/src/layers/preprocessing/image_converter_test.py +++ b/keras_hub/src/layers/preprocessing/image_converter_test.py @@ -6,12 +6,10 @@ from keras import ops from keras_hub.src.layers.preprocessing.image_converter import ImageConverter -from keras_hub.src.models.pali_gemma.pali_gemma_backbone import ( - PaliGemmaBackbone, -) from keras_hub.src.models.pali_gemma.pali_gemma_image_converter import ( PaliGemmaImageConverter, ) +from keras_hub.src.models.resnet.resnet_backbone import ResNetBackbone from keras_hub.src.tests.test_case import TestCase @@ -86,24 +84,19 @@ def test_from_preset_errors(self): def test_save_to_preset(self): save_dir = self.get_temp_dir() converter = ImageConverter.from_preset( - "pali_gemma_3b_mix_224", + "resnet_50_imagenet", interpolation="nearest", ) converter.save_to_preset(save_dir) # Save a tiny backbone so the preset is valid. - backbone = PaliGemmaBackbone( - vocabulary_size=100, - image_size=224, - num_layers=1, - num_query_heads=1, - num_key_value_heads=1, - hidden_dim=8, - intermediate_dim=16, - head_dim=8, - vit_patch_size=14, - vit_num_heads=1, - vit_hidden_dim=8, - vit_num_layers=1, + backbone = ResNetBackbone( + input_conv_filters=[64], + input_conv_kernel_sizes=[7], + stackwise_num_filters=[64, 64, 64], + stackwise_num_blocks=[2, 2, 2], + stackwise_num_strides=[1, 2, 2], + block_type="basic_block", + use_pre_activation=True, ) backbone.save_to_preset(save_dir) diff --git a/keras_hub/src/models/causal_lm.py b/keras_hub/src/models/causal_lm.py index c86bd7be9f..2514022c4d 100644 --- a/keras_hub/src/models/causal_lm.py +++ b/keras_hub/src/models/causal_lm.py @@ -274,6 +274,7 @@ def generate( inputs, max_length=None, stop_token_ids="auto", + strip_prompt=False, ): """Generate text given prompt `inputs`. @@ -309,6 +310,9 @@ def generate( specify a list of token id's the model should stop on. Note that sequences of tokens will each be interpreted as a stop token, multi-token stop sequences are not supported. + strip_prompt: Optional. By default, generate() returns the full prompt + followed by its completion generated by the model. If this option + is set to True, only the newly generated text is returned. """ # Setup our three main passes. # 1. Optionally preprocessing strings to dense integer tensors. @@ -326,6 +330,10 @@ def generate( ) elif stop_token_ids == "auto": stop_token_ids = [self.preprocessor.tokenizer.end_token_id] + # Some models like Llama3 use two end tokens: <|eot_id|> in + # "instruct" versions and <|end_of_text|> in others. + if hasattr(self.preprocessor.tokenizer, "end_token2_id"): + stop_token_ids.append(self.preprocessor.tokenizer.end_token2_id) def preprocess(x): return self.preprocessor.generate_preprocess( @@ -335,6 +343,33 @@ def preprocess(x): def generate(x): return generate_function(x, stop_token_ids=stop_token_ids) + def strip_prompt_function(x, prompt): + # This function removes the prompt from the generated + # response, in a batch-friendly fashion. + y = {} + prompt_mask = prompt["padding_mask"] + seq_len = prompt_mask.shape[1] + + # We need to shift every output sequence by the size of the prompt. + shifts = -ops.sum(ops.cast(prompt_mask, "int"), axis=1) % seq_len + ix = ops.arange(seq_len, dtype="int") + ix = ops.expand_dims(ix, axis=0) - ops.expand_dims(shifts, axis=1) + + # This produces the desired shift (in fact a rollover). + def roll_sequence(seq): + return ops.take_along_axis(seq, ix, axis=1) + + # The shifting rolls the content over so the prompt is at the end of + # the sequence and the generated text is at the beginning. We mask + # it to retain the generated text only. + y["padding_mask"] = ops.logical_xor( + roll_sequence(prompt_mask), roll_sequence(x["padding_mask"]) + ) + # we assume the mask is enough and there is no need to zero-out the values + y["token_ids"] = roll_sequence(x["token_ids"]) + + return y + def postprocess(x): return self.preprocessor.generate_postprocess(x) @@ -343,7 +378,12 @@ def postprocess(x): if self.preprocessor is not None: inputs = [preprocess(x) for x in inputs] - outputs = [generate(x) for x in inputs] + + if strip_prompt: + outputs = [strip_prompt_function(generate(x), x) for x in inputs] + else: + outputs = [generate(x) for x in inputs] + if self.preprocessor is not None: outputs = [postprocess(x) for x in outputs] diff --git a/keras_hub/src/models/deeplab_v3/deeplab_v3_backbone_test.py b/keras_hub/src/models/deeplab_v3/deeplab_v3_backbone_test.py index a7b1809085..a0a3a8d1df 100644 --- a/keras_hub/src/models/deeplab_v3/deeplab_v3_backbone_test.py +++ b/keras_hub/src/models/deeplab_v3/deeplab_v3_backbone_test.py @@ -51,6 +51,7 @@ def test_saved_model(self): cls=DeepLabV3Backbone, init_kwargs=self.init_kwargs, input_data=self.input_data, + atol=0.00001, ) diff --git a/keras_hub/src/models/deeplab_v3/deeplab_v3_presets.py b/keras_hub/src/models/deeplab_v3/deeplab_v3_presets.py index 1b1dde181d..85cd186830 100644 --- a/keras_hub/src/models/deeplab_v3/deeplab_v3_presets.py +++ b/keras_hub/src/models/deeplab_v3/deeplab_v3_presets.py @@ -1,4 +1,18 @@ """DeepLabV3 preset configurations.""" -# TODO https://github.com/keras-team/keras-hub/issues/1896, -backbone_presets = {} +backbone_presets = { + "deeplab_v3_plus_resnet50_pascalvoc": { + "metadata": { + "description": ( + "DeepLabV3+ model with ResNet50 as image encoder and trained on " + "augmented Pascal VOC dataset by Semantic Boundaries Dataset(SBD)" + "which is having categorical accuracy of 90.01 and 0.63 Mean IoU." + ), + "params": 39190656, + "official_name": "DeepLabV3", + "path": "deeplab_v3", + "model_card": "https://arxiv.org/abs/1802.02611", + }, + "kaggle_handle": "kaggle://keras/deeplabv3plus/keras/deeplab_v3_plus_resnet50_pascalvoc/3", + }, +} diff --git a/keras_hub/src/models/densenet/densenet_presets.py b/keras_hub/src/models/densenet/densenet_presets.py index 2c3ef77842..99702bf86f 100644 --- a/keras_hub/src/models/densenet/densenet_presets.py +++ b/keras_hub/src/models/densenet/densenet_presets.py @@ -12,7 +12,7 @@ "path": "densenet", "model_card": "https://arxiv.org/abs/1608.06993", }, - "kaggle_handle": "kaggle://kerashub/densenet/keras/densenet_121_imagenet", + "kaggle_handle": "kaggle://keras/densenet/keras/densenet_121_imagenet/2", }, "densenet_169_imagenet": { "metadata": { @@ -25,7 +25,7 @@ "path": "densenet", "model_card": "https://arxiv.org/abs/1608.06993", }, - "kaggle_handle": "kaggle://kerashub/densenet/keras/densenet_169_imagenet", + "kaggle_handle": "kaggle://keras/densenet/keras/densenet_169_imagenet/2", }, "densenet_201_imagenet": { "metadata": { @@ -38,6 +38,6 @@ "path": "densenet", "model_card": "https://arxiv.org/abs/1608.06993", }, - "kaggle_handle": "kaggle://kerashub/densenet/keras/densenet_201_imagenet", + "kaggle_handle": "kaggle://keras/densenet/keras/densenet_201_imagenet/2", }, } diff --git a/keras_hub/src/models/gemma/gemma_backbone.py b/keras_hub/src/models/gemma/gemma_backbone.py index c34547b83e..1d6482b96b 100644 --- a/keras_hub/src/models/gemma/gemma_backbone.py +++ b/keras_hub/src/models/gemma/gemma_backbone.py @@ -224,7 +224,7 @@ def get_layout_map( Example: ``` - # Feel free to change the mesh shape to balance data and model parallel + # Feel free to change the mesh shape to balance data and model parallelism mesh = keras.distribution.DeviceMesh( shape=(1, 8), axis_names=('batch', 'model'), devices=keras.distribution.list_devices()) @@ -232,11 +232,19 @@ def get_layout_map( mesh, model_parallel_dim_name="model") distribution = keras.distribution.ModelParallel( - mesh, layout_map, batch_dim_name='batch') + layout_map=layout_map, batch_dim_name='batch') with distribution.scope(): gemma_model = keras_hub.models.GemmaCausalLM.from_preset() ``` + To see how the layout map was applied, load the model then run (for one decoder block): + ``` + embedding_layer = gemma_model.backbone.get_layer("token_embedding") + decoder_block_1 = gemma_model.backbone.get_layer('decoder_block_1') + for variable in embedding_layer.weights + decoder_block_1.weights: + print(f'{variable.path:<58} {str(variable.shape):<16} {str(variable.value.sharding.spec)}') + ``` + Args: device_mesh: The `keras.distribution.DeviceMesh` instance for distribution. @@ -246,7 +254,7 @@ def get_layout_map( the data should be partition on. Return: `keras.distribution.LayoutMap` that contains the sharding spec - of all the model weights. + for all the model weights. """ # The weight path and shape of the Gemma backbone is like below (for 2G) # token_embedding/embeddings, (256128, 2048), 524550144 diff --git a/keras_hub/src/models/gemma/gemma_backbone_test.py b/keras_hub/src/models/gemma/gemma_backbone_test.py index bbd383e687..b5f8575332 100644 --- a/keras_hub/src/models/gemma/gemma_backbone_test.py +++ b/keras_hub/src/models/gemma/gemma_backbone_test.py @@ -74,11 +74,10 @@ def test_architecture_characteristics(self): def test_distribution(self): if keras.backend.backend() != "jax": - return + self.skipTest("`ModelParallel` testing requires the Jax backend.") devices = keras.distribution.list_devices("CPU") if len(devices) == 1: - # Need more than 1 device for distribution testing. - return + self.skipTest("`ModelParallel` testing requires multiple devices.") device_mesh = keras.distribution.DeviceMesh( shape=(1, len(devices)), axis_names=("batch", "model"), @@ -86,7 +85,7 @@ def test_distribution(self): ) layout_map = GemmaBackbone.get_layout_map(device_mesh) - distribution = keras.distribution.ModelParallel(device_mesh, layout_map) + distribution = keras.distribution.ModelParallel(layout_map=layout_map) with distribution.scope(): model = GemmaBackbone(**self.init_kwargs) @@ -129,7 +128,6 @@ def test_distribution_with_lora(self): self.skipTest("`ModelParallel` testing requires the Jax backend.") devices = keras.distribution.list_devices("CPU") if len(devices) == 1: - # Need more than 1 device for distribution testing. self.skipTest("`ModelParallel` testing requires multiple devices.") device_mesh = keras.distribution.DeviceMesh( shape=(1, len(devices)), diff --git a/keras_hub/src/models/image_to_image.py b/keras_hub/src/models/image_to_image.py new file mode 100644 index 0000000000..d3194a5815 --- /dev/null +++ b/keras_hub/src/models/image_to_image.py @@ -0,0 +1,411 @@ +import itertools +from functools import partial + +import keras +from keras import ops +from keras import random + +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.task import Task +from keras_hub.src.utils.keras_utils import standardize_data_format + +try: + import tensorflow as tf +except ImportError: + tf = None + + +@keras_hub_export("keras_hub.models.ImageToImage") +class ImageToImage(Task): + """Base class for image-to-image tasks. + + `ImageToImage` tasks wrap a `keras_hub.models.Backbone` and + a `keras_hub.models.Preprocessor` to create a model that can be used for + generation and generative fine-tuning. + + `ImageToImage` tasks provide an additional, high-level `generate()` function + which can be used to generate image by token with a (image, string) in, + image out signature. + + All `ImageToImage` tasks include a `from_preset()` constructor which can be + used to load a pre-trained config and weights. + + Example: + + ```python + # Load a Stable Diffusion 3 backbone with pre-trained weights. + reference_image = np.ones((1024, 1024, 3), dtype="float32") + image_to_image = keras_hub.models.ImageToImage.from_preset( + "stable_diffusion_3_medium", + ) + image_to_image.generate( + reference_image, + "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k", + ) + + # Load a Stable Diffusion 3 backbone at bfloat16 precision. + image_to_image = keras_hub.models.ImageToImage.from_preset( + "stable_diffusion_3_medium", + dtype="bfloat16", + ) + image_to_image.generate( + reference_image, + "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k", + ) + ``` + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # Default compilation. + self.compile() + + @property + def support_negative_prompts(self): + """Whether the model supports `negative_prompts` key in `generate()`.""" + return bool(True) + + @property + def image_shape(self): + return tuple(self.backbone.image_shape) + + @property + def latent_shape(self): + return tuple(self.backbone.latent_shape) + + def compile( + self, + optimizer="auto", + loss="auto", + *, + metrics="auto", + **kwargs, + ): + """Configures the `ImageToImage` task for training. + + The `ImageToImage` task extends the default compilation signature of + `keras.Model.compile` with defaults for `optimizer`, `loss`, and + `metrics`. To override these defaults, pass any value + to these arguments during compilation. + + Args: + optimizer: `"auto"`, an optimizer name, or a `keras.Optimizer` + instance. Defaults to `"auto"`, which uses the default optimizer + for the given model and task. See `keras.Model.compile` and + `keras.optimizers` for more info on possible `optimizer` values. + loss: `"auto"`, a loss name, or a `keras.losses.Loss` instance. + Defaults to `"auto"`, where a + `keras.losses.MeanSquaredError` loss will be applied. See + `keras.Model.compile` and `keras.losses` for more info on + possible `loss` values. + metrics: `"auto"`, or a list of metrics to be evaluated by + the model during training and testing. Defaults to `"auto"`, + where a `keras.metrics.MeanSquaredError` will be applied to + track the loss of the model during training. See + `keras.Model.compile` and `keras.metrics` for more info on + possible `metrics` values. + **kwargs: See `keras.Model.compile` for a full list of arguments + supported by the compile method. + """ + # Ref: https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image.py#L410-L414 + if optimizer == "auto": + optimizer = keras.optimizers.AdamW( + 1e-4, weight_decay=1e-2, epsilon=1e-8, clipnorm=1.0 + ) + if loss == "auto": + loss = keras.losses.MeanSquaredError() + if metrics == "auto": + metrics = [keras.metrics.MeanSquaredError()] + super().compile( + optimizer=optimizer, + loss=loss, + metrics=metrics, + **kwargs, + ) + self.generate_function = None + + def generate_step(self, *args, **kwargs): + """Run generation on batches of input.""" + raise NotImplementedError + + def make_generate_function(self): + """Create or return the compiled generation function.""" + if self.generate_function is not None: + return self.generate_function + + self.generate_function = self.generate_step + if keras.config.backend() == "torch": + import torch + + def wrapped_function(*args, **kwargs): + with torch.no_grad(): + return self.generate_step(*args, **kwargs) + + self.generate_function = wrapped_function + elif keras.config.backend() == "tensorflow" and not self.run_eagerly: + self.generate_function = tf.function( + self.generate_step, jit_compile=self.jit_compile + ) + elif keras.config.backend() == "jax" and not self.run_eagerly: + import jax + + @partial(jax.jit) + def compiled_function(state, *args, **kwargs): + ( + trainable_variables, + non_trainable_variables, + ) = state + mapping = itertools.chain( + zip(self.trainable_variables, trainable_variables), + zip(self.non_trainable_variables, non_trainable_variables), + ) + + with keras.StatelessScope(state_mapping=mapping): + outputs = self.generate_step(*args, **kwargs) + return outputs + + def wrapped_function(*args, **kwargs): + # Create an explicit tuple of all variable state. + state = ( + # Use the explicit variable.value to preserve the + # sharding spec of distribution. + [v.value for v in self.trainable_variables], + [v.value for v in self.non_trainable_variables], + ) + outputs = compiled_function(state, *args, **kwargs) + return outputs + + self.generate_function = wrapped_function + return self.generate_function + + def _normalize_generate_inputs(self, inputs): + """Normalize user input to the generate function. + + This function converts all inputs to tensors, adds a batch dimension if + necessary, and returns a iterable "dataset like" object (either an + actual `tf.data.Dataset` or a list with a single batch element). + + The input format must be one of the following: + - A dict with "images", "prompts" and/or "negative_prompts" keys + - A tf.data.Dataset with "images", "prompts" and/or "negative_prompts" + keys + + The output will be a dict with "images", "prompts" and/or + "negative_prompts" keys. + """ + if tf and isinstance(inputs, tf.data.Dataset): + _inputs = { + "images": inputs.map(lambda x: x["images"]).as_numpy_iterator(), + "prompts": inputs.map( + lambda x: x["prompts"] + ).as_numpy_iterator(), + } + if self.support_negative_prompts: + _inputs["negative_prompts"] = inputs.map( + lambda x: x["negative_prompts"] + ).as_numpy_iterator() + return _inputs, False + + if ( + not isinstance(inputs, dict) + or "images" not in inputs + or "prompts" not in inputs + ): + raise ValueError( + '`inputs` must be a dict with "images" and "prompts" keys or a' + f"tf.data.Dataset. Received: inputs={inputs}" + ) + + def normalize(x): + if isinstance(x, str): + return [x], True + if tf and isinstance(x, tf.Tensor) and x.shape.rank == 0: + return x[tf.newaxis], True + return x, False + + def normalize_images(x): + data_format = getattr( + self.backbone, "data_format", standardize_data_format(None) + ) + input_is_scalar = False + x = ops.convert_to_tensor(x) + if len(ops.shape(x)) < 4: + x = ops.expand_dims(x, axis=0) + input_is_scalar = True + x = ops.image.resize( + x, + (self.backbone.image_shape[0], self.backbone.image_shape[1]), + interpolation="nearest", + data_format=data_format, + ) + return x, input_is_scalar + + def get_dummy_prompts(x): + dummy_prompts = [""] * len(x) + if tf and isinstance(x, tf.Tensor): + return tf.convert_to_tensor(dummy_prompts) + else: + return dummy_prompts + + for key in inputs: + if key == "images": + inputs[key], input_is_scalar = normalize_images(inputs[key]) + else: + inputs[key], input_is_scalar = normalize(inputs[key]) + + if self.support_negative_prompts and "negative_prompts" not in inputs: + inputs["negative_prompts"] = get_dummy_prompts(inputs["prompts"]) + + return [inputs], input_is_scalar + + def _normalize_generate_outputs(self, outputs, input_is_scalar): + """Normalize user output from the generate function. + + This function converts all output to numpy with a value range of + `[0, 255]`. If a batch dimension was added to the input, it is removed + from the output. + """ + + def normalize(x): + outputs = ops.concatenate(x, axis=0) + outputs = ops.clip(ops.divide(ops.add(outputs, 1.0), 2.0), 0.0, 1.0) + outputs = ops.cast(ops.round(ops.multiply(outputs, 255.0)), "uint8") + outputs = ops.squeeze(outputs, 0) if input_is_scalar else outputs + return ops.convert_to_numpy(outputs) + + if isinstance(outputs[0], dict): + normalized = {} + for key in outputs[0]: + normalized[key] = normalize([x[key] for x in outputs]) + return normalized + return normalize([x for x in outputs]) + + def generate( + self, + inputs, + num_steps, + guidance_scale, + strength, + seed=None, + ): + """Generate image based on the provided `inputs`. + + Typically, `inputs` is a dict with `"images"` and `"prompts"` keys. + `"images"` are reference images within a value range of + `[-1.0, 1.0]`, which will be resized to `self.backbone.height` and + `self.backbone.width`, then encoded into latent space by the VAE + encoder. `"prompts"` are strings that will be tokenized and encoded by + the text encoder. + + Some models support a `"negative_prompts"` key, which helps steer the + model away from generating certain styles and elements. To enable this, + add `"negative_prompts"` to the input dict. + + If `inputs` are a `tf.data.Dataset`, outputs will be generated + "batch-by-batch" and concatenated. Otherwise, all inputs will be + processed as batches. + + Args: + inputs: python data, tensor data, or a `tf.data.Dataset`. The format + must be one of the following: + - A dict with `"images"`, `"prompts"` and/or + `"negative_prompts"` keys. + - A `tf.data.Dataset` with `"images"`, `"prompts"` and/or + `"negative_prompts"` keys. + num_steps: int. The number of diffusion steps to take. + guidance_scale: float. The classifier free guidance scale defined in + [Classifier-Free Diffusion Guidance]( + https://arxiv.org/abs/2207.12598). A higher scale encourages + generating images more closely related to the prompts, typically + at the cost of lower image quality. + strength: float. Indicates the extent to which the reference + `images` are transformed. Must be between `0.0` and `1.0`. When + `strength=1.0`, `images` is essentially ignore and added noise + is maximum and the denoising process runs for the full number of + iterations specified in `num_steps`. + seed: optional int. Used as a random seed. + """ + num_steps = int(num_steps) + guidance_scale = float(guidance_scale) + strength = float(strength) + if strength < 0.0 or strength > 1.0: + raise ValueError( + "`strength` must be between `0.0` and `1.0`. " + f"Received strength={strength}." + ) + starting_step = int(num_steps * (1.0 - strength)) + starting_step = ops.convert_to_tensor(starting_step, "int32") + num_steps = ops.convert_to_tensor(num_steps, "int32") + guidance_scale = ops.convert_to_tensor(guidance_scale) + + # Check `inputs` format. + required_keys = ["images", "prompts"] + if tf and isinstance(inputs, tf.data.Dataset): + spec = inputs.element_spec + if not all(key in spec for key in required_keys): + raise ValueError( + "Expected a `tf.data.Dataset` with the following keys:" + f"{required_keys}. Received: inputs.element_spec={spec}" + ) + else: + if not isinstance(inputs, dict): + raise ValueError( + "Expected a `dict` or `tf.data.Dataset`. " + f"Received: inputs={inputs} of type {type(inputs)}." + ) + if not all(key in inputs for key in required_keys): + raise ValueError( + "Expected a `dict` with the following keys:" + f"{required_keys}. " + f"Received: inputs.keys={list(inputs.keys())}" + ) + + # Setup our three main passes. + # 1. Preprocessing strings to dense integer tensors. + # 2. Generate outputs via a compiled function on dense tensors. + # 3. Postprocess dense tensors to a value range of `[0, 255]`. + generate_function = self.make_generate_function() + + def preprocess(x): + if self.preprocessor is not None: + return self.preprocessor.generate_preprocess(x) + else: + return x + + def generate(images, x): + token_ids = x[0] if self.support_negative_prompts else x + + # Initialize noises. + if isinstance(token_ids, dict): + arbitrary_key = list(token_ids.keys())[0] + batch_size = ops.shape(token_ids[arbitrary_key])[0] + else: + batch_size = ops.shape(token_ids)[0] + noise_shape = (batch_size,) + self.latent_shape[1:] + noises = random.normal(noise_shape, dtype="float32", seed=seed) + + return generate_function( + images, noises, x, starting_step, num_steps, guidance_scale + ) + + # Normalize and preprocess inputs. + inputs, input_is_scalar = self._normalize_generate_inputs(inputs) + if self.support_negative_prompts: + images = [x["images"] for x in inputs] + token_ids = [preprocess(x["prompts"]) for x in inputs] + negative_token_ids = [ + preprocess(x["negative_prompts"]) for x in inputs + ] + # Tuple format: (images, (token_ids, negative_token_ids)). + inputs = [ + x for x in zip(images, zip(token_ids, negative_token_ids)) + ] + else: + images = [x["images"] for x in inputs] + token_ids = [preprocess(x["prompts"]) for x in inputs] + # Tuple format: (images, token_ids). + inputs = [x for x in zip(images, token_ids)] + + # Image-to-image. + outputs = [generate(*x) for x in inputs] + return self._normalize_generate_outputs(outputs, input_is_scalar) diff --git a/keras_hub/src/models/inpaint.py b/keras_hub/src/models/inpaint.py new file mode 100644 index 0000000000..40bcc7ad15 --- /dev/null +++ b/keras_hub/src/models/inpaint.py @@ -0,0 +1,513 @@ +import itertools +from functools import partial + +import keras +from keras import ops +from keras import random + +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.task import Task +from keras_hub.src.utils.keras_utils import standardize_data_format + +try: + import tensorflow as tf +except ImportError: + tf = None + + +@keras_hub_export("keras_hub.models.Inpaint") +class Inpaint(Task): + """Base class for image-to-image tasks. + + `Inpaint` tasks wrap a `keras_hub.models.Backbone` and + a `keras_hub.models.Preprocessor` to create a model that can be used for + generation and generative fine-tuning. + + `Inpaint` tasks provide an additional, high-level `generate()` function + which can be used to generate image by token with a (image, mask, string) + in, image out signature. + + All `Inpaint` tasks include a `from_preset()` constructor which can be + used to load a pre-trained config and weights. + + Example: + + ```python + # Load a Stable Diffusion 3 backbone with pre-trained weights. + reference_image = np.ones((1024, 1024, 3), dtype="float32") + reference_mask = np.ones((1024, 1024), dtype="float32") + inpaint = keras_hub.models.Inpaint.from_preset( + "stable_diffusion_3_medium", + ) + inpaint.generate( + reference_image, + reference_mask, + "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k", + ) + + # Load a Stable Diffusion 3 backbone at bfloat16 precision. + inpaint = keras_hub.models.Inpaint.from_preset( + "stable_diffusion_3_medium", + dtype="bfloat16", + ) + inpaint.generate( + reference_image, + reference_mask, + "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k", + ) + ``` + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # Default compilation. + self.compile() + + @property + def support_negative_prompts(self): + """Whether the model supports `negative_prompts` key in `generate()`.""" + return bool(True) + + @property + def image_shape(self): + return tuple(self.backbone.image_shape) + + @property + def latent_shape(self): + return tuple(self.backbone.latent_shape) + + def compile( + self, + optimizer="auto", + loss="auto", + *, + metrics="auto", + **kwargs, + ): + """Configures the `Inpaint` task for training. + + The `Inpaint` task extends the default compilation signature of + `keras.Model.compile` with defaults for `optimizer`, `loss`, and + `metrics`. To override these defaults, pass any value + to these arguments during compilation. + + Args: + optimizer: `"auto"`, an optimizer name, or a `keras.Optimizer` + instance. Defaults to `"auto"`, which uses the default optimizer + for the given model and task. See `keras.Model.compile` and + `keras.optimizers` for more info on possible `optimizer` values. + loss: `"auto"`, a loss name, or a `keras.losses.Loss` instance. + Defaults to `"auto"`, where a + `keras.losses.MeanSquaredError` loss will be applied. See + `keras.Model.compile` and `keras.losses` for more info on + possible `loss` values. + metrics: `"auto"`, or a list of metrics to be evaluated by + the model during training and testing. Defaults to `"auto"`, + where a `keras.metrics.MeanSquaredError` will be applied to + track the loss of the model during training. See + `keras.Model.compile` and `keras.metrics` for more info on + possible `metrics` values. + **kwargs: See `keras.Model.compile` for a full list of arguments + supported by the compile method. + """ + # Ref: https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image.py#L410-L414 + if optimizer == "auto": + optimizer = keras.optimizers.AdamW( + 1e-4, weight_decay=1e-2, epsilon=1e-8, clipnorm=1.0 + ) + if loss == "auto": + loss = keras.losses.MeanSquaredError() + if metrics == "auto": + metrics = [keras.metrics.MeanSquaredError()] + super().compile( + optimizer=optimizer, + loss=loss, + metrics=metrics, + **kwargs, + ) + self.generate_function = None + + def generate_step(self, *args, **kwargs): + """Run generation on batches of input.""" + raise NotImplementedError + + def make_generate_function(self): + """Create or return the compiled generation function.""" + if self.generate_function is not None: + return self.generate_function + + self.generate_function = self.generate_step + if keras.config.backend() == "torch": + import torch + + def wrapped_function(*args, **kwargs): + with torch.no_grad(): + return self.generate_step(*args, **kwargs) + + self.generate_function = wrapped_function + elif keras.config.backend() == "tensorflow" and not self.run_eagerly: + self.generate_function = tf.function( + self.generate_step, jit_compile=self.jit_compile + ) + elif keras.config.backend() == "jax" and not self.run_eagerly: + import jax + + @partial(jax.jit) + def compiled_function(state, *args, **kwargs): + ( + trainable_variables, + non_trainable_variables, + ) = state + mapping = itertools.chain( + zip(self.trainable_variables, trainable_variables), + zip(self.non_trainable_variables, non_trainable_variables), + ) + + with keras.StatelessScope(state_mapping=mapping): + outputs = self.generate_step(*args, **kwargs) + return outputs + + def wrapped_function(*args, **kwargs): + # Create an explicit tuple of all variable state. + state = ( + # Use the explicit variable.value to preserve the + # sharding spec of distribution. + [v.value for v in self.trainable_variables], + [v.value for v in self.non_trainable_variables], + ) + outputs = compiled_function(state, *args, **kwargs) + return outputs + + self.generate_function = wrapped_function + return self.generate_function + + def _normalize_generate_images(self, inputs): + """Normalize user image to the generate function. + + This function converts all inputs to tensors, adds a batch dimension if + necessary, and returns a iterable "dataset like" object (either an + actual `tf.data.Dataset` or a list with a single batch element). + """ + if tf and isinstance(inputs, tf.data.Dataset): + return inputs.as_numpy_iterator(), False + + def normalize(x): + data_format = getattr( + self.backbone, "data_format", standardize_data_format(None) + ) + input_is_scalar = False + x = ops.convert_to_tensor(x) + if len(ops.shape(x)) < 4: + x = ops.expand_dims(x, axis=0) + input_is_scalar = True + x = ops.image.resize( + x, + (self.backbone.image_shape[0], self.backbone.image_shape[1]), + interpolation="nearest", + data_format=data_format, + ) + return x, input_is_scalar + + if isinstance(inputs, dict): + for key in inputs: + inputs[key], input_is_scalar = normalize(inputs[key]) + else: + inputs, input_is_scalar = normalize(inputs) + + return inputs, input_is_scalar + + def _normalize_generate_masks(self, inputs): + """Normalize user masks to the generate function. + + This function converts all inputs to tensors, adds a batch dimension if + necessary, and returns a iterable "dataset like" object (either an + actual `tf.data.Dataset` or a list with a single batch element). + """ + if tf and isinstance(inputs, tf.data.Dataset): + return inputs.as_numpy_iterator(), False + + def normalize(x): + data_format = getattr( + self.backbone, "data_format", standardize_data_format(None) + ) + input_is_scalar = False + x = ops.convert_to_tensor(x) + if len(ops.shape(x)) < 3: + x = ops.expand_dims(x, axis=0) + input_is_scalar = True + x = ops.expand_dims(x, axis=-1) + if keras.backend.standardize_dtype(x.dtype) == "bool": + x = ops.cast(x, "float32") + x = ops.image.resize( + x, + (self.backbone.image_shape[0], self.backbone.image_shape[1]), + interpolation="nearest", + data_format=data_format, + ) + x = ops.squeeze(x, axis=-1) + return x, input_is_scalar + + if isinstance(inputs, dict): + for key in inputs: + inputs[key], input_is_scalar = normalize(inputs[key]) + else: + inputs, input_is_scalar = normalize(inputs) + + return inputs, input_is_scalar + + def _normalize_generate_inputs(self, inputs): + """Normalize user input to the generate function. + + This function converts all inputs to tensors, adds a batch dimension if + necessary, and returns a iterable "dataset like" object (either an + actual `tf.data.Dataset` or a list with a single batch element). + + The input format must be one of the following: + - A dict with "images", "masks", "prompts" and/or "negative_prompts" + keys + - A tf.data.Dataset with "images", "masks", "prompts" and/or + "negative_prompts" keys + + The output will be a dict with "images", "masks", "prompts" and/or + "negative_prompts" keys. + """ + if tf and isinstance(inputs, tf.data.Dataset): + _inputs = { + "images": inputs.map(lambda x: x["images"]).as_numpy_iterator(), + "masks": inputs.map(lambda x: x["masks"]).as_numpy_iterator(), + "prompts": inputs.map( + lambda x: x["prompts"] + ).as_numpy_iterator(), + } + if self.support_negative_prompts: + _inputs["negative_prompts"] = inputs.map( + lambda x: x["negative_prompts"] + ).as_numpy_iterator() + return _inputs, False + + def normalize(x): + if isinstance(x, str): + return [x], True + if tf and isinstance(x, tf.Tensor) and x.shape.rank == 0: + return x[tf.newaxis], True + return x, False + + def normalize_images(x): + data_format = getattr( + self.backbone, "data_format", standardize_data_format(None) + ) + input_is_scalar = False + x = ops.convert_to_tensor(x) + if len(ops.shape(x)) < 4: + x = ops.expand_dims(x, axis=0) + input_is_scalar = True + x = ops.image.resize( + x, + (self.backbone.image_shape[0], self.backbone.image_shape[1]), + interpolation="nearest", + data_format=data_format, + ) + return x, input_is_scalar + + def normalize_masks(x): + data_format = getattr( + self.backbone, "data_format", standardize_data_format(None) + ) + input_is_scalar = False + x = ops.convert_to_tensor(x) + if len(ops.shape(x)) < 3: + x = ops.expand_dims(x, axis=0) + input_is_scalar = True + x = ops.expand_dims(x, axis=-1) + if keras.backend.standardize_dtype(x.dtype) == "bool": + x = ops.cast(x, "float32") + x = ops.image.resize( + x, + (self.backbone.image_shape[0], self.backbone.image_shape[1]), + interpolation="nearest", + data_format=data_format, + ) + x = ops.squeeze(x, axis=-1) + return x, input_is_scalar + + def get_dummy_prompts(x): + dummy_prompts = [""] * len(x) + if tf and isinstance(x, tf.Tensor): + return tf.convert_to_tensor(dummy_prompts) + else: + return dummy_prompts + + for key in inputs: + if key == "images": + inputs[key], input_is_scalar = normalize_images(inputs[key]) + elif key == "masks": + inputs[key], input_is_scalar = normalize_masks(inputs[key]) + else: + inputs[key], input_is_scalar = normalize(inputs[key]) + + if self.support_negative_prompts and "negative_prompts" not in inputs: + inputs["negative_prompts"] = get_dummy_prompts(inputs["prompts"]) + + return [inputs], input_is_scalar + + def _normalize_generate_outputs(self, outputs, input_is_scalar): + """Normalize user output from the generate function. + + This function converts all output to numpy with a value range of + `[0, 255]`. If a batch dimension was added to the input, it is removed + from the output. + """ + + def normalize(x): + outputs = ops.concatenate(x, axis=0) + outputs = ops.clip(ops.divide(ops.add(outputs, 1.0), 2.0), 0.0, 1.0) + outputs = ops.cast(ops.round(ops.multiply(outputs, 255.0)), "uint8") + outputs = ops.squeeze(outputs, 0) if input_is_scalar else outputs + return ops.convert_to_numpy(outputs) + + if isinstance(outputs[0], dict): + normalized = {} + for key in outputs[0]: + normalized[key] = normalize([x[key] for x in outputs]) + return normalized + return normalize([x for x in outputs]) + + def generate( + self, + inputs, + num_steps, + guidance_scale, + strength, + seed=None, + ): + """Generate image based on the provided `inputs`. + + Typically, `inputs` is a dict with `"images"` `"masks"` and `"prompts"` + keys. `"images"` are reference images within a value range of + `[-1.0, 1.0]`, which will be resized to height and width from + `self.backbone.image_shape`, then encoded into latent space by the VAE + encoder. `"masks"` are mask images with a boolean dtype, where white + pixels are repainted while black pixels are preserved. `"prompts"` are + strings that will be tokenized and encoded by the text encoder. + + Some models support a `"negative_prompts"` key, which helps steer the + model away from generating certain styles and elements. To enable this, + add `"negative_prompts"` to the input dict. + + If `inputs` are a `tf.data.Dataset`, outputs will be generated + "batch-by-batch" and concatenated. Otherwise, all inputs will be + processed as batches. + + Args: + inputs: python data, tensor data, or a `tf.data.Dataset`. The format + must be one of the following: + - A dict with `"images"`, `"masks"`, `"prompts"` and/or + `"negative_prompts"` keys. + - A `tf.data.Dataset` with `"images"`, `"masks"`, `"prompts"` + and/or `"negative_prompts"` keys. + num_steps: int. The number of diffusion steps to take. + guidance_scale: float. The classifier free guidance scale defined in + [Classifier-Free Diffusion Guidance]( + https://arxiv.org/abs/2207.12598). A higher scale encourages + generating images more closely related to the prompts, typically + at the cost of lower image quality. + strength: float. Indicates the extent to which the reference + `images` are transformed. Must be between `0.0` and `1.0`. When + `strength=1.0`, `images` is essentially ignore and added noise + is maximum and the denoising process runs for the full number of + iterations specified in `num_steps`. + seed: optional int. Used as a random seed. + """ + num_steps = int(num_steps) + guidance_scale = float(guidance_scale) + strength = float(strength) + if strength < 0.0 or strength > 1.0: + raise ValueError( + "`strength` must be between `0.0` and `1.0`. " + f"Received strength={strength}." + ) + starting_step = int(num_steps * (1.0 - strength)) + starting_step = ops.convert_to_tensor(starting_step, "int32") + num_steps = ops.convert_to_tensor(num_steps, "int32") + guidance_scale = ops.convert_to_tensor(guidance_scale) + + # Check `inputs` format. + required_keys = ["images", "masks", "prompts"] + if tf and isinstance(inputs, tf.data.Dataset): + spec = inputs.element_spec + if not all(key in spec for key in required_keys): + raise ValueError( + "Expected a `tf.data.Dataset` with the following keys:" + f"{required_keys}. Received: inputs.element_spec={spec}" + ) + else: + if not isinstance(inputs, dict): + raise ValueError( + "Expected a `dict` or `tf.data.Dataset`. " + f"Received: inputs={inputs} of type {type(inputs)}." + ) + if not all(key in inputs for key in required_keys): + raise ValueError( + "Expected a `dict` with the following keys:" + f"{required_keys}. " + f"Received: inputs.keys={list(inputs.keys())}" + ) + + # Setup our three main passes. + # 1. Preprocessing strings to dense integer tensors. + # 2. Generate outputs via a compiled function on dense tensors. + # 3. Postprocess dense tensors to a value range of `[0, 255]`. + generate_function = self.make_generate_function() + + def preprocess(x): + if self.preprocessor is not None: + return self.preprocessor.generate_preprocess(x) + else: + return x + + def generate(images, masks, x): + token_ids = x[0] if self.support_negative_prompts else x + + # Initialize noises. + if isinstance(token_ids, dict): + arbitrary_key = list(token_ids.keys())[0] + batch_size = ops.shape(token_ids[arbitrary_key])[0] + else: + batch_size = ops.shape(token_ids)[0] + noise_shape = (batch_size,) + self.latent_shape[1:] + noises = random.normal(noise_shape, dtype="float32", seed=seed) + + return generate_function( + images, + masks, + noises, + x, + starting_step, + num_steps, + guidance_scale, + ) + + # Normalize and preprocess inputs. + inputs, input_is_scalar = self._normalize_generate_inputs(inputs) + if self.support_negative_prompts: + images = [x["images"] for x in inputs] + masks = [x["masks"] for x in inputs] + token_ids = [preprocess(x["prompts"]) for x in inputs] + negative_token_ids = [ + preprocess(x["negative_prompts"]) for x in inputs + ] + # Tuple format: (images, masks, (token_ids, negative_token_ids)). + inputs = [ + x + for x in zip(images, masks, zip(token_ids, negative_token_ids)) + ] + else: + images = [x["images"] for x in inputs] + masks = [x["masks"] for x in inputs] + token_ids = [preprocess(x["prompts"]) for x in inputs] + # Tuple format: (images, masks, token_ids). + inputs = [x for x in zip(images, masks, token_ids)] + + # Inpaint. + outputs = [generate(*x) for x in inputs] + return self._normalize_generate_outputs(outputs, input_is_scalar) diff --git a/keras_hub/src/models/llama/llama_backbone.py b/keras_hub/src/models/llama/llama_backbone.py index a654bdf267..0e923c29cd 100644 --- a/keras_hub/src/models/llama/llama_backbone.py +++ b/keras_hub/src/models/llama/llama_backbone.py @@ -59,7 +59,7 @@ class LlamaBackbone(Backbone): } # Pretrained Llama decoder. - model = keras_hub.models.LlamaBackbone.from_preset("llama7b_base_en") + model = keras_hub.models.LlamaBackbone.from_preset("llama2_7b_en") model(input_data) # Randomly initialized Llama decoder with custom config. @@ -175,3 +175,121 @@ def get_config(self): } ) return config + + @staticmethod + def get_layout_map( + device_mesh, + model_parallel_dim_name="model", + data_parallel_dim_name="batch", + ): + """Get a `keras.distribution.LayoutMap` for model parallel distribution. + + The returned `LayoutMap` contains the sharding spec for the Llama + backbone weights, so that you can use it to distribute weights across + the accelerators. + + Example: + ``` + # Feel free to change the mesh shape to balance data and model parallelism + mesh = keras.distribution.DeviceMesh( + shape=(1, 8), + axis_names=('batch', 'model'), + devices=keras.distribution.list_devices(), + ) + layout_map = LlamaBackbone.get_layout_map( + mesh, + model_parallel_dim_name="model", + ) + + distribution = keras.distribution.ModelParallel( + layout_map=layout_map, + batch_dim_name='batch', + ) + + with distribution.scope(): + llama_model = keras_hub.models.LlamaCausalLM.from_preset() + ``` + + To see how the layout map was applied, load the model then run (for one decoder block): + ``` + embedding_layer = llama_model.backbone.get_layer("token_embedding") + decoder_block_1 = llama_model.backbone.get_layer('transformer_layer_0') + for variable in embedding_layer.weights + decoder_block_1.weights: + print(f'{variable.path:<58} {str(variable.shape):<16} {str(variable.value.sharding.spec)}') + ``` + + Args: + device_mesh: The `keras.distribution.DeviceMesh` instance for + distribution. + model_parallel_dim_name: The axis name of the device mesh, where + the weights should be partition on. + data_parallel_dim_name: The axis name of the device mesh, where + the data should be partition on. + Return: + `keras.distribution.LayoutMap` that contains the sharding spec + for all the model weights. + """ + # The weight path and shape of the Llama backbone is like below + # token_embedding/embeddings (128256, 2048) + # repeat block for decoder + # transformer_layer_0/self_attention/query/kernel (2048, 32, 64) + # transformer_layer_0/self_attention/key/kernel (2048, 8, 64) + # transformer_layer_0/self_attention/value/kernel (2048, 8, 64) + # transformer_layer_0/self_attention/attention_output/kernel (32, 64, 2048) + # transformer_layer_0/self_attention_layernorm/scale (2048,) + # transformer_layer_0/feedforward_intermediate_dense/kernel (2048, 8192) + # transformer_layer_0/feedforward_gate_dense/kernel (2048, 8192) + # transformer_layer_0/feedforward_output_dense/kernel (8192, 2048) + # transformer_layer_0/feedforward_layernorm/scale (2048,) + + if not isinstance(device_mesh, keras.distribution.DeviceMesh): + raise ValueError( + "Invalid device_mesh type. Expected `keras.distribution.Device`," + f" got {type(device_mesh)}" + ) + if model_parallel_dim_name not in device_mesh.axis_names: + raise ValueError( + f"{model_parallel_dim_name} is not found in the " + f"device_mesh.axis_names. {device_mesh.axis_name=}" + ) + if data_parallel_dim_name not in device_mesh.axis_names: + raise ValueError( + f"{data_parallel_dim_name} is not found in the " + f"device_mesh.axis_names. {device_mesh.axis_name=}" + ) + # Note that it is possible to further config the mesh to be 3D, eg + # (data, seq, model). We leave it as 2D for now for simplicity. + data_dim = data_parallel_dim_name + model_dim = model_parallel_dim_name + # The sharding config is based on the Gemma team training config. + # See https://arxiv.org/abs/2403.08295 + layout_map = keras.distribution.LayoutMap(device_mesh) + layout_map["token_embedding/embeddings"] = (model_dim, data_dim) + layout_map[ + "transformer_layer.*self_attention.*(query|key|value).kernel" + ] = ( + model_dim, + data_dim, + None, + ) + layout_map["transformer_layer.*attention_output.kernel"] = ( + model_dim, + None, + data_dim, + ) + layout_map[ + "transformer_layer.*feedforward_intermediate_dense.kernel" + ] = ( + data_dim, + model_dim, + ) + layout_map["transformer_layer.*feedforward_gate_dense.kernel"] = ( + data_dim, + model_dim, + ) + layout_map["transformer_layer.*feedforward_output_dense.kernel"] = ( + model_dim, + data_dim, + ) + + return layout_map diff --git a/keras_hub/src/models/llama/llama_backbone_test.py b/keras_hub/src/models/llama/llama_backbone_test.py index 3b8eca49fe..0007dd7a96 100644 --- a/keras_hub/src/models/llama/llama_backbone_test.py +++ b/keras_hub/src/models/llama/llama_backbone_test.py @@ -1,3 +1,4 @@ +import keras import pytest from keras import ops @@ -66,3 +67,87 @@ def test_all_presets(self): preset=preset, input_data=self.input_data, ) + + def test_distribution(self): + if keras.backend.backend() != "jax": + self.skipTest("`ModelParallel` testing requires the Jax backend.") + devices = keras.distribution.list_devices("CPU") + if len(devices) == 1: + self.skipTest("`ModelParallel` testing requires multiple devices.") + device_mesh = keras.distribution.DeviceMesh( + shape=(1, len(devices)), + axis_names=("batch", "model"), + devices=devices, + ) + + layout_map = LlamaBackbone.get_layout_map(device_mesh) + distribution = keras.distribution.ModelParallel(layout_map=layout_map) + with distribution.scope(): + model = LlamaBackbone(**self.init_kwargs) + + for w in model.weights: + if "token_embedding/embeddings" in w.path: + self.assertEqual( + tuple(w.value.sharding.spec), ("model", "batch") + ) + if "self_attention/query/kernel" in w.path: + self.assertEqual( + tuple(w.value.sharding.spec), ("model", "batch", None) + ) + if "self_attention/key/kernel" in w.path: + self.assertEqual( + tuple(w.value.sharding.spec), ("model", "batch", None) + ) + if "self_attention/value/kernel" in w.path: + self.assertEqual( + tuple(w.value.sharding.spec), ("model", "batch", None) + ) + if "self_attention/attention_output/kernel" in w.path: + self.assertEqual( + tuple(w.value.sharding.spec), ("model", None, "batch") + ) + if "feedforward_intermediate_dense/kernel" in w.path: + self.assertEqual( + tuple(w.value.sharding.spec), ("batch", "model") + ) + if "feedforward_gate_dense/kernel" in w.path: + self.assertEqual( + tuple(w.value.sharding.spec), ("batch", "model") + ) + if "feedforward_output_dense" in w.path: + self.assertEqual( + tuple(w.value.sharding.spec), ("model", "batch") + ) + + def test_distribution_with_lora(self): + if keras.backend.backend() != "jax": + self.skipTest("`ModelParallel` testing requires the Jax backend.") + devices = keras.distribution.list_devices("CPU") + if len(devices) == 1: + # Need more than 1 device for distribution testing. + self.skipTest("`ModelParallel` testing requires multiple devices.") + device_mesh = keras.distribution.DeviceMesh( + shape=(1, len(devices)), + axis_names=("batch", "model"), + devices=devices, + ) + + layout_map = LlamaBackbone.get_layout_map(device_mesh) + distribution = keras.distribution.ModelParallel(layout_map=layout_map) + with distribution.scope(): + model = LlamaBackbone(**self.init_kwargs) + model.enable_lora(rank=4) + + for w in model.weights: + if "self_attention/query/lora_kernel_a" in w.path: + self.assertEqual( + tuple(w.value.sharding.spec), (None, None, None) + ) + if "self_attention/query/lora_kernel_b" in w.path: + self.assertEqual(tuple(w.value.sharding.spec), (None, None)) + if "self_attention/value/lora_kernel_a" in w.path: + self.assertEqual( + tuple(w.value.sharding.spec), (None, None, None) + ) + if "self_attention/value/lora_kernel_b" in w.path: + self.assertEqual(tuple(w.value.sharding.spec), (None, None)) diff --git a/keras_hub/src/models/llama/llama_causal_lm.py b/keras_hub/src/models/llama/llama_causal_lm.py index 7e1e319f1d..7f0f901d52 100644 --- a/keras_hub/src/models/llama/llama_causal_lm.py +++ b/keras_hub/src/models/llama/llama_causal_lm.py @@ -42,7 +42,9 @@ def __init__(self, backbone, preprocessor=None, **kwargs): self.preprocessor = preprocessor # === Functional Model === - inputs = backbone.inputs + # This must be "backbone.input" i.e. the full input structure, + # rather than "backbone.inputs" which is the flattened list of inputs. + inputs = backbone.input hidden_states = backbone(inputs) outputs = backbone.token_embedding(hidden_states, reverse=True) super().__init__( diff --git a/keras_hub/src/models/llama/llama_presets.py b/keras_hub/src/models/llama/llama_presets.py index 6197cfe07f..f72a0ec95f 100644 --- a/keras_hub/src/models/llama/llama_presets.py +++ b/keras_hub/src/models/llama/llama_presets.py @@ -7,7 +7,7 @@ "description": "7 billion parameter, 32-layer, base LLaMA 2 model.", "params": 6738415616, "official_name": "LLaMA 2", - "path": "llama2", + "path": "llama", "model_card": "https://github.com/meta-llama/llama", }, "kaggle_handle": "kaggle://keras/llama2/keras/llama2_7b_en/1", @@ -20,7 +20,7 @@ ), "params": 6739839488, "official_name": "LLaMA 2", - "path": "llama2", + "path": "llama", "model_card": "https://github.com/meta-llama/llama", }, "kaggle_handle": "kaggle://keras/llama2/keras/llama2_7b_en_int8/1", @@ -33,7 +33,7 @@ ), "params": 6738415616, "official_name": "LLaMA 2", - "path": "llama2", + "path": "llama", "model_card": "https://github.com/meta-llama/llama", }, "kaggle_handle": "kaggle://keras/llama2/keras/llama2_instruct_7b_en/1", @@ -46,7 +46,7 @@ ), "params": 6739839488, "official_name": "LLaMA 2", - "path": "llama2", + "path": "llama", "model_card": "https://github.com/meta-llama/llama", }, "kaggle_handle": "kaggle://keras/llama2/keras/llama2_instruct_7b_en_int8/1", @@ -59,7 +59,7 @@ ), "params": 6738415616, "official_name": "Vicuna", - "path": "vicuna", + "path": "llama", "model_card": "https://github.com/lm-sys/FastChat", }, "kaggle_handle": "kaggle://keras/vicuna/keras/vicuna_1.5_7b_en/1", diff --git a/keras_hub/src/models/llama3/llama3_causal_lm_preprocessor_test.py b/keras_hub/src/models/llama3/llama3_causal_lm_preprocessor_test.py index b8b45d8fd6..f79be674fb 100644 --- a/keras_hub/src/models/llama3/llama3_causal_lm_preprocessor_test.py +++ b/keras_hub/src/models/llama3/llama3_causal_lm_preprocessor_test.py @@ -11,6 +11,8 @@ class Llama3CausalLMPreprocessorTest(TestCase): def setUp(self): self.vocab = ["!", "air", "Ä air", "plane", "Ä at", "port"] self.vocab += ["<|begin_of_text|>", "<|end_of_text|>"] + self.vocab += ["<|start_header_id|>", "<|end_header_id|>"] + self.vocab += ["<|eot_id|>"] self.vocab = dict([(token, i) for i, token in enumerate(self.vocab)]) self.merges = ["Ä  a", "Ä  t", "Ä  i", "Ä  b", "a i", "p l", "n e"] self.merges += ["Ä a t", "p o", "r t", "Ä t h", "ai r", "pl a", "po rt"] diff --git a/keras_hub/src/models/llama3/llama3_causal_lm_test.py b/keras_hub/src/models/llama3/llama3_causal_lm_test.py index 995c1a00e1..a054b8ae14 100644 --- a/keras_hub/src/models/llama3/llama3_causal_lm_test.py +++ b/keras_hub/src/models/llama3/llama3_causal_lm_test.py @@ -16,6 +16,8 @@ class Llama3CausalLMTest(TestCase): def setUp(self): self.vocab = ["!", "air", "Ä air", "plane", "Ä at", "port"] self.vocab += ["<|begin_of_text|>", "<|end_of_text|>"] + self.vocab += ["<|start_header_id|>", "<|end_header_id|>"] + self.vocab += ["<|eot_id|>"] self.vocab = dict([(token, i) for i, token in enumerate(self.vocab)]) self.merges = ["Ä  a", "Ä  t", "Ä  i", "Ä  b", "a i", "p l", "n e"] self.merges += ["Ä a t", "p o", "r t", "Ä t h", "ai r", "pl a", "po rt"] @@ -44,7 +46,7 @@ def test_causal_lm_basics(self): cls=Llama3CausalLM, init_kwargs=self.init_kwargs, train_data=self.train_data, - expected_output_shape=(2, 7, 8), + expected_output_shape=(2, 7, 11), ) def test_generate(self): @@ -67,6 +69,12 @@ def test_generate(self): prompt_ids["padding_mask"][:, :5], ) + def test_generate_strip_prompt(self): + causal_lm = Llama3CausalLM(**self.init_kwargs) + prompt = " airplane at airport" + output = causal_lm.generate(prompt, strip_prompt=True) + self.assertFalse(output.startswith(prompt)) + def test_early_stopping(self): causal_lm = Llama3CausalLM(**self.init_kwargs) call_with_cache = causal_lm.call_with_cache diff --git a/keras_hub/src/models/llama3/llama3_tokenizer.py b/keras_hub/src/models/llama3/llama3_tokenizer.py index 397b5e1923..ee3037e854 100644 --- a/keras_hub/src/models/llama3/llama3_tokenizer.py +++ b/keras_hub/src/models/llama3/llama3_tokenizer.py @@ -16,10 +16,33 @@ def __init__( self, vocabulary=None, merges=None, + bos_token="<|begin_of_text|>", + eos_token="<|end_of_text|>", + misc_special_tokens={"<|start_header_id|>", "<|end_header_id|>"}, **kwargs, ): - self._add_special_token("<|begin_of_text|>", "start_token") - self._add_special_token("<|end_of_text|>", "end_token") + # Note: all special tokens must also appear in "vocabulary" + + self._add_special_token(bos_token, "start_token") + misc_special_tokens -= {bos_token} + self._add_special_token(eos_token, "end_token") + misc_special_tokens -= {eos_token} + for i, token in enumerate(misc_special_tokens): + self._add_special_token(token, f"special_token_{i:03d}") + + # Hack: + # Llama models use the <|end_of_text|> or the <|eot_id|> as the stop + # token. This info can be read from config when loading a Hugging Face + # checkpoint but no such config exists for Keras checkpoints. + # Setting both probable end tokens when no config is availble will + # make text generation work in all cases as it will stop + # on both end tokens. However, the packer will always use + # "<|end_of_text|>" , which will be the wrong eos_token for "instruct" + # variants of Llama3. + # TODO: load this correctly from a Keras tokenizer config. + if eos_token == "<|end_of_text|>": + self._add_special_token("<|eot_id|>", "end_token2") + self.pad_token_id = 0 super().__init__( vocabulary=vocabulary, diff --git a/keras_hub/src/models/llama3/llama3_tokenizer_test.py b/keras_hub/src/models/llama3/llama3_tokenizer_test.py index 8440d8ebb2..aff591de04 100644 --- a/keras_hub/src/models/llama3/llama3_tokenizer_test.py +++ b/keras_hub/src/models/llama3/llama3_tokenizer_test.py @@ -8,6 +8,8 @@ class Llama3TokenizerTest(TestCase): def setUp(self): self.vocab = ["!", "air", "Ä air", "plane", "Ä at", "port"] self.vocab += ["<|end_of_text|>", "<|begin_of_text|>"] + self.vocab += ["<|start_header_id|>", "<|end_header_id|>"] + self.vocab += ["<|eot_id|>"] self.vocab = dict([(token, i) for i, token in enumerate(self.vocab)]) self.merges = ["Ä  a", "Ä  t", "Ä  i", "Ä  b", "a i", "p l", "n e"] self.merges += ["Ä a t", "p o", "r t", "Ä t h", "ai r", "pl a", "po rt"] diff --git a/keras_hub/src/models/mistral/mistral_causal_lm.py b/keras_hub/src/models/mistral/mistral_causal_lm.py index 7f7ff03d14..06170aa089 100644 --- a/keras_hub/src/models/mistral/mistral_causal_lm.py +++ b/keras_hub/src/models/mistral/mistral_causal_lm.py @@ -42,7 +42,9 @@ def __init__(self, backbone, preprocessor=None, **kwargs): self.preprocessor = preprocessor # === Functional Model === - inputs = backbone.inputs + # This must be "backbone.input" i.e. the full input structure, + # rather than "backbone.inputs" which is the flattened list of inputs. + inputs = backbone.input hidden_states = backbone(inputs) outputs = backbone.token_embedding(hidden_states, reverse=True) super().__init__( diff --git a/keras_hub/src/models/mit/__init__.py b/keras_hub/src/models/mit/__init__.py new file mode 100644 index 0000000000..f581202b1c --- /dev/null +++ b/keras_hub/src/models/mit/__init__.py @@ -0,0 +1,6 @@ +from keras_hub.src.models.mit.mit_backbone import MiTBackbone +from keras_hub.src.models.mit.mit_image_classifier import MiTImageClassifier +from keras_hub.src.models.mit.mit_presets import backbone_presets +from keras_hub.src.utils.preset_utils import register_presets + +register_presets(backbone_presets, MiTBackbone) diff --git a/keras_hub/src/models/mix_transformer/mix_transformer_backbone.py b/keras_hub/src/models/mit/mit_backbone.py similarity index 87% rename from keras_hub/src/models/mix_transformer/mix_transformer_backbone.py rename to keras_hub/src/models/mit/mit_backbone.py index e8f881aee3..a6c57816c4 100644 --- a/keras_hub/src/models/mix_transformer/mix_transformer_backbone.py +++ b/keras_hub/src/models/mit/mit_backbone.py @@ -1,15 +1,22 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import keras import numpy as np from keras import ops from keras_hub.src.api_export import keras_hub_export from keras_hub.src.models.feature_pyramid_backbone import FeaturePyramidBackbone -from keras_hub.src.models.mix_transformer.mix_transformer_layers import ( - HierarchicalTransformerEncoder, -) -from keras_hub.src.models.mix_transformer.mix_transformer_layers import ( - OverlappingPatchingAndEmbedding, -) +from keras_hub.src.models.mit.mit_layers import HierarchicalTransformerEncoder +from keras_hub.src.models.mit.mit_layers import OverlappingPatchingAndEmbedding @keras_hub_export("keras_hub.models.MiTBackbone") @@ -61,7 +68,7 @@ def __init__( ```python images = np.ones(shape=(1, 96, 96, 3)) labels = np.zeros(shape=(1, 96, 96, 1)) - backbone = keras_hub.models.MiTBackbone.from_preset("mit_b0_imagenet") + backbone = keras_hub.models.MiTBackbone.from_preset("mit_b0_ade20k_512") # Evaluate model model(images) @@ -104,7 +111,7 @@ def __init__( ] transformer_blocks.append(transformer_block) cur += depths[i] - layer_norms.append(keras.layers.LayerNormalization()) + layer_norms.append(keras.layers.LayerNormalization(epsilon=1e-5)) # === Functional Model === image_input = keras.layers.Input(shape=image_shape) diff --git a/keras_hub/src/models/mit/mit_backbone_test.py b/keras_hub/src/models/mit/mit_backbone_test.py new file mode 100644 index 0000000000..88c58e96a2 --- /dev/null +++ b/keras_hub/src/models/mit/mit_backbone_test.py @@ -0,0 +1,45 @@ +import numpy as np +import pytest + +from keras_hub.src.models.mit.mit_backbone import MiTBackbone +from keras_hub.src.tests.test_case import TestCase + + +class MiTBackboneTest(TestCase): + def setUp(self): + self.init_kwargs = { + "depths": [2, 2], + "image_shape": (32, 32, 3), + "hidden_dims": [4, 8], + "num_layers": 2, + "blockwise_num_heads": [1, 2], + "blockwise_sr_ratios": [8, 4], + "max_drop_path_rate": 0.1, + "patch_sizes": [7, 3], + "strides": [4, 2], + } + self.input_size = 32 + self.input_data = np.ones( + (2, self.input_size, self.input_size, 3), dtype="float32" + ) + + def test_backbone_basics(self): + self.run_vision_backbone_test( + cls=MiTBackbone, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + expected_output_shape=(2, 4, 4, 8), + expected_pyramid_output_keys=["P1", "P2"], + expected_pyramid_image_sizes=[(8, 8), (4, 4)], + run_quantization_check=False, + run_mixed_precision_check=False, + run_data_format_check=False, + ) + + @pytest.mark.large + def test_saved_model(self): + self.run_model_saving_test( + cls=MiTBackbone, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + ) diff --git a/keras_hub/src/models/mix_transformer/mix_transformer_classifier.py b/keras_hub/src/models/mit/mit_image_classifier.py similarity index 53% rename from keras_hub/src/models/mix_transformer/mix_transformer_classifier.py rename to keras_hub/src/models/mit/mit_image_classifier.py index 0daac9327f..370920ddf9 100644 --- a/keras_hub/src/models/mix_transformer/mix_transformer_classifier.py +++ b/keras_hub/src/models/mit/mit_image_classifier.py @@ -1,10 +1,12 @@ from keras_hub.src.api_export import keras_hub_export from keras_hub.src.models.image_classifier import ImageClassifier -from keras_hub.src.models.mix_transformer.mix_transformer_backbone import ( - MiTBackbone, +from keras_hub.src.models.mit.mit_backbone import MiTBackbone +from keras_hub.src.models.mit.mit_image_classifier_preprocessor import ( + MiTImageClassifierPreprocessor, ) @keras_hub_export("keras_hub.models.MiTImageClassifier") class MiTImageClassifier(ImageClassifier): backbone_cls = MiTBackbone + preprocessor_cls = MiTImageClassifierPreprocessor diff --git a/keras_hub/src/models/mit/mit_image_classifier_preprocessor.py b/keras_hub/src/models/mit/mit_image_classifier_preprocessor.py new file mode 100644 index 0000000000..d3859c30d6 --- /dev/null +++ b/keras_hub/src/models/mit/mit_image_classifier_preprocessor.py @@ -0,0 +1,12 @@ +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.image_classifier_preprocessor import ( + ImageClassifierPreprocessor, +) +from keras_hub.src.models.mit.mit_backbone import MiTBackbone +from keras_hub.src.models.mit.mit_image_converter import MiTImageConverter + + +@keras_hub_export("keras_hub.models.MiTImageClassifierPreprocessor") +class MiTImageClassifierPreprocessor(ImageClassifierPreprocessor): + backbone_cls = MiTBackbone + image_converter_cls = MiTImageConverter diff --git a/keras_hub/src/models/mix_transformer/mix_transformer_classifier_test.py b/keras_hub/src/models/mit/mit_image_classifier_test.py similarity index 78% rename from keras_hub/src/models/mix_transformer/mix_transformer_classifier_test.py rename to keras_hub/src/models/mit/mit_image_classifier_test.py index fb7ff5ce2b..32055c47ed 100644 --- a/keras_hub/src/models/mix_transformer/mix_transformer_classifier_test.py +++ b/keras_hub/src/models/mit/mit_image_classifier_test.py @@ -1,23 +1,19 @@ import numpy as np import pytest -from keras_hub.src.models.mix_transformer.mix_transformer_backbone import ( - MiTBackbone, -) -from keras_hub.src.models.mix_transformer.mix_transformer_classifier import ( - MiTImageClassifier, -) +from keras_hub.src.models.mit.mit_backbone import MiTBackbone +from keras_hub.src.models.mit.mit_image_classifier import MiTImageClassifier from keras_hub.src.tests.test_case import TestCase class MiTImageClassifierTest(TestCase): def setUp(self): # Setup model. - self.images = np.ones((2, 16, 16, 3), dtype="float32") + self.images = np.ones((2, 32, 32, 3), dtype="float32") self.labels = [0, 3] self.backbone = MiTBackbone( depths=[2, 2, 2, 2], - image_shape=(16, 16, 3), + image_shape=(32, 32, 3), hidden_dims=[4, 8], num_layers=2, blockwise_num_heads=[1, 2], @@ -44,7 +40,7 @@ def test_classifier_basics(self): cls=MiTImageClassifier, init_kwargs=self.init_kwargs, train_data=self.train_data, - expected_output_shape=(2, 2), + expected_output_shape=(4, 4), ) @pytest.mark.large diff --git a/keras_hub/src/models/mit/mit_image_converter.py b/keras_hub/src/models/mit/mit_image_converter.py new file mode 100644 index 0000000000..269fcb88fd --- /dev/null +++ b/keras_hub/src/models/mit/mit_image_converter.py @@ -0,0 +1,8 @@ +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.layers.preprocessing.image_converter import ImageConverter +from keras_hub.src.models.mit import MiTBackbone + + +@keras_hub_export("keras_hub.layers.MiTImageConverter") +class MiTImageConverter(ImageConverter): + backbone_cls = MiTBackbone diff --git a/keras_hub/src/models/mix_transformer/mix_transformer_layers.py b/keras_hub/src/models/mit/mit_layers.py similarity index 92% rename from keras_hub/src/models/mix_transformer/mix_transformer_layers.py rename to keras_hub/src/models/mit/mit_layers.py index 42402da7ea..b949fcb6e2 100644 --- a/keras_hub/src/models/mix_transformer/mix_transformer_layers.py +++ b/keras_hub/src/models/mit/mit_layers.py @@ -28,19 +28,23 @@ def __init__(self, project_dim=32, patch_size=7, stride=4, **kwargs): self.patch_size = patch_size self.stride = stride + padding_size = self.patch_size // 2 + + self.padding = keras.layers.ZeroPadding2D( + padding=(padding_size, padding_size) + ) self.proj = keras.layers.Conv2D( filters=project_dim, kernel_size=patch_size, strides=stride, - padding="same", + padding="valid", ) - self.norm = keras.layers.LayerNormalization() + self.norm = keras.layers.LayerNormalization(epsilon=1e-5) def call(self, x): + x = self.padding(x) x = self.proj(x) - # B, H, W, C - shape = x.shape - x = ops.reshape(x, (-1, shape[1] * shape[2], shape[3])) + x = ops.reshape(x, (-1, x.shape[1] * x.shape[2], x.shape[3])) x = self.norm(x) return x @@ -179,20 +183,21 @@ def __init__(self, project_dim, num_heads, sr_ratio): self.k = keras.layers.Dense(project_dim) self.v = keras.layers.Dense(project_dim) self.proj = keras.layers.Dense(project_dim) + self.dropout = keras.layers.Dropout(0.1) + self.proj_drop = keras.layers.Dropout(0.1) if sr_ratio > 1: self.sr = keras.layers.Conv2D( filters=project_dim, kernel_size=sr_ratio, strides=sr_ratio, - padding="same", ) - self.norm = keras.layers.LayerNormalization() + self.norm = keras.layers.LayerNormalization(epsilon=1e-5) def call(self, x): input_shape = ops.shape(x) H, W = int(math.sqrt(input_shape[1])), int(math.sqrt(input_shape[1])) - B, C = input_shape[0], input_shape[2] + B, N, C = input_shape[0], input_shape[1], input_shape[2] q = self.q(x) q = ops.reshape( @@ -208,12 +213,11 @@ def call(self, x): if self.sr_ratio > 1: x = ops.reshape( - ops.transpose(x, [0, 2, 1]), + x, (B, H, W, C), ) x = self.sr(x) - x = ops.reshape(x, [input_shape[0], input_shape[2], -1]) - x = ops.transpose(x, [0, 2, 1]) + x = ops.reshape(x, [B, -1, C]) x = self.norm(x) k = self.k(x) @@ -237,14 +241,16 @@ def call(self, x): attn = (q @ ops.transpose(k, [0, 1, 3, 2])) * self.scale attn = ops.nn.softmax(attn, axis=-1) + attn = self.dropout(attn) attn = attn @ v attn = ops.reshape( ops.transpose(attn, [0, 2, 1, 3]), - [input_shape[0], input_shape[1], input_shape[2]], + [B, N, C], ) x = self.proj(attn) + x = self.proj_drop(x) return x diff --git a/keras_hub/src/models/mit/mit_presets.py b/keras_hub/src/models/mit/mit_presets.py new file mode 100644 index 0000000000..9c2a5fe362 --- /dev/null +++ b/keras_hub/src/models/mit/mit_presets.py @@ -0,0 +1,151 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""MiT model preset configurations.""" + +backbone_presets_with_weights = { + "mit_b0_ade20k_512": { + "metadata": { + "description": ( + "MiT (MixTransformer) model with 8 transformer blocks." + ), + "params": 3321962, + "official_name": "MiT", + "path": "mit", + }, + "kaggle_handle": "kaggle://keras/mit/keras/mit_b0_ade20k_512/1", + }, + "mit_b1_ade20k_512": { + "metadata": { + "description": ( + "MiT (MixTransformer) model with 8 transformer blocks." + ), + "params": 13156554, + "official_name": "MiT", + "path": "mit", + }, + "kaggle_handle": "kaggle://keras/mit/keras/mit_b1_ade20k_512/1", + }, + "mit_b2_ade20k_512": { + "metadata": { + "description": ( + "MiT (MixTransformer) model with 16 transformer blocks." + ), + "params": 24201418, + "official_name": "MiT", + "path": "mit", + }, + "kaggle_handle": "kaggle://keras/mit/keras/mit_b2_ade20k_512/1", + }, + "mit_b3_ade20k_512": { + "metadata": { + "description": ( + "MiT (MixTransformer) model with 28 transformer blocks." + ), + "params": 44077258, + "official_name": "MiT", + "path": "mit", + }, + "kaggle_handle": "kaggle://keras/mit/keras/mit_b3_ade20k_512/1", + }, + "mit_b4_ade20k_512": { + "metadata": { + "description": ( + "MiT (MixTransformer) model with 41 transformer blocks." + ), + "params": 60847818, + "official_name": "MiT", + "path": "mit", + }, + "kaggle_handle": "kaggle://keras/mit/keras/mit_b4_ade20k_512/1", + }, + "mit_b5_ade20k_640": { + "metadata": { + "description": ( + "MiT (MixTransformer) model with 52 transformer blocks." + ), + "params": 81448138, + "official_name": "MiT", + "path": "mit", + }, + "kaggle_handle": "kaggle://keras/mit/keras/mit_b5_ade20k_640/1", + }, + "mit_b0_cityscapes_1024": { + "metadata": { + "description": ( + "MiT (MixTransformer) model with 8 transformer blocks." + ), + "params": 3321962, + "official_name": "MiT", + "path": "mit", + }, + "kaggle_handle": "kaggle://keras/mit/keras/mit_b0_cityscapes_1024/1", + }, + "mit_b1_cityscapes_1024": { + "metadata": { + "description": ( + "MiT (MixTransformer) model with 8 transformer blocks." + ), + "params": 13156554, + "official_name": "MiT", + "path": "mit", + }, + "kaggle_handle": "kaggle://keras/mit/keras/mit_b1_cityscapes_1024/1", + }, + "mit_b2_cityscapes_1024": { + "metadata": { + "description": ( + "MiT (MixTransformer) model with 16 transformer blocks." + ), + "params": 24201418, + "official_name": "MiT", + "path": "mit", + }, + "kaggle_handle": "kaggle://keras/mit/keras/mit_b2_cityscapes_1024/1", + }, + "mit_b3_cityscapes_1024": { + "metadata": { + "description": ( + "MiT (MixTransformer) model with 28 transformer blocks." + ), + "params": 44077258, + "official_name": "MiT", + "path": "mit", + }, + "kaggle_handle": "kaggle://keras/mit/keras/mit_b3_cityscapes_1024/1", + }, + "mit_b4_cityscapes_1024": { + "metadata": { + "description": ( + "MiT (MixTransformer) model with 41 transformer blocks." + ), + "params": 60847818, + "official_name": "MiT", + "path": "mit", + }, + "kaggle_handle": "kaggle://keras/mit/keras/mit_b4_cityscapes_1024/1", + }, + "mit_b5_cityscapes_1024": { + "metadata": { + "description": ( + "MiT (MixTransformer) model with 52 transformer blocks." + ), + "params": 81448138, + "official_name": "MiT", + "path": "mit", + }, + "kaggle_handle": "kaggle://keras/mit/keras/mit_b5_cityscapes_1024/1", + }, +} + +backbone_presets = { + **backbone_presets_with_weights, +} diff --git a/keras_hub/src/models/mix_transformer/mix_transformer_backbone_test.py b/keras_hub/src/models/mit/mix_transformer_backbone_test.py similarity index 81% rename from keras_hub/src/models/mix_transformer/mix_transformer_backbone_test.py rename to keras_hub/src/models/mit/mix_transformer_backbone_test.py index b3840f5c07..88c58e96a2 100644 --- a/keras_hub/src/models/mix_transformer/mix_transformer_backbone_test.py +++ b/keras_hub/src/models/mit/mix_transformer_backbone_test.py @@ -1,9 +1,7 @@ import numpy as np import pytest -from keras_hub.src.models.mix_transformer.mix_transformer_backbone import ( - MiTBackbone, -) +from keras_hub.src.models.mit.mit_backbone import MiTBackbone from keras_hub.src.tests.test_case import TestCase @@ -11,7 +9,7 @@ class MiTBackboneTest(TestCase): def setUp(self): self.init_kwargs = { "depths": [2, 2], - "image_shape": (16, 16, 3), + "image_shape": (32, 32, 3), "hidden_dims": [4, 8], "num_layers": 2, "blockwise_num_heads": [1, 2], @@ -20,7 +18,7 @@ def setUp(self): "patch_sizes": [7, 3], "strides": [4, 2], } - self.input_size = 16 + self.input_size = 32 self.input_data = np.ones( (2, self.input_size, self.input_size, 3), dtype="float32" ) @@ -30,9 +28,9 @@ def test_backbone_basics(self): cls=MiTBackbone, init_kwargs=self.init_kwargs, input_data=self.input_data, - expected_output_shape=(2, 2, 2, 8), + expected_output_shape=(2, 4, 4, 8), expected_pyramid_output_keys=["P1", "P2"], - expected_pyramid_image_sizes=[(4, 4), (2, 2)], + expected_pyramid_image_sizes=[(8, 8), (4, 4)], run_quantization_check=False, run_mixed_precision_check=False, run_data_format_check=False, diff --git a/keras_hub/src/models/mix_transformer/__init__.py b/keras_hub/src/models/mix_transformer/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/keras_hub/src/models/mobilenet/mobilenet_backbone.py b/keras_hub/src/models/mobilenet/mobilenet_backbone.py index 34ddbda6d0..e40eac32b1 100644 --- a/keras_hub/src/models/mobilenet/mobilenet_backbone.py +++ b/keras_hub/src/models/mobilenet/mobilenet_backbone.py @@ -1,560 +1,583 @@ -import keras -from keras import ops - -from keras_hub.src.api_export import keras_hub_export -from keras_hub.src.models.backbone import Backbone - -BN_EPSILON = 1e-5 -BN_MOMENTUM = 0.9 - - -@keras_hub_export("keras_hub.models.MobileNetBackbone") -class MobileNetBackbone(Backbone): - """Instantiates the MobileNet architecture. - - MobileNet is a lightweight convolutional neural network (CNN) - optimized for mobile and edge devices, striking a balance between - accuracy and efficiency. By employing depthwise separable convolutions - and techniques like Squeeze-and-Excitation (SE) blocks, - MobileNet models are highly suitable for real-time applications on - resource-constrained devices. - - References: - - [MobileNets: Efficient Convolutional Neural Networks - for Mobile Vision Applications]( - https://arxiv.org/abs/1704.04861) - - [MobileNetV2: Inverted Residuals and Linear Bottlenecks]( - https://arxiv.org/abs/1801.04381) (CVPR 2018) - - [Searching for MobileNetV3](https://arxiv.org/pdf/1905.02244.pdf) - (ICCV 2019) - - Args: - stackwise_expansion: list of list of ints, the expanded filters for - each inverted residual block for each block in the model. - stackwise_num_blocks: list of ints, number of inversted residual blocks - per block - stackwise_num_filters: list of list of ints, number of filters for - each inverted residual block in the model. - stackwise_kernel_size: list of list of ints, kernel size for each - inverted residual block in the model. - stackwise_num_strides: list of list of ints, stride length for each - inverted residual block in the model. - stackwise_se_ratio: se ratio for each inverted residual block in the - model. 0 if dont want to add Squeeze and Excite layer. - stackwise_activation: list of list of activation functions, for each - inverted residual block in the model. - image_shape: optional shape tuple, defaults to (224, 224, 3). - input_num_filters: number of filters in first convolution layer - output_num_filters: specifies whether to add conv and batch_norm in the - end, if set to None, it will not add these layers in the end. - 'None' for MobileNetV1 - input_activation: activation function to be used in the input layer - 'hard_swish' for MobileNetV3, - 'relu6' for MobileNetV1 and MobileNetV2 - output_activation: activation function to be used in the output layer - 'hard_swish' for MobileNetV3, - 'relu6' for MobileNetV1 and MobileNetV2 - depthwise_filters: int, number of filters in depthwise separable - convolution layer - squeeze_and_excite: float, squeeze and excite ratio in the depthwise - layer, None, if dont want to do squeeze and excite - - - Example: - ```python - input_data = tf.ones(shape=(8, 224, 224, 3)) - - # Randomly initialized backbone with a custom config - model = MobileNetBackbone( - stackwise_expansion=[ - [40, 56], - [64, 144, 144], - [72, 72], - [144, 288, 288], - ], - stackwise_num_blocks=[2, 3, 2, 3], - stackwise_num_filters=[ - [16, 16], - [24, 24, 24], - [24, 24], - [48, 48, 48], - ], - stackwise_kernel_size=[[3, 3], [5, 5, 5], [5, 5], [5, 5, 5]], - stackwise_num_strides=[[2, 1], [2, 1, 1], [1, 1], [2, 1, 1]], - stackwise_se_ratio=[ - [None, None], - [0.25, 0.25, 0.25], - [0.3, 0.3], - [0.3, 0.25, 0.25], - ], - stackwise_activation=[ - ["relu", "relu"], - ["hard_swish", "hard_swish", "hard_swish"], - ["hard_swish", "hard_swish"], - ["hard_swish", "hard_swish", "hard_swish"], - ], - output_num_filters=288, - input_activation="hard_swish", - output_activation="hard_swish", - input_num_filters=16, - image_shape=(224, 224, 3), - depthwise_filters=8, - squeeze_and_excite=0.5, - - ) - output = model(input_data) - ``` - """ - - def __init__( - self, - stackwise_expansion, - stackwise_num_blocks, - stackwise_num_filters, - stackwise_kernel_size, - stackwise_num_strides, - stackwise_se_ratio, - stackwise_activation, - output_num_filters, - depthwise_filters, - last_layer_filter, - squeeze_and_excite=None, - image_shape=(224, 224, 3), - input_activation="hard_swish", - output_activation="hard_swish", - input_num_filters=16, - **kwargs, - ): - # === Functional Model === - channel_axis = ( - -1 if keras.config.image_data_format() == "channels_last" else 1 - ) - - image_input = keras.layers.Input(shape=image_shape) - x = image_input - input_num_filters = adjust_channels(input_num_filters) - x = keras.layers.Conv2D( - input_num_filters, - kernel_size=3, - strides=(2, 2), - padding="same", - data_format=keras.config.image_data_format(), - use_bias=False, - name="input_conv", - )(x) - x = keras.layers.BatchNormalization( - axis=channel_axis, - epsilon=BN_EPSILON, - momentum=BN_MOMENTUM, - name="input_batch_norm", - )(x) - x = keras.layers.Activation(input_activation)(x) - - x = apply_depthwise_conv_block( - x, depthwise_filters, se=squeeze_and_excite, name="block_0" - ) - - for block in range(len(stackwise_num_blocks)): - for inverted_block in range(stackwise_num_blocks[block]): - x = apply_inverted_res_block( - x, - expansion=stackwise_expansion[block][inverted_block], - filters=adjust_channels( - stackwise_num_filters[block][inverted_block] - ), - kernel_size=stackwise_kernel_size[block][inverted_block], - stride=stackwise_num_strides[block][inverted_block], - se_ratio=stackwise_se_ratio[block][inverted_block], - activation=stackwise_activation[block][inverted_block], - name=f"block_{block+1}_{inverted_block}", - ) - - x = ConvBnAct( - x, - filter=adjust_channels(last_layer_filter), - activation="hard_swish", - name=f"block_{len(stackwise_num_blocks)+1}_0", - ) - - last_conv_ch = adjust_channels(output_num_filters) - - x = keras.layers.Conv2D( - last_conv_ch, - kernel_size=1, - padding="same", - data_format=keras.config.image_data_format(), - use_bias=False, - name="output_conv", - )(x) - - # no output normalization in mobilenetv3 - if output_activation == "relu6": - x = keras.layers.BatchNormalization( - axis=channel_axis, - epsilon=BN_EPSILON, - momentum=BN_MOMENTUM, - name="output_batch_norm", - )(x) - - x = keras.layers.Activation(output_activation)(x) - - super().__init__(inputs=image_input, outputs=x, **kwargs) - - # === Config === - self.stackwise_expansion = stackwise_expansion - self.stackwise_num_blocks = stackwise_num_blocks - self.stackwise_num_filters = stackwise_num_filters - self.stackwise_kernel_size = stackwise_kernel_size - self.stackwise_num_strides = stackwise_num_strides - self.stackwise_se_ratio = stackwise_se_ratio - self.stackwise_activation = stackwise_activation - self.input_num_filters = input_num_filters - self.output_num_filters = output_num_filters - self.depthwise_filters = depthwise_filters - self.last_layer_filter = last_layer_filter - self.squeeze_and_excite = squeeze_and_excite - self.input_activation = keras.activations.get(input_activation) - self.output_activation = keras.activations.get(output_activation) - self.image_shape = image_shape - - def get_config(self): - config = super().get_config() - config.update( - { - "stackwise_expansion": self.stackwise_expansion, - "stackwise_num_blocks": self.stackwise_num_blocks, - "stackwise_num_filters": self.stackwise_num_filters, - "stackwise_kernel_size": self.stackwise_kernel_size, - "stackwise_num_strides": self.stackwise_num_strides, - "stackwise_se_ratio": self.stackwise_se_ratio, - "stackwise_activation": self.stackwise_activation, - "image_shape": self.image_shape, - "input_num_filters": self.input_num_filters, - "output_num_filters": self.output_num_filters, - "depthwise_filters": self.depthwise_filters, - "last_layer_filter": self.last_layer_filter, - "squeeze_and_excite": self.squeeze_and_excite, - "input_activation": keras.activations.serialize( - activation=self.input_activation - ), - "output_activation": keras.activations.serialize( - activation=self.output_activation - ), - } - ) - return config - - -def adjust_channels(x, divisor=8, min_value=None): - """Ensure that all layers have a channel number divisible by the `divisor`. - - Args: - x: integer, input value. - divisor: integer, the value by which a channel number should be - divisible, defaults to 8. - min_value: float, optional minimum value for the new tensor. If None, - defaults to value of divisor. - - Returns: - the updated input scalar. - """ - - if min_value is None: - min_value = divisor - - new_x = max(min_value, int(x + divisor / 2) // divisor * divisor) - - # make sure that round down does not go down by more than 10%. - if new_x < 0.9 * x: - new_x += divisor - return new_x - - -def apply_inverted_res_block( - x, - expansion, - filters, - kernel_size, - stride, - se_ratio, - activation, - name=None, -): - """An Inverted Residual Block. - - Args: - x: input tensor. - expansion: integer, the expansion ratio, multiplied with infilters to - get the minimum value passed to adjust_channels. - filters: integer, number of filters for convolution layer. - kernel_size: integer, the kernel size for DepthWise Convolutions. - stride: integer, the stride length for DepthWise Convolutions. - se_ratio: float, ratio for bottleneck filters. Number of bottleneck - filters = filters * se_ratio. - activation: the activation layer to use. - name: string, block label. - - Returns: - the updated input tensor. - """ - channel_axis = ( - -1 if keras.config.image_data_format() == "channels_last" else 1 - ) - activation = keras.activations.get(activation) - shortcut = x - infilters = x.shape[channel_axis] - expanded_channels = adjust_channels(expansion) - - x = keras.layers.Conv2D( - expanded_channels, - kernel_size=1, - padding="same", - data_format=keras.config.image_data_format(), - use_bias=False, - name=f"{name}_conv1", - )(x) - - x = keras.layers.BatchNormalization( - axis=channel_axis, - epsilon=BN_EPSILON, - momentum=BN_MOMENTUM, - name=f"{name}_bn1", - )(x) - - x = keras.layers.Activation(activation=activation)(x) - - if stride == 2: - x = keras.layers.ZeroPadding2D( - padding=correct_pad_downsample(x, kernel_size), - )(x) - - x = keras.layers.Conv2D( - expanded_channels, - kernel_size, - strides=stride, - padding="same" if stride == 1 else "valid", - groups=expanded_channels, - data_format=keras.config.image_data_format(), - use_bias=False, - name=f"{name}_conv2", - )(x) - x = keras.layers.BatchNormalization( - axis=channel_axis, - epsilon=BN_EPSILON, - momentum=BN_MOMENTUM, - name=f"{name}_bn2", - )(x) - - x = keras.layers.Activation(activation=activation)(x) - - if se_ratio: - se_filters = expanded_channels - x = SqueezeAndExcite2D( - input=x, - filters=se_filters, - bottleneck_filters=adjust_channels(se_filters * se_ratio), - squeeze_activation="relu", - excite_activation=keras.activations.hard_sigmoid, - name=f"{name}_se", - ) - - x = keras.layers.Conv2D( - filters, - kernel_size=1, - padding="same", - data_format=keras.config.image_data_format(), - use_bias=False, - name=f"{name}_conv3", - )(x) - x = keras.layers.BatchNormalization( - axis=channel_axis, - epsilon=BN_EPSILON, - momentum=BN_MOMENTUM, - name=f"{name}_bn3", - )(x) - - if stride == 1 and infilters == filters: - x = keras.layers.Add(name=f"{name}_add")([shortcut, x]) - return x - - -def apply_depthwise_conv_block( - x, filters, kernel_size=3, stride=1, se=None, name=None -): - """Adds a depthwise convolution block. - - A depthwise convolution block consists of a depthwise conv, - batch normalization, relu6, pointwise convolution, - batch normalization and relu6 activation. - - Args: - x: Input tensor of shape `(rows, cols, channels) - filters: Integer, the dimensionality of the output space - (i.e. the number of output filters in the pointwise convolution). - strides: An integer or tuple/list of 2 integers, specifying the strides - of the convolution along the width and height. - Can be a single integer to specify the same value for - all spatial dimensions. Specifying any stride value != 1 is - incompatible with specifying any `dilation_rate` value != 1. - block_id: Integer, a unique identification designating the block number. - - Input shape: - 4D tensor with shape: `(batch, rows, cols, channels)` in "channels_last" - 4D tensor with shape: `(batch, channels, rows, cols)` in "channels_first" - Returns: - Output tensor of block. - """ - channel_axis = ( - -1 if keras.config.image_data_format() == "channels_last" else 1 - ) - infilters = x.shape[channel_axis] - name = f"{name}_0" - - if stride == 2: - x = keras.layers.ZeroPadding2D( - padding=correct_pad_downsample(x, kernel_size), - )(x) - - x = keras.layers.Conv2D( - infilters, - kernel_size, - strides=stride, - padding="same" if stride == 1 else "valid", - data_format=keras.config.image_data_format(), - groups=infilters, - use_bias=False, - name=f"{name}_conv1", - )(x) - x = keras.layers.BatchNormalization( - axis=channel_axis, - epsilon=BN_EPSILON, - momentum=BN_MOMENTUM, - name=f"{name}_bn1", - )(x) - x = keras.layers.ReLU(6.0)(x) - - if se: - x = SqueezeAndExcite2D( - input=x, - filters=infilters, - bottleneck_filters=adjust_channels(infilters * se), - squeeze_activation="relu", - excite_activation=keras.activations.hard_sigmoid, - name=f"{name}_se", - ) - - x = keras.layers.Conv2D( - filters, - kernel_size=1, - padding="same", - data_format=keras.config.image_data_format(), - use_bias=False, - name=f"{name}_conv2", - )(x) - x = keras.layers.BatchNormalization( - axis=channel_axis, - epsilon=BN_EPSILON, - momentum=BN_MOMENTUM, - name=f"{name}_bn2", - )(x) - return x - - -def SqueezeAndExcite2D( - input, - filters, - bottleneck_filters=None, - squeeze_activation="relu", - excite_activation="sigmoid", - name=None, -): - """ - Description: - This layer applies a content-aware mechanism to adaptively assign - channel-wise weights. It uses global average pooling to compress - feature maps into single values, which are then processed by - two Conv1D layers: the first reduces the dimensionality, and - the second restores it. - Args: - filters: Number of input and output filters. The number of input and - output filters is same. - bottleneck_filters: (Optional) Number of bottleneck filters. Defaults - to `0.25 * filters` - squeeze_activation: (Optional) String, callable (or - keras.layers.Layer) or keras.activations.Activation instance - denoting activation to be applied after squeeze convolution. - Defaults to `relu`. - excite_activation: (Optional) String, callable (or - keras.layers.Layer) or keras.activations.Activation instance - denoting activation to be applied after excite convolution. - Defaults to `sigmoid`. - name: Name of the layer - """ - if not bottleneck_filters: - bottleneck_filters = filters // 4 - - x = input - x = keras.layers.Conv2D( - bottleneck_filters, - (1, 1), - data_format=keras.config.image_data_format(), - activation=squeeze_activation, - name=f"{name}_conv_reduce", - )(x) - x = keras.layers.Conv2D( - filters, - (1, 1), - data_format=keras.config.image_data_format(), - activation=excite_activation, - name=f"{name}_conv_expand", - )(x) - - x = ops.multiply(x, input) - return x - - -def ConvBnAct(x, filter, activation, name=None): - channel_axis = ( - -1 if keras.config.image_data_format() == "channels_last" else 1 - ) - x = keras.layers.Conv2D( - filter, - kernel_size=1, - padding="same", - data_format=keras.config.image_data_format(), - use_bias=False, - name=f"{name}_conv", - )(x) - x = keras.layers.BatchNormalization( - axis=channel_axis, - epsilon=BN_EPSILON, - momentum=BN_MOMENTUM, - name=f"{name}_bn", - )(x) - x = keras.layers.Activation(activation)(x) - return x - - -def correct_pad_downsample(inputs, kernel_size): - """Returns a tuple for zero-padding for 2D convolution with downsampling. - - Args: - inputs: Input tensor. - kernel_size: An integer or tuple/list of 2 integers. - - Returns: - A tuple. - """ - img_dim = 1 - input_size = inputs.shape[img_dim : (img_dim + 2)] - if isinstance(kernel_size, int): - kernel_size = (kernel_size, kernel_size) - if input_size[0] is None: - adjust = (1, 1) - else: - adjust = (1 - input_size[0] % 2, 1 - input_size[1] % 2) - correct = (kernel_size[0] // 2, kernel_size[1] // 2) - return ( - (correct[0] - adjust[0], correct[0]), - (correct[1] - adjust[1], correct[1]), - ) +import keras +from keras import ops + +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.backbone import Backbone + +BN_EPSILON = 1e-5 +BN_MOMENTUM = 0.9 + + +@keras_hub_export("keras_hub.models.MobileNetBackbone") +class MobileNetBackbone(Backbone): + """Instantiates the MobileNet architecture. + + MobileNet is a lightweight convolutional neural network (CNN) + optimized for mobile and edge devices, striking a balance between + accuracy and efficiency. By employing depthwise separable convolutions + and techniques like Squeeze-and-Excitation (SE) blocks, + MobileNet models are highly suitable for real-time applications on + resource-constrained devices. + + References: + - [MobileNets: Efficient Convolutional Neural Networks + for Mobile Vision Applications]( + https://arxiv.org/abs/1704.04861) + - [MobileNetV2: Inverted Residuals and Linear Bottlenecks]( + https://arxiv.org/abs/1801.04381) (CVPR 2018) + - [Searching for MobileNetV3](https://arxiv.org/pdf/1905.02244.pdf) + (ICCV 2019) + + Args: + stackwise_expansion: list of list of ints, the expanded filters for + each inverted residual block for each block in the model. + stackwise_num_blocks: list of ints, number of inversted residual blocks + per block + stackwise_num_filters: list of list of ints, number of filters for + each inverted residual block in the model. + stackwise_kernel_size: list of list of ints, kernel size for each + inverted residual block in the model. + stackwise_num_strides: list of list of ints, stride length for each + inverted residual block in the model. + stackwise_se_ratio: se ratio for each inverted residual block in the + model. 0 if dont want to add Squeeze and Excite layer. + stackwise_activation: list of list of activation functions, for each + inverted residual block in the model. + image_shape: optional shape tuple, defaults to (224, 224, 3). + input_num_filters: number of filters in first convolution layer + output_num_filters: specifies whether to add conv and batch_norm in the + end, if set to None, it will not add these layers in the end. + 'None' for MobileNetV1 + input_activation: activation function to be used in the input layer + 'hard_swish' for MobileNetV3, + 'relu6' for MobileNetV1 and MobileNetV2 + output_activation: activation function to be used in the output layer + 'hard_swish' for MobileNetV3, + 'relu6' for MobileNetV1 and MobileNetV2 + depthwise_filters: int, number of filters in depthwise separable + convolution layer + squeeze_and_excite: float, squeeze and excite ratio in the depthwise + layer, None, if dont want to do squeeze and excite + + + Example: + ```python + input_data = tf.ones(shape=(8, 224, 224, 3)) + + # Randomly initialized backbone with a custom config + model = MobileNetBackbone( + stackwise_expansion=[ + [40, 56], + [64, 144, 144], + [72, 72], + [144, 288, 288], + ], + stackwise_num_blocks=[2, 3, 2, 3], + stackwise_num_filters=[ + [16, 16], + [24, 24, 24], + [24, 24], + [48, 48, 48], + ], + stackwise_kernel_size=[[3, 3], [5, 5, 5], [5, 5], [5, 5, 5]], + stackwise_num_strides=[[2, 1], [2, 1, 1], [1, 1], [2, 1, 1]], + stackwise_se_ratio=[ + [None, None], + [0.25, 0.25, 0.25], + [0.3, 0.3], + [0.3, 0.25, 0.25], + ], + stackwise_activation=[ + ["relu", "relu"], + ["hard_swish", "hard_swish", "hard_swish"], + ["hard_swish", "hard_swish"], + ["hard_swish", "hard_swish", "hard_swish"], + ], + output_num_filters=288, + input_activation="hard_swish", + output_activation="hard_swish", + input_num_filters=16, + image_shape=(224, 224, 3), + depthwise_filters=8, + squeeze_and_excite=0.5, + + ) + output = model(input_data) + ``` + """ + + def __init__( + self, + stackwise_expansion, + stackwise_num_blocks, + stackwise_num_filters, + stackwise_kernel_size, + stackwise_num_strides, + stackwise_se_ratio, + stackwise_activation, + stackwise_padding, + output_num_filters, + depthwise_filters, + last_layer_filter, + squeeze_and_excite=None, + image_shape=(None, None, 3), + input_activation="hard_swish", + output_activation="hard_swish", + input_num_filters=16, + **kwargs, + ): + # === Functional Model === + channel_axis = ( + -1 if keras.config.image_data_format() == "channels_last" else 1 + ) + + image_input = keras.layers.Input(shape=image_shape) + x = image_input + input_num_filters = adjust_channels(input_num_filters) + + pad_width = ( + (0, 0), # No padding for batch + (1, 1), # 1 pixel padding for height + (1, 1), # 1 pixel padding for width + (0, 0), + ) # No padding for channels + x = ops.pad(x, pad_width=pad_width) + x = keras.layers.Conv2D( + input_num_filters, + kernel_size=3, + strides=(2, 2), + data_format=keras.config.image_data_format(), + use_bias=False, + name="input_conv", + )(x) + x = keras.layers.BatchNormalization( + axis=channel_axis, + epsilon=BN_EPSILON, + momentum=BN_MOMENTUM, + name="input_batch_norm", + )(x) + x = keras.layers.Activation(input_activation)(x) + + x = apply_depthwise_conv_block( + x, depthwise_filters, se=squeeze_and_excite, name="block_0" + ) + + for block in range(len(stackwise_num_blocks)): + for inverted_block in range(stackwise_num_blocks[block]): + x = apply_inverted_res_block( + x, + expansion=stackwise_expansion[block][inverted_block], + filters=adjust_channels( + stackwise_num_filters[block][inverted_block] + ), + kernel_size=stackwise_kernel_size[block][inverted_block], + stride=stackwise_num_strides[block][inverted_block], + se_ratio=stackwise_se_ratio[block][inverted_block], + activation=stackwise_activation[block][inverted_block], + padding=stackwise_padding[block][inverted_block], + name=f"block_{block+1}_{inverted_block}", + ) + + x = ConvBnAct( + x, + filter=adjust_channels(last_layer_filter), + activation="hard_swish", + name=f"block_{len(stackwise_num_blocks)+1}_0", + ) + + last_conv_ch = adjust_channels(output_num_filters) + + x = keras.layers.Conv2D( + last_conv_ch, + kernel_size=1, + data_format=keras.config.image_data_format(), + use_bias=False, + name="output_conv", + )(x) + + # no output normalization in mobilenetv3 + if output_activation == "relu6": + x = keras.layers.BatchNormalization( + axis=channel_axis, + epsilon=BN_EPSILON, + momentum=BN_MOMENTUM, + name="output_batch_norm", + )(x) + + x = keras.layers.Activation(output_activation)(x) + + super().__init__(inputs=image_input, outputs=x, **kwargs) + + # === Config === + self.stackwise_expansion = stackwise_expansion + self.stackwise_num_blocks = stackwise_num_blocks + self.stackwise_num_filters = stackwise_num_filters + self.stackwise_kernel_size = stackwise_kernel_size + self.stackwise_num_strides = stackwise_num_strides + self.stackwise_se_ratio = stackwise_se_ratio + self.stackwise_activation = stackwise_activation + self.stackwise_padding = stackwise_padding + self.input_num_filters = input_num_filters + self.output_num_filters = output_num_filters + self.depthwise_filters = depthwise_filters + self.last_layer_filter = last_layer_filter + self.squeeze_and_excite = squeeze_and_excite + self.input_activation = keras.activations.get(input_activation) + self.output_activation = keras.activations.get(output_activation) + self.image_shape = image_shape + + def get_config(self): + config = super().get_config() + config.update( + { + "stackwise_expansion": self.stackwise_expansion, + "stackwise_num_blocks": self.stackwise_num_blocks, + "stackwise_num_filters": self.stackwise_num_filters, + "stackwise_kernel_size": self.stackwise_kernel_size, + "stackwise_num_strides": self.stackwise_num_strides, + "stackwise_se_ratio": self.stackwise_se_ratio, + "stackwise_activation": self.stackwise_activation, + "stackwise_padding": self.stackwise_padding, + "image_shape": self.image_shape, + "input_num_filters": self.input_num_filters, + "output_num_filters": self.output_num_filters, + "depthwise_filters": self.depthwise_filters, + "last_layer_filter": self.last_layer_filter, + "squeeze_and_excite": self.squeeze_and_excite, + "input_activation": keras.activations.serialize( + activation=self.input_activation + ), + "output_activation": keras.activations.serialize( + activation=self.output_activation + ), + } + ) + return config + + +def adjust_channels(x, divisor=8, min_value=None): + """Ensure that all layers have a channel number divisible by the `divisor`. + + Args: + x: integer, input value. + divisor: integer, the value by which a channel number should be + divisible, defaults to 8. + min_value: float, optional minimum value for the new tensor. If None, + defaults to value of divisor. + + Returns: + the updated input scalar. + """ + + if min_value is None: + min_value = divisor + + new_x = max(min_value, int(x + divisor / 2) // divisor * divisor) + + # make sure that round down does not go down by more than 10%. + if new_x < 0.9 * x: + new_x += divisor + return new_x + + +def apply_inverted_res_block( + x, + expansion, + filters, + kernel_size, + stride, + se_ratio, + activation, + padding, + name=None, +): + """An Inverted Residual Block. + + Args: + x: input tensor. + expansion: integer, the expansion ratio, multiplied with infilters to + get the minimum value passed to adjust_channels. + filters: integer, number of filters for convolution layer. + kernel_size: integer, the kernel size for DepthWise Convolutions. + stride: integer, the stride length for DepthWise Convolutions. + se_ratio: float, ratio for bottleneck filters. Number of bottleneck + filters = filters * se_ratio. + activation: the activation layer to use. + padding: padding in the conv2d layer + name: string, block label. + + Returns: + the updated input tensor. + """ + channel_axis = ( + -1 if keras.config.image_data_format() == "channels_last" else 1 + ) + activation = keras.activations.get(activation) + shortcut = x + infilters = x.shape[channel_axis] + expanded_channels = adjust_channels(expansion) + + x = keras.layers.Conv2D( + expanded_channels, + kernel_size=1, + data_format=keras.config.image_data_format(), + use_bias=False, + name=f"{name}_conv1", + )(x) + + x = keras.layers.BatchNormalization( + axis=channel_axis, + epsilon=BN_EPSILON, + momentum=BN_MOMENTUM, + name=f"{name}_bn1", + )(x) + + x = keras.layers.Activation(activation=activation)(x) + + # if stride == 2: + # x = keras.layers.ZeroPadding2D( + # padding=correct_pad_downsample(x, kernel_size), + # )(x) + + # pad_width=[[padding, padding], [padding, padding]] + pad_width = ( + (0, 0), # No padding for batch + (padding, padding), # 1 pixel padding for height + (padding, padding), # 1 pixel padding for width + (0, 0), + ) # No padding for channels + x = ops.pad(x, pad_width=pad_width) + + x = keras.layers.Conv2D( + expanded_channels, + kernel_size, + strides=stride, + padding="valid", + groups=expanded_channels, + data_format=keras.config.image_data_format(), + use_bias=False, + name=f"{name}_conv2", + )(x) + x = keras.layers.BatchNormalization( + axis=channel_axis, + epsilon=BN_EPSILON, + momentum=BN_MOMENTUM, + name=f"{name}_bn2", + )(x) + + x = keras.layers.Activation(activation=activation)(x) + + if se_ratio: + se_filters = expanded_channels + x = SqueezeAndExcite2D( + input=x, + filters=se_filters, + bottleneck_filters=adjust_channels(se_filters * se_ratio), + squeeze_activation="relu", + excite_activation=keras.activations.hard_sigmoid, + name=f"{name}_se", + ) + + x = keras.layers.Conv2D( + filters, + kernel_size=1, + data_format=keras.config.image_data_format(), + use_bias=False, + name=f"{name}_conv3", + )(x) + x = keras.layers.BatchNormalization( + axis=channel_axis, + epsilon=BN_EPSILON, + momentum=BN_MOMENTUM, + name=f"{name}_bn3", + )(x) + + if stride == 1 and infilters == filters: + x = keras.layers.Add(name=f"{name}_add")([shortcut, x]) + return x + + +def apply_depthwise_conv_block( + x, filters, kernel_size=3, stride=2, se=None, name=None +): + """Adds a depthwise convolution block. + + A depthwise convolution block consists of a depthwise conv, + batch normalization, relu6, pointwise convolution, + batch normalization and relu6 activation. + + Args: + x: Input tensor of shape `(rows, cols, channels) + filters: Integer, the dimensionality of the output space + (i.e. the number of output filters in the pointwise convolution). + strides: An integer or tuple/list of 2 integers, specifying the strides + of the convolution along the width and height. + Can be a single integer to specify the same value for + all spatial dimensions. Specifying any stride value != 1 is + incompatible with specifying any `dilation_rate` value != 1. + block_id: Integer, a unique identification designating the block number. + + Input shape: + 4D tensor with shape: `(batch, rows, cols, channels)` in "channels_last" + 4D tensor with shape: `(batch, channels, rows, cols)` in "channels_first" + Returns: + Output tensor of block. + """ + channel_axis = ( + -1 if keras.config.image_data_format() == "channels_last" else 1 + ) + infilters = x.shape[channel_axis] + name = f"{name}_0" + + # if stride == 2: + # x = keras.layers.ZeroPadding2D( + # padding=correct_pad_downsample(x, kernel_size), + # )(x) + pad_width = ( + (0, 0), # No padding for batch + (1, 1), # 1 pixel padding for height + (1, 1), # 1 pixel padding for width + (0, 0), + ) # No padding for channels + x = ops.pad(x, pad_width=pad_width) + x = keras.layers.Conv2D( + infilters, + kernel_size, + strides=stride, + padding="valid", + data_format=keras.config.image_data_format(), + groups=infilters, + use_bias=False, + name=f"{name}_conv1", + )(x) + x = keras.layers.BatchNormalization( + axis=channel_axis, + epsilon=BN_EPSILON, + momentum=BN_MOMENTUM, + name=f"{name}_bn1", + )(x) + x = keras.layers.ReLU(6.0)(x) + + if se: + x = SqueezeAndExcite2D( + input=x, + filters=infilters, + bottleneck_filters=adjust_channels(infilters * se), + squeeze_activation="relu", + excite_activation=keras.activations.hard_sigmoid, + name=f"{name}_se", + ) + + x = keras.layers.Conv2D( + filters, + kernel_size=1, + data_format=keras.config.image_data_format(), + use_bias=False, + name=f"{name}_conv2", + )(x) + x = keras.layers.BatchNormalization( + axis=channel_axis, + epsilon=BN_EPSILON, + momentum=BN_MOMENTUM, + name=f"{name}_bn2", + )(x) + return x + + +def SqueezeAndExcite2D( + input, + filters, + bottleneck_filters=None, + squeeze_activation="relu", + excite_activation="sigmoid", + name=None, +): + """ + Description: + This layer applies a content-aware mechanism to adaptively assign + channel-wise weights. It uses global average pooling to compress + feature maps into single values, which are then processed by + two Conv1D layers: the first reduces the dimensionality, and + the second restores it. + Args: + filters: Number of input and output filters. The number of input and + output filters is same. + bottleneck_filters: (Optional) Number of bottleneck filters. Defaults + to `0.25 * filters` + squeeze_activation: (Optional) String, callable (or + keras.layers.Layer) or keras.activations.Activation instance + denoting activation to be applied after squeeze convolution. + Defaults to `relu`. + excite_activation: (Optional) String, callable (or + keras.layers.Layer) or keras.activations.Activation instance + denoting activation to be applied after excite convolution. + Defaults to `sigmoid`. + name: Name of the layer + """ + if not bottleneck_filters: + bottleneck_filters = filters // 4 + + x = input + x = keras.layers.Conv2D( + bottleneck_filters, + (1, 1), + data_format=keras.config.image_data_format(), + activation=squeeze_activation, + name=f"{name}_conv_reduce", + )(x) + x = keras.layers.Conv2D( + filters, + (1, 1), + data_format=keras.config.image_data_format(), + activation=excite_activation, + name=f"{name}_conv_expand", + )(x) + + x = ops.multiply(x, input) + return x + + +def ConvBnAct(x, filter, activation, name=None): + channel_axis = ( + -1 if keras.config.image_data_format() == "channels_last" else 1 + ) + x = keras.layers.Conv2D( + filter, + kernel_size=1, + data_format=keras.config.image_data_format(), + use_bias=False, + name=f"{name}_conv", + )(x) + x = keras.layers.BatchNormalization( + axis=channel_axis, + epsilon=BN_EPSILON, + momentum=BN_MOMENTUM, + name=f"{name}_bn", + )(x) + x = keras.layers.Activation(activation)(x) + return x + + +def correct_pad_downsample(inputs, kernel_size): + """Returns a tuple for zero-padding for 2D convolution with downsampling. + + Args: + inputs: Input tensor. + kernel_size: An integer or tuple/list of 2 integers. + + Returns: + A tuple. + """ + img_dim = 1 + input_size = inputs.shape[img_dim : (img_dim + 2)] + if isinstance(kernel_size, int): + kernel_size = (kernel_size, kernel_size) + if input_size[0] is None: + adjust = (1, 1) + else: + adjust = (1 - input_size[0] % 2, 1 - input_size[1] % 2) + correct = (kernel_size[0] // 2, kernel_size[1] // 2) + return ( + (correct[0] - adjust[0], correct[0]), + (correct[1] - adjust[1], correct[1]), + ) diff --git a/keras_hub/src/models/mobilenet/mobilenet_backbone_test.py b/keras_hub/src/models/mobilenet/mobilenet_backbone_test.py index 8119d0aa1b..3d909c9221 100644 --- a/keras_hub/src/models/mobilenet/mobilenet_backbone_test.py +++ b/keras_hub/src/models/mobilenet/mobilenet_backbone_test.py @@ -37,8 +37,8 @@ def setUp(self): "stackwise_se_ratio": [ [None, None], [0.25, 0.25, 0.25], - [0.3, 0.3], - [0.3, 0.25, 0.25], + [0.25, 0.25], + [0.25, 0.25, 0.25], ], "stackwise_activation": [ ["relu", "relu"], @@ -47,6 +47,7 @@ def setUp(self): ["hard_swish", "hard_swish", "hard_swish"], ["hard_swish"], ], + "stackwise_padding": [[1, 1], [2, 2, 2], [2, 2], [2, 2, 2]], "output_num_filters": 1024, "input_activation": "hard_swish", "output_activation": "hard_swish", @@ -63,7 +64,7 @@ def test_backbone_basics(self): cls=MobileNetBackbone, init_kwargs=self.init_kwargs, input_data=self.input_data, - expected_output_shape=(2, 14, 14, 1024), + expected_output_shape=(2, 7, 7, 1024), run_mixed_precision_check=False, run_data_format_check=False, ) diff --git a/keras_hub/src/models/mobilenet/mobilenet_image_classifier.py b/keras_hub/src/models/mobilenet/mobilenet_image_classifier.py index bf07914781..e9cc0fc153 100644 --- a/keras_hub/src/models/mobilenet/mobilenet_image_classifier.py +++ b/keras_hub/src/models/mobilenet/mobilenet_image_classifier.py @@ -5,6 +5,7 @@ MobileNetImageClassifierPreprocessor, ) + @keras_hub_export("keras_hub.models.MobileNetImageClassifier") class MobileNetImageClassifier(ImageClassifier): backbone_cls = MobileNetBackbone diff --git a/keras_hub/src/models/mobilenet/mobilenet_image_classifier_test.py b/keras_hub/src/models/mobilenet/mobilenet_image_classifier_test.py index 36adb46613..7997b444fd 100644 --- a/keras_hub/src/models/mobilenet/mobilenet_image_classifier_test.py +++ b/keras_hub/src/models/mobilenet/mobilenet_image_classifier_test.py @@ -32,8 +32,8 @@ def setUp(self): stackwise_se_ratio=[ [None, None], [0.25, 0.25, 0.25], - [0.3, 0.3], - [0.3, 0.25, 0.25], + [0.25, 0.25], + [0.25, 0.25, 0.25], ], stackwise_activation=[ ["relu", "relu"], @@ -41,6 +41,7 @@ def setUp(self): ["hard_swish", "hard_swish"], ["hard_swish", "hard_swish", "hard_swish"], ], + stackwise_padding=[[1, 1], [2, 2, 2], [2, 2], [2, 2, 2], [1]], output_num_filters=1024, input_activation="hard_swish", output_activation="hard_swish", @@ -71,6 +72,18 @@ def test_classifier_basics(self): expected_output_shape=(2, 2), ) + @pytest.mark.large + def test_smallest_preset(self): + # Test that our forward pass is stable! + image_batch = self.load_test_image()[None, ...] / 255.0 + self.run_preset_test( + cls=MobileNetImageClassifier, + preset="mobilenetv3_small_050", + input_data=image_batch, + expected_output_shape=(1, 1000), + expected_labels=[85], + ) + @pytest.mark.large def test_saved_model(self): self.run_model_saving_test( diff --git a/keras_hub/src/models/pali_gemma/pali_gemma_backbone.py b/keras_hub/src/models/pali_gemma/pali_gemma_backbone.py index e8eb1dd232..75c1cb8ad0 100644 --- a/keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +++ b/keras_hub/src/models/pali_gemma/pali_gemma_backbone.py @@ -61,8 +61,6 @@ class PaliGemmaBackbone(Backbone): vit_classifier_activation: activation function. The activation that is used for final output classification in the vision transformer. vit_name: string. The name used for vision transformer layers. - include_rescaling: bool. If true, the image input will be rescaled from - the range `[0, 255]`, to the range `[0, 1]`. layer_norm_epsilon: float. The epsilon value user for every layer norm in all transformer blocks. dropout: float. Dropout probability for the Transformer decoder blocks. @@ -121,7 +119,6 @@ def __init__( vit_pooling=None, vit_classifier_activation=None, vit_name=None, - include_rescaling=True, layer_norm_epsilon=1e-6, dropout=0, dtype=None, @@ -145,7 +142,6 @@ def __init__( vit_intermediate_dim = vit_intermediate_dim or 4304 self.vit_encoder = PaliGemmaVit( image_size=image_size, - include_rescaling=include_rescaling, patch_size=vit_patch_size, num_heads=vit_num_heads, hidden_dim=vit_hidden_dim, @@ -215,7 +211,6 @@ def __init__( # === Config === self.vocabulary_size = vocabulary_size self.image_size = image_size - self.include_rescaling = include_rescaling self.num_layers = num_layers self.num_query_heads = num_query_heads self.num_key_value_heads = num_key_value_heads @@ -242,7 +237,6 @@ def get_config(self): { "vocabulary_size": self.vocabulary_size, "image_size": self.image_size, - "include_rescaling": self.include_rescaling, "num_layers": self.num_layers, "num_query_heads": self.num_query_heads, "num_key_value_heads": self.num_key_value_heads, diff --git a/keras_hub/src/models/pali_gemma/pali_gemma_causal_lm.py b/keras_hub/src/models/pali_gemma/pali_gemma_causal_lm.py index 5419daee5b..a0f912add1 100644 --- a/keras_hub/src/models/pali_gemma/pali_gemma_causal_lm.py +++ b/keras_hub/src/models/pali_gemma/pali_gemma_causal_lm.py @@ -110,7 +110,9 @@ def __init__( self.backbone = backbone # === Functional Model === - inputs = backbone.inputs + # This must be "backbone.input" i.e. the full input structure, + # rather than "backbone.inputs" which is the flattened list of inputs. + inputs = backbone.input hidden_state = backbone(inputs=inputs) outputs = backbone.token_embedding(hidden_state, reverse=True) outputs = outputs[:, backbone.image_sequence_length :, :] diff --git a/keras_hub/src/models/pali_gemma/pali_gemma_presets.py b/keras_hub/src/models/pali_gemma/pali_gemma_presets.py index 3f642833a4..af5443fd1a 100644 --- a/keras_hub/src/models/pali_gemma/pali_gemma_presets.py +++ b/keras_hub/src/models/pali_gemma/pali_gemma_presets.py @@ -12,7 +12,7 @@ "path": "pali_gemma", "model_card": "https://www.kaggle.com/models/google/paligemma", }, - "kaggle_handle": "kaggle://keras/paligemma/keras/pali_gemma_3b_mix_224/2", + "kaggle_handle": "kaggle://keras/paligemma/keras/pali_gemma_3b_mix_224/3", }, "pali_gemma_3b_mix_448": { "metadata": { @@ -24,7 +24,7 @@ "path": "pali_gemma", "model_card": "https://www.kaggle.com/models/google/paligemma", }, - "kaggle_handle": "kaggle://keras/paligemma/keras/pali_gemma_3b_mix_448/2", + "kaggle_handle": "kaggle://keras/paligemma/keras/pali_gemma_3b_mix_448/3", }, "pali_gemma_3b_224": { "metadata": { @@ -36,7 +36,7 @@ "path": "pali_gemma", "model_card": "https://www.kaggle.com/models/google/paligemma", }, - "kaggle_handle": "kaggle://keras/paligemma/keras/pali_gemma_3b_224/2", + "kaggle_handle": "kaggle://keras/paligemma/keras/pali_gemma_3b_224/3", }, "pali_gemma_3b_448": { "metadata": { @@ -48,7 +48,7 @@ "path": "pali_gemma", "model_card": "https://www.kaggle.com/models/google/paligemma", }, - "kaggle_handle": "kaggle://keras/paligemma/keras/pali_gemma_3b_448/2", + "kaggle_handle": "kaggle://keras/paligemma/keras/pali_gemma_3b_448/3", }, "pali_gemma_3b_896": { "metadata": { @@ -60,6 +60,6 @@ "path": "pali_gemma", "model_card": "https://www.kaggle.com/models/google/paligemma", }, - "kaggle_handle": "kaggle://keras/paligemma/keras/pali_gemma_3b_896/2", + "kaggle_handle": "kaggle://keras/paligemma/keras/pali_gemma_3b_896/3", }, } diff --git a/keras_hub/src/models/pali_gemma/pali_gemma_vit.py b/keras_hub/src/models/pali_gemma/pali_gemma_vit.py index 20194a6039..190a5e8e13 100644 --- a/keras_hub/src/models/pali_gemma/pali_gemma_vit.py +++ b/keras_hub/src/models/pali_gemma/pali_gemma_vit.py @@ -410,8 +410,6 @@ class PaliGemmaVit(keras.Model): Args: image_size: int. The height/width of the image. Both height and width is expected to be the same. - include_rescaling: bool. If true, the image input will be rescaled from - the range `[0, 255]`, to the range `[0, 1]`. patch_size: int. The size of each square patch in the input image. num_heads: int. The number of attention heads for the vision(image) transformer encoder. @@ -452,7 +450,6 @@ def __init__( num_layers, intermediate_dim, num_classes, - include_rescaling=True, pooling=None, classifier_activation=None, dtype=None, @@ -463,14 +460,6 @@ def __init__( shape=(image_size, image_size, 3), name="images" ) x = image_input # Intermediate result. - # TODO we have moved this rescaling to preprocessing layers for most - # models. We should consider removing it here, though it would break - # compatibility. - if include_rescaling: - rescaling = keras.layers.Rescaling( - scale=1.0 / 127.5, offset=-1.0, name="rescaling" - ) - x = rescaling(image_input) x = PaliGemmaVitEncoder( hidden_dim=hidden_dim, num_layers=num_layers, @@ -520,7 +509,6 @@ def __init__( self.pooling = pooling self.num_classes = num_classes self.image_size = image_size - self.include_rescaling = include_rescaling self.patch_size = patch_size self.classifier_activation = keras.activations.get( classifier_activation @@ -549,7 +537,6 @@ def get_config(self): self.classifier_activation ), "image_size": self.image_size, - "include_rescaling": self.include_rescaling, "patch_size": self.patch_size, } ) diff --git a/keras_hub/src/models/pali_gemma/pali_gemma_vit_test.py b/keras_hub/src/models/pali_gemma/pali_gemma_vit_test.py index 9611590da0..76d11e356b 100644 --- a/keras_hub/src/models/pali_gemma/pali_gemma_vit_test.py +++ b/keras_hub/src/models/pali_gemma/pali_gemma_vit_test.py @@ -30,23 +30,6 @@ def test_vit_encoder(self): output.shape, (batch_size, intermediate_dim, hidden_dim) ) - def test_vit_rescaling(self): - vit_encoder = PaliGemmaVit( - image_size=16, - patch_size=4, - hidden_dim=8, - num_layers=2, - num_heads=2, - intermediate_dim=16, - num_classes=32, - ) - self.assertIsNotNone(vit_encoder.get_layer("rescaling")) - with self.assertRaises(ValueError): - config = vit_encoder.get_config() - config["include_rescaling"] = False - vit_encoder = PaliGemmaVit.from_config(config) - vit_encoder.get_layer("rescaling") - def test_vision_embeddings(self): embeddings_layer = PaliGemmaVitEmbeddings( image_size=16, diff --git a/keras_hub/src/models/phi3/phi3_causal_lm.py b/keras_hub/src/models/phi3/phi3_causal_lm.py index fed4c2ea27..a60c336afb 100644 --- a/keras_hub/src/models/phi3/phi3_causal_lm.py +++ b/keras_hub/src/models/phi3/phi3_causal_lm.py @@ -41,7 +41,9 @@ def __init__(self, backbone, preprocessor=None, **kwargs): self.preprocessor = preprocessor # === Functional Model === - inputs = backbone.inputs + # This must be "backbone.input" i.e. the full input structure, + # rather than "backbone.inputs" which is the flattened list of inputs. + inputs = backbone.input hidden_states = backbone(inputs) outputs = backbone.token_embedding(hidden_states, reverse=True) super().__init__( diff --git a/keras_hub/src/models/preprocessor.py b/keras_hub/src/models/preprocessor.py index f0569a36f8..f338b45339 100644 --- a/keras_hub/src/models/preprocessor.py +++ b/keras_hub/src/models/preprocessor.py @@ -32,7 +32,7 @@ class Preprocessor(PreprocessingLayer): image_converter_cls = None def __init__(self, *args, **kwargs): - self.config_name = kwargs.pop("config_name", PREPROCESSOR_CONFIG_FILE) + self.config_file = kwargs.pop("config_file", PREPROCESSOR_CONFIG_FILE) super().__init__(*args, **kwargs) self._tokenizer = None self._image_converter = None @@ -71,6 +71,22 @@ def image_converter(self): def image_converter(self, value): self._image_converter = value + @property + def image_size(self): + """Shortcut to get/set the image size of the image converter.""" + if self.image_converter is None: + return None + return self.image_converter.image_size + + @image_size.setter + def image_size(self, value): + if self.image_converter is None: + raise ValueError( + "Cannot set `image_size` on preprocessor if `image_converter` " + " is `None`." + ) + self.image_converter.image_size = value + def get_config(self): config = super().get_config() if self.tokenizer: @@ -85,7 +101,7 @@ def get_config(self): ) config.update( { - "config_name": self.config_name, + "config_file": self.config_file, } ) return config @@ -117,7 +133,7 @@ def presets(cls): def from_preset( cls, preset, - config_name=PREPROCESSOR_CONFIG_FILE, + config_file=PREPROCESSOR_CONFIG_FILE, **kwargs, ): """Instantiate a `keras_hub.models.Preprocessor` from a model preset. @@ -167,7 +183,7 @@ def from_preset( # Detect the correct subclass if we need to. if cls.backbone_cls != backbone_cls: cls = find_subclass(preset, cls, backbone_cls) - return loader.load_preprocessor(cls, config_name, **kwargs) + return loader.load_preprocessor(cls, config_file, **kwargs) @classmethod def _add_missing_kwargs(cls, loader, kwargs): diff --git a/keras_hub/src/models/resnet/resnet_backbone.py b/keras_hub/src/models/resnet/resnet_backbone.py index bc8def804a..407ce44f5b 100644 --- a/keras_hub/src/models/resnet/resnet_backbone.py +++ b/keras_hub/src/models/resnet/resnet_backbone.py @@ -68,7 +68,7 @@ class ResNetBackbone(FeaturePyramidBackbone): input_data = np.random.uniform(0, 1, size=(2, 224, 224, 3)) # Pretrained ResNet backbone. - model = keras_hub.models.ResNetBackbone.from_preset("resnet50") + model = keras_hub.models.ResNetBackbone.from_preset("resnet_50_imagenet") model(input_data) # Randomly initialized ResNetV2 backbone with a custom config. @@ -80,7 +80,6 @@ class ResNetBackbone(FeaturePyramidBackbone): stackwise_num_strides=[1, 2, 2], block_type="basic_block", use_pre_activation=True, - pooling="avg", ) model(input_data) ``` diff --git a/keras_hub/src/models/resnet/resnet_presets.py b/keras_hub/src/models/resnet/resnet_presets.py index 58bed3d90a..c3f7c17de6 100644 --- a/keras_hub/src/models/resnet/resnet_presets.py +++ b/keras_hub/src/models/resnet/resnet_presets.py @@ -12,7 +12,7 @@ "path": "resnet", "model_card": "https://arxiv.org/abs/2110.00476", }, - "kaggle_handle": "kaggle://kerashub/resnetv1/keras/resnet_18_imagenet/3", + "kaggle_handle": "kaggle://keras/resnetv1/keras/resnet_18_imagenet/2", }, "resnet_50_imagenet": { "metadata": { @@ -25,7 +25,7 @@ "path": "resnet", "model_card": "https://arxiv.org/abs/2110.00476", }, - "kaggle_handle": "kaggle://kerashub/resnetv1/keras/resnet_50_imagenet/3", + "kaggle_handle": "kaggle://keras/resnetv1/keras/resnet_50_imagenet/2", }, "resnet_101_imagenet": { "metadata": { @@ -38,7 +38,7 @@ "path": "resnet", "model_card": "https://arxiv.org/abs/2110.00476", }, - "kaggle_handle": "kaggle://kerashub/resnetv1/keras/resnet_101_imagenet/3", + "kaggle_handle": "kaggle://keras/resnetv1/keras/resnet_101_imagenet/2", }, "resnet_152_imagenet": { "metadata": { @@ -51,7 +51,7 @@ "path": "resnet", "model_card": "https://arxiv.org/abs/2110.00476", }, - "kaggle_handle": "kaggle://kerashub/resnetv1/keras/resnet_152_imagenet/3", + "kaggle_handle": "kaggle://keras/resnetv1/keras/resnet_152_imagenet/2", }, "resnet_v2_50_imagenet": { "metadata": { @@ -64,7 +64,7 @@ "path": "resnet", "model_card": "https://arxiv.org/abs/2110.00476", }, - "kaggle_handle": "kaggle://kerashub/resnetv2/keras/resnet_v2_50_imagenet/3", + "kaggle_handle": "kaggle://keras/resnetv2/keras/resnet_v2_50_imagenet/2", }, "resnet_v2_101_imagenet": { "metadata": { @@ -77,6 +77,147 @@ "path": "resnet", "model_card": "https://arxiv.org/abs/2110.00476", }, - "kaggle_handle": "kaggle://kerashub/resnetv2/keras/resnet_v2_101_imagenet/3", + "kaggle_handle": "kaggle://keras/resnetv2/keras/resnet_v2_101_imagenet/2", + }, + "resnet_vd_18_imagenet": { + "metadata": { + "description": ( + "18-layer ResNetVD (ResNet with bag of tricks) model " + "pre-trained on the ImageNet 1k dataset at a 224x224 " + "resolution." + ), + "params": 11722824, + "official_name": "ResNet", + "path": "resnet", + "model_card": "https://arxiv.org/abs/1812.01187", + }, + "kaggle_handle": "kaggle://kerashub/resnetvd/keras/resnet_vd_18_imagenet", + }, + "resnet_vd_34_imagenet": { + "metadata": { + "description": ( + "34-layer ResNetVD (ResNet with bag of tricks) model " + "pre-trained on the ImageNet 1k dataset at a 224x224 " + "resolution." + ), + "params": 21838408, + "official_name": "ResNet", + "path": "resnet", + "model_card": "https://arxiv.org/abs/1812.01187", + }, + "kaggle_handle": "kaggle://kerashub/resnetvd/keras/resnet_vd_34_imagenet", + }, + "resnet_vd_50_imagenet": { + "metadata": { + "description": ( + "50-layer ResNetVD (ResNet with bag of tricks) model " + "pre-trained on the ImageNet 1k dataset at a 224x224 " + "resolution." + ), + "params": 25629512, + "official_name": "ResNet", + "path": "resnet", + "model_card": "https://arxiv.org/abs/1812.01187", + }, + "kaggle_handle": "kaggle://kerashub/resnetvd/keras/resnet_vd_50_imagenet", + }, + "resnet_vd_50_ssld_imagenet": { + "metadata": { + "description": ( + "50-layer ResNetVD (ResNet with bag of tricks) model " + "pre-trained on the ImageNet 1k dataset at a 224x224 " + "resolution with knowledge distillation." + ), + "params": 25629512, + "official_name": "ResNet", + "path": "resnet", + "model_card": "https://arxiv.org/abs/1812.01187", + }, + "kaggle_handle": "kaggle://kerashub/resnetvd/keras/resnet_vd_50_ssld_imagenet", + }, + "resnet_vd_50_ssld_v2_imagenet": { + "metadata": { + "description": ( + "50-layer ResNetVD (ResNet with bag of tricks) model " + "pre-trained on the ImageNet 1k dataset at a 224x224 " + "resolution with knowledge distillation and AutoAugment." + ), + "params": 25629512, + "official_name": "ResNet", + "path": "resnet", + "model_card": "https://arxiv.org/abs/1812.01187", + }, + "kaggle_handle": "kaggle://kerashub/resnetvd/keras/resnet_vd_50_ssld_v2_imagenet", + }, + "resnet_vd_50_ssld_v2_fix_imagenet": { + "metadata": { + "description": ( + "50-layer ResNetVD (ResNet with bag of tricks) model " + "pre-trained on the ImageNet 1k dataset at a 224x224 " + "resolution with knowledge distillation, AutoAugment and " + "additional fine-tuning of the classification head." + ), + "params": 25629512, + "official_name": "ResNet", + "path": "resnet", + "model_card": "https://arxiv.org/abs/1812.01187", + }, + "kaggle_handle": "kaggle://kerashub/resnetvd/keras/resnet_vd_50_ssld_v2_fix_imagenet", + }, + "resnet_vd_101_imagenet": { + "metadata": { + "description": ( + "101-layer ResNetVD (ResNet with bag of tricks) model " + "pre-trained on the ImageNet 1k dataset at a 224x224 " + "resolution." + ), + "params": 44673864, + "official_name": "ResNet", + "path": "resnet", + "model_card": "https://arxiv.org/abs/1812.01187", + }, + "kaggle_handle": "kaggle://kerashub/resnetvd/keras/resnet_vd_101_imagenet", + }, + "resnet_vd_101_ssld_imagenet": { + "metadata": { + "description": ( + "101-layer ResNetVD (ResNet with bag of tricks) model " + "pre-trained on the ImageNet 1k dataset at a 224x224 " + "resolution with knowledge distillation." + ), + "params": 44673864, + "official_name": "ResNet", + "path": "resnet", + "model_card": "https://arxiv.org/abs/1812.01187", + }, + "kaggle_handle": "kaggle://kerashub/resnetvd/keras/resnet_vd_101_ssld_imagenet", + }, + "resnet_vd_152_imagenet": { + "metadata": { + "description": ( + "152-layer ResNetVD (ResNet with bag of tricks) model " + "pre-trained on the ImageNet 1k dataset at a 224x224 " + "resolution." + ), + "params": 60363592, + "official_name": "ResNet", + "path": "resnet", + "model_card": "https://arxiv.org/abs/1812.01187", + }, + "kaggle_handle": "kaggle://kerashub/resnetvd/keras/resnet_vd_152_imagenet", + }, + "resnet_vd_200_imagenet": { + "metadata": { + "description": ( + "200-layer ResNetVD (ResNet with bag of tricks) model " + "pre-trained on the ImageNet 1k dataset at a 224x224 " + "resolution." + ), + "params": 74933064, + "official_name": "ResNet", + "path": "resnet", + "model_card": "https://arxiv.org/abs/1812.01187", + }, + "kaggle_handle": "kaggle://kerashub/resnetvd/keras/resnet_vd_200_imagenet", }, } diff --git a/keras_hub/src/models/sam/sam_image_segmenter.py b/keras_hub/src/models/sam/sam_image_segmenter.py index ed4b63ecd0..19b0035cb7 100644 --- a/keras_hub/src/models/sam/sam_image_segmenter.py +++ b/keras_hub/src/models/sam/sam_image_segmenter.py @@ -31,7 +31,7 @@ class SAMImageSegmenter(ImageSegmenter): Args: - backbone: A `keras_hub.models.VGGBackbone` instance. + backbone: A `keras_hub.models.SAMBackbone` instance. Example: Load pretrained model using `from_preset`. diff --git a/keras_hub/src/models/sam/sam_presets.py b/keras_hub/src/models/sam/sam_presets.py index 7b7986662c..60e33616e7 100644 --- a/keras_hub/src/models/sam/sam_presets.py +++ b/keras_hub/src/models/sam/sam_presets.py @@ -9,7 +9,7 @@ "path": "sam", "model_card": "https://arxiv.org/abs/2304.02643", }, - "kaggle_handle": "kaggle://kerashub/sam/keras/sam_base_sa1b/2", + "kaggle_handle": "kaggle://keras/sam/keras/sam_base_sa1b/4", }, "sam_large_sa1b": { "metadata": { @@ -19,7 +19,7 @@ "path": "sam", "model_card": "https://arxiv.org/abs/2304.02643", }, - "kaggle_handle": "kaggle://kerashub/sam/keras/sam_large_sa1b/2", + "kaggle_handle": "kaggle://keras/sam/keras/sam_large_sa1b/4", }, "sam_huge_sa1b": { "metadata": { @@ -29,6 +29,6 @@ "path": "sam", "model_card": "https://arxiv.org/abs/2304.02643", }, - "kaggle_handle": "kaggle://kerashub/sam/keras/sam_huge_sa1b/2", + "kaggle_handle": "kaggle://keras/sam/keras/sam_huge_sa1b/4", }, } diff --git a/keras_hub/src/models/segformer/__init__.py b/keras_hub/src/models/segformer/__init__.py new file mode 100644 index 0000000000..3a95690dba --- /dev/null +++ b/keras_hub/src/models/segformer/__init__.py @@ -0,0 +1,8 @@ +from keras_hub.src.models.segformer.segformer_backbone import SegFormerBackbone +from keras_hub.src.models.segformer.segformer_image_segmenter import ( + SegFormerImageSegmenter, +) +from keras_hub.src.models.segformer.segformer_presets import presets +from keras_hub.src.utils.preset_utils import register_presets + +register_presets(presets, SegFormerImageSegmenter) diff --git a/keras_hub/src/models/segformer/segformer_backbone.py b/keras_hub/src/models/segformer/segformer_backbone.py new file mode 100644 index 0000000000..f5563b4c02 --- /dev/null +++ b/keras_hub/src/models/segformer/segformer_backbone.py @@ -0,0 +1,163 @@ +import keras + +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.backbone import Backbone + + +@keras_hub_export("keras_hub.models.SegFormerBackbone") +class SegFormerBackbone(Backbone): + """A Keras model implementing the SegFormer architecture for semantic segmentation. + + This class implements the majority of the SegFormer architecture described in + [SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers] + (https://arxiv.org/abs/2105.15203) and [based on the TensorFlow implementation from DeepVision] + (https://github.com/DavidLandup0/deepvision/tree/main/deepvision/models/segmentation/segformer). + + SegFormers are meant to be used with the MixTransformer (MiT) encoder family, and + and use a very lightweight all-MLP decoder head. + + The MiT encoder uses a hierarchical transformer which outputs features at multiple scales, + similar to that of the hierarchical outputs typically associated with CNNs. + + Args: + image_encoder: `keras.Model`. The backbone network for the model that is + used as a feature extractor for the SegFormer encoder. + Should be used with the MiT backbone model + (`keras_hub.models.MiTBackbone`) which was created + specifically for SegFormers. + num_classes: int, the number of classes for the detection model, + including the background class. + projection_filters: int, number of filters in the + convolution layer projecting the concatenated features into + a segmentation map. Defaults to 256`. + + Example: + + Using the class with a custom `backbone`: + + ```python + import keras_hub + + backbone = keras_hub.models.MiTBackbone( + depths=[2, 2, 2, 2], + image_shape=(224, 224, 3), + hidden_dims=[32, 64, 160, 256], + num_layers=4, + blockwise_num_heads=[1, 2, 5, 8], + blockwise_sr_ratios=[8, 4, 2, 1], + max_drop_path_rate=0.1, + patch_sizes=[7, 3, 3, 3], + strides=[4, 2, 2, 2], + ) + + segformer_backbone = keras_hub.models.SegFormerBackbone(image_encoder=backbone, projection_filters=256) + ``` + + Using the class with a preset `backbone`: + + ```python + import keras_hub + + backbone = keras_hub.models.MiTBackbone.from_preset("mit_b0_ade20k_512") + segformer_backbone = keras_hub.models.SegFormerBackbone(image_encoder=backbone, projection_filters=256) + ``` + + """ + + def __init__( + self, + image_encoder, + projection_filters, + **kwargs, + ): + if not isinstance(image_encoder, keras.layers.Layer) or not isinstance( + image_encoder, keras.Model + ): + raise ValueError( + "Argument `image_encoder` must be a `keras.layers.Layer` instance " + f" or `keras.Model`. Received instead " + f"image_encoder={image_encoder} (of type {type(image_encoder)})." + ) + + # === Layers === + inputs = keras.layers.Input(shape=image_encoder.input.shape[1:]) + + self.feature_extractor = keras.Model( + image_encoder.inputs, image_encoder.pyramid_outputs + ) + + features = self.feature_extractor(inputs) + # Get height and width of level one output + _, height, width, _ = features["P1"].shape + + self.mlp_blocks = [] + + for feature_dim, feature in zip(image_encoder.hidden_dims, features): + self.mlp_blocks.append( + keras.layers.Dense( + projection_filters, name=f"linear_{feature_dim}" + ) + ) + + self.resizing = keras.layers.Resizing( + height, width, interpolation="bilinear" + ) + self.concat = keras.layers.Concatenate(axis=-1) + self.linear_fuse = keras.Sequential( + [ + keras.layers.Conv2D( + filters=projection_filters, kernel_size=1, use_bias=False + ), + keras.layers.BatchNormalization(epsilon=1e-5, momentum=0.9), + keras.layers.Activation("relu"), + ] + ) + + # === Functional Model === + # Project all multi-level outputs onto + # the same dimensionality and feature map shape + multi_layer_outs = [] + for index, (feature_dim, feature) in enumerate( + zip(image_encoder.hidden_dims, features) + ): + out = self.mlp_blocks[index](features[feature]) + out = self.resizing(out) + multi_layer_outs.append(out) + + # Concat now-equal feature maps + concatenated_outs = self.concat(multi_layer_outs[::-1]) + + # Fuse concatenated features into a segmentation map + seg = self.linear_fuse(concatenated_outs) + + super().__init__( + inputs=inputs, + outputs=seg, + **kwargs, + ) + + # === Config === + self.projection_filters = projection_filters + self.image_encoder = image_encoder + + def get_config(self): + config = super().get_config() + config.update( + { + "projection_filters": self.projection_filters, + "image_encoder": keras.saving.serialize_keras_object( + self.image_encoder + ), + } + ) + return config + + @classmethod + def from_config(cls, config): + if "image_encoder" in config and isinstance( + config["image_encoder"], dict + ): + config["image_encoder"] = keras.layers.deserialize( + config["image_encoder"] + ) + return super().from_config(config) diff --git a/keras_hub/src/models/segformer/segformer_backbone_tests.py b/keras_hub/src/models/segformer/segformer_backbone_tests.py new file mode 100644 index 0000000000..22133763e7 --- /dev/null +++ b/keras_hub/src/models/segformer/segformer_backbone_tests.py @@ -0,0 +1,76 @@ +import numpy as np +import pytest +from keras import ops + +from keras_hub.api.models import MiTBackbone +from keras_hub.api.models import SegFormerBackbone +from keras_hub.src.tests.test_case import TestCase + + +class SegFormerTest(TestCase): + def setUp(self): + image_encoder = MiTBackbone( + depths=[2, 2], + image_shape=(224, 224, 3), + hidden_dims=[32, 64], + num_layers=2, + blockwise_num_heads=[1, 2], + blockwise_sr_ratios=[8, 4], + max_drop_path_rate=0.1, + patch_sizes=[7, 3], + strides=[4, 2], + ) + projection_filters = 256 + self.input_size = 224 + self.input_data = ops.ones((2, self.input_size, self.input_size, 3)) + + self.init_kwargs = { + "projection_filters": projection_filters, + "image_encoder": image_encoder, + } + + def test_segformer_backbone_construction(self): + + SegFormerBackbone( + image_encoder=self.init_kwargs["image_encoder"], + projection_filters=self.init_kwargs["projection_filters"], + ) + + @pytest.mark.large + def test_segformer_call(self): + segformer_backbone = SegFormerBackbone( + image_encoder=self.init_kwargs["image_encoder"], + projection_filters=self.init_kwargs["projection_filters"], + ) + + images = np.random.uniform(size=(2, 224, 224, 3)) + segformer_output = segformer_backbone(images) + segformer_predict = segformer_backbone.predict(images) + + assert segformer_output.shape == (2, 56, 56, 256) + assert segformer_predict.shape == (2, 56, 56, 256) + + def test_backbone_basics(self): + + self.run_vision_backbone_test( + cls=SegFormerBackbone, + init_kwargs={**self.init_kwargs}, + input_data=self.input_data, + expected_output_shape=(2, 56, 56, 256), + ) + + def test_task(self): + self.run_task_test( + cls=SegFormerBackbone, + init_kwargs={**self.init_kwargs}, + train_data=self.input_data, + expected_output_shape=(2, 56, 56, 256), + ) + + @pytest.mark.large + def test_saved_model(self): + self.run_model_saving_test( + cls=SegFormerBackbone, + init_kwargs={**self.init_kwargs}, + input_data=self.input_data, + ) diff --git a/keras_hub/src/models/segformer/segformer_image_converter.py b/keras_hub/src/models/segformer/segformer_image_converter.py new file mode 100644 index 0000000000..44febd6833 --- /dev/null +++ b/keras_hub/src/models/segformer/segformer_image_converter.py @@ -0,0 +1,8 @@ +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.layers.preprocessing.image_converter import ImageConverter +from keras_hub.src.models.segformer.segformer_backbone import SegFormerBackbone + + +@keras_hub_export("keras_hub.layers.SegFormerImageConverter") +class SegFormerImageConverter(ImageConverter): + backbone_cls = SegFormerBackbone diff --git a/keras_hub/src/models/segformer/segformer_image_segmenter.py b/keras_hub/src/models/segformer/segformer_image_segmenter.py new file mode 100644 index 0000000000..1b00c7a754 --- /dev/null +++ b/keras_hub/src/models/segformer/segformer_image_segmenter.py @@ -0,0 +1,171 @@ +import keras + +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.image_segmenter import ImageSegmenter +from keras_hub.src.models.segformer.segformer_backbone import SegFormerBackbone +from keras_hub.src.models.segformer.segformer_image_segmenter_preprocessor import ( + SegFormerImageSegmenterPreprocessor, +) + + +@keras_hub_export("keras_hub.models.SegFormerImageSegmenter") +class SegFormerImageSegmenter(ImageSegmenter): + """A Keras model implementing the SegFormer architecture for semantic segmentation. + + This class implements the segmentation head of the SegFormer architecture described in + [SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers] + (https://arxiv.org/abs/2105.15203) and [based on the TensorFlow implementation from DeepVision] + (https://github.com/DavidLandup0/deepvision/tree/main/deepvision/models/segmentation/segformer). + + SegFormers are meant to be used with the MixTransformer (MiT) encoder family, and + and use a very lightweight all-MLP decoder head. + + The MiT encoder uses a hierarchical transformer which outputs features at multiple scales, + similar to that of the hierarchical outputs typically associated with CNNs. + + Args: + image_encoder: `keras.Model`. The backbone network for the model that is + used as a feature extractor for the SegFormer encoder. + It is *intended* to be used only with the MiT backbone model + (`keras_hub.models.MiTBackbone`) which was created + specifically for SegFormers. + Alternatively, can be a `keras_hub.models.Backbone` a model subclassing + `keras_hub.models.FeaturePyramidBackbone`, or a `keras.Model` + that has a `pyramid_outputs` property which is + a dictionary with keys "P2", "P3", "P4", and "P5" and layer names as values. + num_classes: int, the number of classes for the detection model, + including the background class. + projection_filters: int, number of filters in the + convolution layer projecting the concatenated features into + a segmentation map. Defaults to 256`. + + + Example: + + Using presets: + + ```python + import keras_hub + import numpy as np + + segmenter = keras_hub.models.SegFormerImageSegmenter.from_preset("segformer_b0_ade20k_512") + + images = np.random.rand(1, 512, 512, 3) + segformer(images) + ``` + + Using the SegFormer backbone: + + ```python + encoder = keras_hub.models.MiTBackbone.from_preset("mit_b0_ade20k_512") + backbone = keras_hub.models.SegFormerBackbone(image_encoder=encoder, projection_filters=256) + ``` + + Using the SegFormer backbone with a custom encoder: + + ```python + import keras + import keras_hub + import numpy as np + + images = np.ones(shape=(1, 96, 96, 3)) + labels = np.zeros(shape=(1, 96, 96, 1)) + + encoder = keras_hub.models.MiTBackbone( + depths=[2, 2, 2, 2], + image_shape=(96, 96, 3), + hidden_dims=[32, 64, 160, 256], + num_layers=4, + blockwise_num_heads=[1, 2, 5, 8], + blockwise_sr_ratios=[8, 4, 2, 1], + max_drop_path_rate=0.1, + patch_sizes=[7, 3, 3, 3], + strides=[4, 2, 2, 2], + ) + + backbone = keras_hub.models.SegFormerBackbone(image_encoder=encoder, projection_filters=256) + segformer = keras_hub.models.SegFormerImageSegmenter(backbone=backbone, num_classes=4) + + segformer(images) + ``` + + Using the segmentor class with a preset backbone: + + ```python + import keras_hub + + image_encoder = keras_hub.models.MiTBackbone.from_preset("mit_b0_ade20k_512") + backbone = keras_hub.models.SegFormerBackbone(image_encoder=encoder, projection_filters=256) + segformer = keras_hub.models.SegFormerImageSegmenter(backbone=backbone, num_classes=4) + ``` + """ + + backbone_cls = SegFormerBackbone + preprocessor_cls = SegFormerImageSegmenterPreprocessor + + def __init__( + self, + backbone, + num_classes, + preprocessor=None, + **kwargs, + ): + if not isinstance(backbone, keras.layers.Layer) or not isinstance( + backbone, keras.Model + ): + raise ValueError( + "Argument `backbone` must be a `keras.layers.Layer` instance " + f" or `keras.Model`. Received instead " + f"backbone={backbone} (of type {type(backbone)})." + ) + + # === Layers === + inputs = backbone.input + + self.backbone = backbone + self.preprocessor = preprocessor + self.dropout = keras.layers.Dropout(0.1) + self.output_segmentation_head = keras.layers.Conv2D( + filters=num_classes, kernel_size=1, strides=1 + ) + self.resizing = keras.layers.Resizing( + height=inputs.shape[1], + width=inputs.shape[2], + interpolation="bilinear", + ) + + # === Functional Model === + x = self.backbone(inputs) + x = self.dropout(x) + x = self.output_segmentation_head(x) + output = self.resizing(x) + + super().__init__( + inputs=inputs, + outputs=output, + **kwargs, + ) + + # === Config === + self.num_classes = num_classes + self.backbone = backbone + + def get_config(self): + config = super().get_config() + config.update( + { + "num_classes": self.num_classes, + "backbone": keras.saving.serialize_keras_object(self.backbone), + } + ) + return config + + @classmethod + def from_config(cls, config): + if "image_encoder" in config and isinstance( + config["image_encoder"], dict + ): + config["image_encoder"] = keras.layers.deserialize( + config["image_encoder"] + ) + return super().from_config(config) diff --git a/keras_hub/src/models/segformer/segformer_image_segmenter_preprocessor.py b/keras_hub/src/models/segformer/segformer_image_segmenter_preprocessor.py new file mode 100644 index 0000000000..fd8c5fba35 --- /dev/null +++ b/keras_hub/src/models/segformer/segformer_image_segmenter_preprocessor.py @@ -0,0 +1,31 @@ +import keras + +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.image_segmenter_preprocessor import ( + ImageSegmenterPreprocessor, +) +from keras_hub.src.models.segformer.segformer_backbone import SegFormerBackbone +from keras_hub.src.models.segformer.segformer_image_converter import ( + SegFormerImageConverter, +) +from keras_hub.src.utils.tensor_utils import preprocessing_function + +IMAGENET_DEFAULT_MEAN = [0.485, 0.456, 0.406] +IMAGENET_DEFAULT_STD = [0.229, 0.224, 0.225] + + +@keras_hub_export("keras_hub.models.SegFormerImageSegmenterPreprocessor") +class SegFormerImageSegmenterPreprocessor(ImageSegmenterPreprocessor): + backbone_cls = SegFormerBackbone + image_converter_cls = SegFormerImageConverter + + @preprocessing_function + def call(self, x, y=None, sample_weight=None): + if self.image_converter: + x = self.image_converter(x) + y = self.image_converter(y) + + x = x / 255 + x = (x - IMAGENET_DEFAULT_MEAN) / IMAGENET_DEFAULT_STD + + return keras.utils.pack_x_y_sample_weight(x, y, sample_weight) diff --git a/keras_hub/src/models/segformer/segformer_image_segmenter_tests.py b/keras_hub/src/models/segformer/segformer_image_segmenter_tests.py new file mode 100644 index 0000000000..4ad2e8bc6f --- /dev/null +++ b/keras_hub/src/models/segformer/segformer_image_segmenter_tests.py @@ -0,0 +1,65 @@ +import numpy as np +import pytest +from keras import ops + +from keras_hub.api.models import MiTBackbone +from keras_hub.api.models import SegFormerBackbone +from keras_hub.api.models import SegFormerImageSegmenter +from keras_hub.src.tests.test_case import TestCase + + +class SegFormerTest(TestCase): + def setUp(self): + image_encoder = MiTBackbone( + depths=[2, 2], + image_shape=(224, 224, 3), + hidden_dims=[32, 64], + num_layers=2, + blockwise_num_heads=[1, 2], + blockwise_sr_ratios=[8, 4], + max_drop_path_rate=0.1, + patch_sizes=[7, 3], + strides=[4, 2], + ) + projection_filters = 256 + self.backbone = SegFormerBackbone( + image_encoder=image_encoder, projection_filters=projection_filters + ) + + self.input_size = 224 + self.input_data = ops.ones((2, self.input_size, self.input_size, 3)) + + self.init_kwargs = {"backbone": self.backbone, "num_classes": 4} + + def test_segformer_segmenter_construction(self): + SegFormerImageSegmenter(backbone=self.backbone, num_classes=4) + + @pytest.mark.large + def test_segformer_call(self): + + segformer = SegFormerImageSegmenter( + backbone=self.backbone, num_classes=4 + ) + + images = np.random.uniform(size=(2, 224, 224, 4)) + segformer_output = segformer(images) + segformer_predict = segformer.predict(images) + + assert segformer_output.shape == images.shape + assert segformer_predict.shape == images.shape + + def test_task(self): + self.run_task_test( + cls=SegFormerImageSegmenter, + init_kwargs={**self.init_kwargs}, + train_data=self.input_data, + expected_output_shape=(2, 224, 224), + ) + + @pytest.mark.large + def test_saved_model(self): + self.run_model_saving_test( + cls=SegFormerImageSegmenter, + init_kwargs={**self.init_kwargs}, + input_data=self.input_data, + ) diff --git a/keras_hub/src/models/segformer/segformer_presets.py b/keras_hub/src/models/segformer/segformer_presets.py new file mode 100644 index 0000000000..2c0fff0a50 --- /dev/null +++ b/keras_hub/src/models/segformer/segformer_presets.py @@ -0,0 +1,136 @@ +"""SegFormer model preset configurations.""" + +presets = { + "segformer_b0_ade20k_512": { + "metadata": { + "description": ( + "SegFormer model with MiTB0 backbone fine-tuned on ADE20k in 512x512 resolution." + ), + "params": 3719027, + "official_name": "SegFormerB0", + "path": "segformer_b0", + }, + "kaggle_handle": "kaggle://kerashub/segformer/keras/segformer_b0_ade20k_512", + }, + "segformer_b1_ade20k_512": { + "metadata": { + "description": ( + "SegFormer model with MiTB1 backbone fine-tuned on ADE20k in 512x512 resolution." + ), + "params": 13682643, + "official_name": "SegFormerB1", + "path": "segformer_b1", + }, + "kaggle_handle": "kaggle://kerashub/segformer/keras/segformer_b1_ade20k_512", + }, + "segformer_b2_ade20k_512": { + "metadata": { + "description": ( + "SegFormer model with MiTB2 backbone fine-tuned on ADE20k in 512x512 resolution." + ), + "params": 24727507, + "official_name": "SegFormerB2", + "path": "segformer_b2", + }, + "kaggle_handle": "kaggle://kerashub/segformer/keras/segformer_b2_ade20k_512", + }, + "segformer_b3_ade20k_512": { + "metadata": { + "description": ( + "SegFormer model with MiTB3 backbone fine-tuned on ADE20k in 512x512 resolution." + ), + "params": 44603347, + "official_name": "SegFormerB3", + "path": "segformer_b3", + }, + "kaggle_handle": "kaggle://kerashub/segformer/keras/segformer_b3_ade20k_512", + }, + "segformer_b4_ade20k_512": { + "metadata": { + "description": ( + "SegFormer model with MiTB4 backbone fine-tuned on ADE20k in 512x512 resolution." + ), + "params": 61373907, + "official_name": "SegFormerB4", + "path": "segformer_b4", + }, + "kaggle_handle": "kaggle://kerashub/segformer/keras/segformer_b4_ade20k_512", + }, + "segformer_b5_ade20k_640": { + "metadata": { + "description": ( + "SegFormer model with MiTB5 backbone fine-tuned on ADE20k in 640x640 resolution." + ), + "params": 81974227, + "official_name": "SegFormerB5", + "path": "segformer_b5", + }, + "kaggle_handle": "kaggle://kerashub/segformer/keras/segformer_b5_ade20k_640", + }, + "segformer_b0_cityscapes_1024": { + "metadata": { + "description": ( + "SegFormer model with MiTB0 backbone fine-tuned on Cityscapes in 1024x1024 resolution." + ), + "params": 3719027, + "official_name": "SegFormerB0", + "path": "segformer_b0", + }, + "kaggle_handle": "kaggle://kerashub/segformer/keras/segformer_b0_cityscapes_1024", + }, + "segformer_b1_cityscapes_1024": { + "metadata": { + "description": ( + "SegFormer model with MiTB1 backbone fine-tuned on Cityscapes in 1024x1024 resolution." + ), + "params": 13682643, + "official_name": "SegFormerB1", + "path": "segformer_b1", + }, + "kaggle_handle": "kaggle://kerashub/segformer/keras/segformer_b1_ade20k_512", + }, + "segformer_b2_cityscapes_1024": { + "metadata": { + "description": ( + "SegFormer model with MiTB2 backbone fine-tuned on Cityscapes in 1024x1024 resolution." + ), + "params": 24727507, + "official_name": "SegFormerB2", + "path": "segformer_b2", + }, + "kaggle_handle": "kaggle://kerashub/segformer/keras/segformer_b2_cityscapes_1024", + }, + "segformer_b3_cityscapes_1024": { + "metadata": { + "description": ( + "SegFormer model with MiTB3 backbone fine-tuned on Cityscapes in 1024x1024 resolution." + ), + "params": 44603347, + "official_name": "SegFormerB3", + "path": "segformer_b3", + }, + "kaggle_handle": "kaggle://kerashub/segformer/keras/segformer_b3_cityscapes_1024", + }, + "segformer_b4_cityscapes_1024": { + "metadata": { + "description": ( + "SegFormer model with MiTB4 backbone fine-tuned on Cityscapes in 1024x1024 resolution." + ), + "params": 61373907, + "official_name": "SegFormerB4", + "path": "segformer_b4", + }, + "kaggle_handle": "kaggle://kerashub/segformer/keras/segformer_b4_cityscapes_1024", + }, + "segformer_b5_cityscapes_1024": { + "metadata": { + "description": ( + "SegFormer model with MiTB5 backbone fine-tuned on Cityscapes in 1024x1024 resolution." + ), + "params": 81974227, + "official_name": "SegFormerB5", + "path": "segformer_b5", + }, + "kaggle_handle": "kaggle://kerashub/segformer/keras/segformer_b5_cityscapes_1024", + }, +} diff --git a/keras_hub/src/models/stable_diffusion_3/mmdit.py b/keras_hub/src/models/stable_diffusion_3/mmdit.py index 0fe78e571b..546d56f13a 100644 --- a/keras_hub/src/models/stable_diffusion_3/mmdit.py +++ b/keras_hub/src/models/stable_diffusion_3/mmdit.py @@ -2,7 +2,6 @@ import keras from keras import layers -from keras import models from keras import ops from keras_hub.src.layers.modeling.position_embedding import PositionEmbedding @@ -11,7 +10,167 @@ from keras_hub.src.utils.keras_utils import standardize_data_format +class AdaptiveLayerNormalization(layers.Layer): + """Adaptive layer normalization. + + Args: + embedding_dim: int. The size of each embedding vector. + residual_modulation: bool. Whether to output the modulation parameters + of the residual connection within the block of the diffusion + transformers. Defaults to `False`. + **kwargs: other keyword arguments passed to `keras.layers.Layer`, + including `name`, `dtype` etc. + + References: + - [FiLM: Visual Reasoning with a General Conditioning Layer]( + https://arxiv.org/abs/1709.07871). + - [Scalable Diffusion Models with Transformers]( + https://arxiv.org/abs/2212.09748). + """ + + def __init__(self, hidden_dim, residual_modulation=False, **kwargs): + super().__init__(**kwargs) + self.hidden_dim = int(hidden_dim) + self.residual_modulation = bool(residual_modulation) + num_modulations = 6 if self.residual_modulation else 2 + + self.silu = layers.Activation("silu", dtype=self.dtype_policy) + self.dense = layers.Dense( + num_modulations * hidden_dim, dtype=self.dtype_policy, name="dense" + ) + self.norm = layers.LayerNormalization( + epsilon=1e-6, + center=False, + scale=False, + dtype="float32", + name="norm", + ) + + def build(self, inputs_shape, embeddings_shape): + self.silu.build(embeddings_shape) + self.dense.build(embeddings_shape) + self.norm.build(inputs_shape) + + def call(self, inputs, embeddings, training=None): + x = inputs + emb = self.dense(self.silu(embeddings), training=training) + if self.residual_modulation: + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( + ops.split(emb, 6, axis=1) + ) + else: + shift_msa, scale_msa = ops.split(emb, 2, axis=1) + scale_msa = ops.expand_dims(scale_msa, axis=1) + shift_msa = ops.expand_dims(shift_msa, axis=1) + x = ops.add( + ops.multiply( + self.norm(x, training=training), + ops.add(1.0, scale_msa), + ), + shift_msa, + ) + if self.residual_modulation: + return x, gate_msa, shift_mlp, scale_mlp, gate_mlp + else: + return x + + def get_config(self): + config = super().get_config() + config.update( + { + "hidden_dim": self.hidden_dim, + "residual_modulation": self.residual_modulation, + } + ) + return config + + def compute_output_shape(self, inputs_shape, embeddings_shape): + if self.residual_modulation: + return ( + inputs_shape, + embeddings_shape, + embeddings_shape, + embeddings_shape, + embeddings_shape, + ) + else: + return inputs_shape + + +class MLP(layers.Layer): + """A MLP block with architecture. + + Args: + hidden_dim: int. The number of units in the hidden layers. + output_dim: int. The number of units in the output layer. + activation: str of callable. Activation to use in the hidden layers. + Default to `None`. + """ + + def __init__(self, hidden_dim, output_dim, activation=None, **kwargs): + super().__init__(**kwargs) + self.hidden_dim = int(hidden_dim) + self.output_dim = int(output_dim) + self.activation = keras.activations.get(activation) + + self.dense1 = layers.Dense( + hidden_dim, + activation=self.activation, + dtype=self.dtype_policy, + name="dense1", + ) + self.dense2 = layers.Dense( + output_dim, + activation=None, + dtype=self.dtype_policy, + name="dense2", + ) + + def build(self, inputs_shape): + self.dense1.build(inputs_shape) + inputs_shape = self.dense1.compute_output_shape(inputs_shape) + self.dense2.build(inputs_shape) + + def call(self, inputs, training=None): + x = self.dense1(inputs, training=training) + return self.dense2(x, training=training) + + def get_config(self): + config = super().get_config() + config.update( + { + "hidden_dim": self.hidden_dim, + "output_dim": self.output_dim, + "activation": keras.activations.serialize(self.activation), + } + ) + return config + + def compute_output_shape(self, inputs_shape): + outputs_shape = list(inputs_shape) + outputs_shape[-1] = self.output_dim + return outputs_shape + + class PatchEmbedding(layers.Layer): + """A layer that converts images into patches. + + Args: + patch_size: int. The size of one side of each patch. + hidden_dim: int. The number of units in the hidden layers. + data_format: `None` or str. If specified, either `"channels_last"` or + `"channels_first"`. The ordering of the dimensions in the + inputs. `"channels_last"` corresponds to inputs with shape + `(batch_size, height, width, channels)` + while `"channels_first"` corresponds to inputs with shape + `(batch_size, channels, height, width)`. It defaults to the + `image_data_format` value found in your Keras config file at + `~/.keras/keras.json`. If you never set it, then it will be + `"channels_last"`. + **kwargs: other keyword arguments passed to `keras.layers.Layer`, + including `name`, `dtype` etc. + """ + def __init__(self, patch_size, hidden_dim, data_format=None, **kwargs): super().__init__(**kwargs) self.patch_size = int(patch_size) @@ -48,6 +207,15 @@ def get_config(self): class AdjustablePositionEmbedding(PositionEmbedding): + """A position embedding layer with adjustable height and width. + + The embedding will be cropped to match the input dimensions. + + Args: + height: int. The maximum height of the embedding. + width: int. The maximum width of the embedding. + """ + def __init__( self, height, @@ -84,11 +252,36 @@ def call(self, inputs, height=None, width=None): position_embedding = ops.expand_dims(position_embedding, axis=0) return position_embedding + def get_config(self): + config = super().get_config() + del config["sequence_length"] + config.update( + { + "height": self.height, + "width": self.width, + } + ) + return config + def compute_output_shape(self, input_shape): return input_shape class TimestepEmbedding(layers.Layer): + """A layer which learns embedding for input timesteps. + + Args: + embedding_dim: int. The size of the embedding. + frequency_dim: int. The size of the frequency. + max_period: int. Controls the maximum frequency of the embeddings. + **kwargs: other keyword arguments passed to `keras.layers.Layer`, + including `name`, `dtype` etc. + + Reference: + - [Denoising Diffusion Probabilistic Models]( + https://arxiv.org/abs/2006.11239). + """ + def __init__( self, embedding_dim, frequency_dim=256, max_period=10000, **kwargs ): @@ -96,17 +289,23 @@ def __init__( self.embedding_dim = int(embedding_dim) self.frequency_dim = int(frequency_dim) self.max_period = float(max_period) - self.half_frequency_dim = self.frequency_dim // 2 - - self.mlp = models.Sequential( - [ - layers.Dense( - embedding_dim, activation="silu", dtype=self.dtype_policy - ), - layers.Dense( - embedding_dim, activation=None, dtype=self.dtype_policy + # Precomputed `freq`. + half_frequency_dim = frequency_dim // 2 + self.freq = ops.exp( + ops.divide( + ops.multiply( + -math.log(max_period), + ops.arange(0, half_frequency_dim, dtype="float32"), ), - ], + half_frequency_dim, + ) + ) + + self.mlp = MLP( + embedding_dim, + embedding_dim, + "silu", + dtype=self.dtype_policy, name="mlp", ) @@ -118,16 +317,7 @@ def build(self, inputs_shape): def _create_timestep_embedding(self, inputs): compute_dtype = keras.backend.result_type(self.compute_dtype, "float32") x = ops.cast(inputs, compute_dtype) - freqs = ops.exp( - ops.divide( - ops.multiply( - -math.log(self.max_period), - ops.arange(0, self.half_frequency_dim, dtype="float32"), - ), - self.half_frequency_dim, - ) - ) - freqs = ops.cast(freqs, compute_dtype) + freqs = ops.cast(self.freq, compute_dtype) x = ops.multiply(x, ops.expand_dims(freqs, axis=0)) embedding = ops.concatenate([ops.cos(x), ops.sin(x)], axis=-1) if self.frequency_dim % 2 != 0: @@ -143,6 +333,7 @@ def get_config(self): config.update( { "embedding_dim": self.embedding_dim, + "frequency_dim": self.frequency_dim, "max_period": self.max_period, } ) @@ -155,6 +346,18 @@ def compute_output_shape(self, inputs_shape): class DismantledBlock(layers.Layer): + """A dismantled block used to compute pre- and post-attention. + + Args: + num_heads: int. Number of attention heads. + hidden_dim: int. The number of units in the hidden layers. + mlp_ratio: float. The expansion ratio of `MLP`. + use_projection: bool. Whether to use an attention projection layer at + the end of the block. + **kwargs: other keyword arguments passed to `keras.layers.Layer`, + including `name`, `dtype` etc. + """ + def __init__( self, num_heads, @@ -173,25 +376,18 @@ def __init__( self.head_dim = head_dim mlp_hidden_dim = int(hidden_dim * mlp_ratio) self.mlp_hidden_dim = mlp_hidden_dim - num_modulations = 6 if use_projection else 2 - self.num_modulations = num_modulations - - self.adaptive_norm_modulation = models.Sequential( - [ - layers.Activation("silu", dtype=self.dtype_policy), - layers.Dense( - num_modulations * hidden_dim, dtype=self.dtype_policy - ), - ], - name="adaptive_norm_modulation", - ) - self.norm1 = layers.LayerNormalization( - epsilon=1e-6, - center=False, - scale=False, - dtype="float32", - name="norm1", - ) + + if use_projection: + self.ada_layer_norm = AdaptiveLayerNormalization( + hidden_dim, + residual_modulation=True, + dtype=self.dtype_policy, + name="ada_layer_norm", + ) + else: + self.ada_layer_norm = AdaptiveLayerNormalization( + hidden_dim, dtype=self.dtype_policy, name="ada_layer_norm" + ) self.attention_qkv = layers.Dense( hidden_dim * 3, dtype=self.dtype_policy, name="attention_qkv" ) @@ -206,73 +402,45 @@ def __init__( dtype="float32", name="norm2", ) - self.mlp = models.Sequential( - [ - layers.Dense( - mlp_hidden_dim, - activation=gelu_approximate, - dtype=self.dtype_policy, - ), - layers.Dense( - hidden_dim, - dtype=self.dtype_policy, - ), - ], + self.mlp = MLP( + mlp_hidden_dim, + hidden_dim, + gelu_approximate, + dtype=self.dtype_policy, name="mlp", ) def build(self, inputs_shape, timestep_embedding): - self.adaptive_norm_modulation.build(timestep_embedding) + self.ada_layer_norm.build(inputs_shape, timestep_embedding) self.attention_qkv.build(inputs_shape) - self.norm1.build(inputs_shape) if self.use_projection: self.attention_proj.build(inputs_shape) self.norm2.build(inputs_shape) self.mlp.build(inputs_shape) def _modulate(self, inputs, shift, scale): - shift = ops.expand_dims(shift, axis=1) - scale = ops.expand_dims(scale, axis=1) + inputs = ops.cast(inputs, self.compute_dtype) + shift = ops.cast(shift, self.compute_dtype) + scale = ops.cast(scale, self.compute_dtype) return ops.add(ops.multiply(inputs, ops.add(scale, 1.0)), shift) def _compute_pre_attention(self, inputs, timestep_embedding, training=None): batch_size = ops.shape(inputs)[0] if self.use_projection: - modulation = self.adaptive_norm_modulation( - timestep_embedding, training=training - ) - modulation = ops.reshape( - modulation, (batch_size, 6, self.hidden_dim) - ) - ( - shift_msa, - scale_msa, - gate_msa, - shift_mlp, - scale_mlp, - gate_mlp, - ) = ops.unstack(modulation, 6, axis=1) - qkv = self.attention_qkv( - self._modulate(self.norm1(inputs), shift_msa, scale_msa), - training=training, + x, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.ada_layer_norm( + inputs, timestep_embedding, training=training ) + qkv = self.attention_qkv(x, training=training) qkv = ops.reshape( qkv, (batch_size, -1, 3, self.num_heads, self.head_dim) ) q, k, v = ops.unstack(qkv, 3, axis=2) return (q, k, v), (inputs, gate_msa, shift_mlp, scale_mlp, gate_mlp) else: - modulation = self.adaptive_norm_modulation( - timestep_embedding, training=training - ) - modulation = ops.reshape( - modulation, (batch_size, 2, self.hidden_dim) - ) - shift_msa, scale_msa = ops.unstack(modulation, 2, axis=1) - qkv = self.attention_qkv( - self._modulate(self.norm1(inputs), shift_msa, scale_msa), - training=training, + x = self.ada_layer_norm( + inputs, timestep_embedding, training=training ) + qkv = self.attention_qkv(x, training=training) qkv = ops.reshape( qkv, (batch_size, -1, 3, self.num_heads, self.head_dim) ) @@ -283,12 +451,16 @@ def _compute_post_attention( self, inputs, inputs_intermediates, training=None ): x, gate_msa, shift_mlp, scale_mlp, gate_mlp = inputs_intermediates + gate_msa = ops.expand_dims(gate_msa, axis=1) + shift_mlp = ops.expand_dims(shift_mlp, axis=1) + scale_mlp = ops.expand_dims(scale_mlp, axis=1) + gate_mlp = ops.expand_dims(gate_mlp, axis=1) attn = self.attention_proj(inputs, training=training) - x = ops.add(x, ops.multiply(ops.expand_dims(gate_msa, axis=1), attn)) + x = ops.add(x, ops.multiply(gate_msa, attn)) x = ops.add( x, ops.multiply( - ops.expand_dims(gate_mlp, axis=1), + gate_mlp, self.mlp( self._modulate(self.norm2(x), shift_mlp, scale_mlp), training=training, @@ -328,6 +500,27 @@ def get_config(self): class MMDiTBlock(layers.Layer): + """A MMDiT block consisting of two `DismantledBlock` layers. + + One `DismantledBlock` processes the input latents, and the other processes + the context embedding. This block integrates two modalities within the + attention operation, allowing each representation to operate in its own + space while considering the other. + + Args: + num_heads: int. Number of attention heads. + hidden_dim: int. The number of units in the hidden layers. + mlp_ratio: float. The expansion ratio of `MLP`. + use_context_projection: bool. Whether to use an attention projection + layer at the end of the context block. + **kwargs: other keyword arguments passed to `keras.layers.Layer`, + including `name`, `dtype` etc. + + Reference: + - [Scaling Rectified Flow Transformers for High-Resolution Image Synthesis]( + https://arxiv.org/abs/2403.03206) + """ + def __init__( self, num_heads, @@ -345,8 +538,6 @@ def __init__( head_dim = hidden_dim // num_heads self.head_dim = head_dim self._inverse_sqrt_key_dim = 1.0 / math.sqrt(head_dim) - self._dot_product_equation = "aecd,abcd->acbe" - self._combine_equation = "acbe,aecd->abcd" self.x_block = DismantledBlock( num_heads=num_heads, @@ -371,20 +562,18 @@ def build(self, inputs_shape, context_shape, timestep_embedding_shape): self.context_block.build(context_shape, timestep_embedding_shape) def _compute_attention(self, query, key, value): - query = ops.multiply( - query, ops.cast(self._inverse_sqrt_key_dim, query.dtype) - ) - attention_scores = ops.einsum(self._dot_product_equation, key, query) - attention_scores = self.softmax(attention_scores) - attention_scores = ops.cast(attention_scores, self.compute_dtype) - attention_output = ops.einsum( - self._combine_equation, attention_scores, value - ) - batch_size = ops.shape(attention_output)[0] - attention_output = ops.reshape( - attention_output, (batch_size, -1, self.num_heads * self.head_dim) - ) - return attention_output + # Ref: jax.nn.dot_product_attention + # https://github.com/jax-ml/jax/blob/db89c245ac66911c98f265a05956fdfa4bc79d83/jax/_src/nn/functions.py#L846 + batch_size = ops.shape(query)[0] + logits = ops.einsum("BTNH,BSNH->BNTS", query, key) + logits = ops.multiply(logits, self._inverse_sqrt_key_dim) + probs = self.softmax(logits) + probs = ops.cast(probs, self.compute_dtype) + encoded = ops.einsum("BNTS,BSNH->BTNH", probs, value) + encoded = ops.reshape( + encoded, (batch_size, -1, self.num_heads * self.head_dim) + ) + return encoded def call(self, inputs, context, timestep_embedding, training=None): # Compute pre-attention. @@ -453,74 +642,16 @@ def compute_output_shape( return inputs_shape -class OutputLayer(layers.Layer): - def __init__(self, hidden_dim, output_dim, **kwargs): - super().__init__(**kwargs) - self.hidden_dim = hidden_dim - self.output_dim = output_dim - num_modulation = 2 - - self.adaptive_norm_modulation = models.Sequential( - [ - layers.Activation("silu", dtype=self.dtype_policy), - layers.Dense( - num_modulation * hidden_dim, dtype=self.dtype_policy - ), - ], - name="adaptive_norm_modulation", - ) - self.norm = layers.LayerNormalization( - epsilon=1e-6, - center=False, - scale=False, - dtype="float32", - name="norm", - ) - self.output_dense = layers.Dense( - output_dim, - use_bias=True, - dtype=self.dtype_policy, - name="output_dense", - ) - - def build(self, inputs_shape, timestep_embedding_shape): - self.adaptive_norm_modulation.build(timestep_embedding_shape) - self.norm.build(inputs_shape) - self.output_dense.build(inputs_shape) - - def _modulate(self, inputs, shift, scale): - shift = ops.expand_dims(shift, axis=1) - scale = ops.expand_dims(scale, axis=1) - return ops.add(ops.multiply(inputs, ops.add(scale, 1.0)), shift) - - def call(self, inputs, timestep_embedding, training=None): - x = inputs - modulation = self.adaptive_norm_modulation( - timestep_embedding, training=training - ) - modulation = ops.reshape(modulation, (-1, 2, self.hidden_dim)) - shift, scale = ops.unstack(modulation, 2, axis=1) - x = self._modulate(self.norm(x), shift, scale) - x = self.output_dense(x, training=training) - return x - - def get_config(self): - config = super().get_config() - config.update( - { - "hidden_dim": self.hidden_dim, - "output_dim": self.output_dim, - } - ) - return config - - def compute_output_shape(self, inputs_shape): - outputs_shape = list(inputs_shape) - outputs_shape[-1] = self.output_dim - return outputs_shape +class Unpatch(layers.Layer): + """A layer that reconstructs the image from hidden patches. + Args: + patch_size: int. The size of each square patch in the input image. + output_dim: int. The number of units in the output layer. + **kwargs: other keyword arguments passed to `keras.layers.Layer`, + including `name`, `dtype` etc. + """ -class Unpatch(layers.Layer): def __init__(self, patch_size, output_dim, **kwargs): super().__init__(**kwargs) self.patch_size = int(patch_size) @@ -556,7 +687,7 @@ def compute_output_shape(self, inputs_shape): class MMDiT(Backbone): - """Multimodal Diffusion Transformer (MMDiT) model for Stable Diffusion 3. + """A Multimodal Diffusion Transformer (MMDiT) model. MMDiT is introduced in [ Scaling Rectified Flow Transformers for High-Resolution Image Synthesis]( @@ -636,12 +767,8 @@ def __init__( dtype=dtype, name="context_embedding", ) - self.vector_embedding = models.Sequential( - [ - layers.Dense(hidden_dim, activation="silu", dtype=dtype), - layers.Dense(hidden_dim, activation=None, dtype=dtype), - ], - name="vector_embedding", + self.vector_embedding = MLP( + hidden_dim, hidden_dim, "silu", dtype=dtype, name="vector_embedding" ) self.vector_embedding_add = layers.Add( dtype=dtype, name="vector_embedding_add" @@ -660,8 +787,11 @@ def __init__( ) for i in range(num_layers) ] - self.output_layer = OutputLayer( - hidden_dim, output_dim_in_final, dtype=dtype, name="output_layer" + self.output_ada_layer_norm = AdaptiveLayerNormalization( + hidden_dim, dtype=dtype, name="output_ada_layer_norm" + ) + self.output_dense = layers.Dense( + output_dim_in_final, dtype=dtype, name="output_dense" ) self.unpatch = Unpatch( patch_size, output_dim, dtype=dtype, name="unpatch" @@ -696,7 +826,8 @@ def __init__( x = block(x, context, timestep_embedding) # Output layer. - x = self.output_layer(x, timestep_embedding) + x = self.output_ada_layer_norm(x, timestep_embedding) + x = self.output_dense(x) outputs = self.unpatch(x, height=image_height, width=image_width) super().__init__( diff --git a/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone.py b/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone.py index c5930a3460..4dd3e4403d 100644 --- a/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone.py +++ b/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone.py @@ -51,11 +51,52 @@ def compute_output_shape(self, inputs_shape): return (inputs_shape[0], self.hidden_dim) -class ClassifierFreeGuidanceConcatenate(layers.Layer): - def __init__(self, axis=0, **kwargs): - super().__init__(**kwargs) - self.axis = axis +class CLIPConcatenate(layers.Layer): + def call( + self, + clip_l_projection, + clip_g_projection, + clip_l_intermediate_output, + clip_g_intermediate_output, + padding, + ): + pooled_embeddings = ops.concatenate( + [clip_l_projection, clip_g_projection], axis=-1 + ) + embeddings = ops.concatenate( + [clip_l_intermediate_output, clip_g_intermediate_output], axis=-1 + ) + embeddings = ops.pad(embeddings, [[0, 0], [0, 0], [0, padding]]) + return pooled_embeddings, embeddings + + +class ImageRescaling(layers.Rescaling): + """Rescales inputs from image space to latent space. + + The rescaling is performed using the formula: `(inputs - offset) * scale`. + """ + + def call(self, inputs): + dtype = self.compute_dtype + scale = self.backend.cast(self.scale, dtype) + offset = self.backend.cast(self.offset, dtype) + return (self.backend.cast(inputs, dtype) - offset) * scale + + +class LatentRescaling(layers.Rescaling): + """Rescales inputs from latent space to image space. + The rescaling is performed using the formula: `inputs / scale + offset`. + """ + + def call(self, inputs): + dtype = self.compute_dtype + scale = self.backend.cast(self.scale, dtype) + offset = self.backend.cast(self.offset, dtype) + return (self.backend.cast(inputs, dtype) / scale) + offset + + +class ClassifierFreeGuidanceConcatenate(layers.Layer): def call( self, latents, @@ -66,20 +107,16 @@ def call( timestep, ): timestep = ops.broadcast_to(timestep, ops.shape(latents)[:1]) - latents = ops.concatenate([latents, latents], axis=self.axis) + latents = ops.concatenate([latents, latents], axis=0) contexts = ops.concatenate( - [positive_contexts, negative_contexts], axis=self.axis + [positive_contexts, negative_contexts], axis=0 ) pooled_projections = ops.concatenate( - [positive_pooled_projections, negative_pooled_projections], - axis=self.axis, + [positive_pooled_projections, negative_pooled_projections], axis=0 ) - timesteps = ops.concatenate([timestep, timestep], axis=self.axis) + timesteps = ops.concatenate([timestep, timestep], axis=0) return latents, contexts, pooled_projections, timesteps - def get_config(self): - return super().get_config() - class ClassifierFreeGuidance(layers.Layer): """Perform classifier free guidance. @@ -100,9 +137,6 @@ class ClassifierFreeGuidance(layers.Layer): - [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). """ - def __init__(self, **kwargs): - super().__init__(**kwargs) - def call(self, inputs, guidance_scale): positive_noise, negative_noise = ops.split(inputs, 2, axis=0) return ops.add( @@ -112,9 +146,6 @@ def call(self, inputs, guidance_scale): ), ) - def get_config(self): - return super().get_config() - def compute_output_shape(self, inputs_shape): outputs_shape = list(inputs_shape) if outputs_shape[0] is not None: @@ -142,16 +173,10 @@ class EulerStep(layers.Layer): https://arxiv.org/abs/2206.00364). """ - def __init__(self, **kwargs): - super().__init__(**kwargs) - def call(self, latents, noise_residual, sigma, sigma_next): sigma_diff = ops.subtract(sigma_next, sigma) return ops.add(latents, ops.multiply(sigma_diff, noise_residual)) - def get_config(self): - return super().get_config() - def compute_output_shape(self, latents_shape): return latents_shape @@ -190,8 +215,8 @@ class StableDiffusion3Backbone(Backbone): model. Defaults to `1000`. shift: float. The shift value for the timestep schedule. Defaults to `3.0`. - height: optional int. The output height of the image. - width: optional int. The output width of the image. + image_shape: tuple. The input shape without the batch size. Defaults to + `(1024, 1024, 3)`. data_format: `None` or str. If specified, either `"channels_last"` or `"channels_first"`. The ordering of the dimensions in the inputs. `"channels_last"` corresponds to inputs with shape @@ -245,23 +270,21 @@ def __init__( output_channels=3, num_train_timesteps=1000, shift=3.0, - height=None, - width=None, + image_shape=(1024, 1024, 3), data_format=None, dtype=None, **kwargs, ): - height = int(height or 1024) - width = int(width or 1024) - if height % 8 != 0 or width % 8 != 0: - raise ValueError( - "`height` and `width` must be divisible by 8. " - f"Received: height={height}, width={width}" - ) data_format = standardize_data_format(data_format) if data_format != "channels_last": raise NotImplementedError - image_shape = (height, width, int(vae.input_channels)) + height = image_shape[0] + width = image_shape[1] + if height % 8 != 0 or width % 8 != 0: + raise ValueError( + "height and width in `image_shape` must be divisible by 8. " + f"Received: image_shape={image_shape}" + ) latent_shape = (height // 8, width // 8, int(latent_channels)) context_shape = (None, 4096 if t5 is None else t5.hidden_dim) pooled_projection_shape = (clip_l.hidden_dim + clip_g.hidden_dim,) @@ -272,12 +295,13 @@ def __init__( self.clip_l_projection = CLIPProjection( clip_l.hidden_dim, dtype=dtype, name="clip_l_projection" ) - self.clip_l_projection.build([None, clip_l.hidden_dim], None) self.clip_g = clip_g self.clip_g_projection = CLIPProjection( clip_g.hidden_dim, dtype=dtype, name="clip_g_projection" ) - self.clip_g_projection.build([None, clip_g.hidden_dim], None) + self.clip_concatenate = CLIPConcatenate( + dtype=dtype, name="clip_concatenate" + ) self.t5 = t5 self.diffuser = MMDiT( mmdit_patch_size, @@ -293,6 +317,12 @@ def __init__( name="diffuser", ) self.vae = vae + self.cfg_concat = ClassifierFreeGuidanceConcatenate( + dtype=dtype, name="classifier_free_guidance_concat" + ) + self.cfg = ClassifierFreeGuidance( + dtype=dtype, name="classifier_free_guidance" + ) # Set `dtype="float32"` to ensure the high precision for the noise # residual. self.scheduler = FlowMatchEulerDiscreteScheduler( @@ -301,17 +331,17 @@ def __init__( dtype="float32", name="scheduler", ) - self.cfg_concat = ClassifierFreeGuidanceConcatenate( - dtype="float32", name="classifier_free_guidance_concat" - ) - self.cfg = ClassifierFreeGuidance( - dtype="float32", name="classifier_free_guidance" - ) self.euler_step = EulerStep(dtype="float32", name="euler_step") - self.latent_rescaling = layers.Rescaling( - scale=1.0 / self.vae.scale, + self.image_rescaling = ImageRescaling( + scale=self.vae.scale, offset=self.vae.shift, - dtype="float32", + dtype=dtype, + name="image_rescaling", + ) + self.latent_rescaling = LatentRescaling( + scale=self.vae.scale, + offset=self.vae.shift, + dtype=dtype, name="latent_rescaling", ) @@ -420,8 +450,7 @@ def __init__( self.output_channels = output_channels self.num_train_timesteps = num_train_timesteps self.shift = shift - self.height = height - self.width = width + self.image_shape = image_shape @property def latent_shape(self): @@ -440,8 +469,12 @@ def encode_text_step(self, token_ids, negative_token_ids): t5_hidden_dim = self.t5_hidden_dim def encode(token_ids): - clip_l_outputs = self.clip_l(token_ids["clip_l"], training=False) - clip_g_outputs = self.clip_g(token_ids["clip_g"], training=False) + clip_l_outputs = self.clip_l( + {"token_ids": token_ids["clip_l"]}, training=False + ) + clip_g_outputs = self.clip_g( + {"token_ids": token_ids["clip_g"]}, training=False + ) clip_l_projection = self.clip_l_projection( clip_l_outputs["sequence_output"], token_ids["clip_l"], @@ -452,23 +485,21 @@ def encode(token_ids): token_ids["clip_g"], training=False, ) - pooled_embeddings = ops.concatenate( - [clip_l_projection, clip_g_projection], - axis=-1, - ) - embeddings = ops.concatenate( - [ - clip_l_outputs["intermediate_output"], - clip_g_outputs["intermediate_output"], - ], - axis=-1, - ) - embeddings = ops.pad( - embeddings, - [[0, 0], [0, 0], [0, t5_hidden_dim - clip_hidden_dim]], + pooled_embeddings, embeddings = self.clip_concatenate( + clip_l_projection, + clip_g_projection, + clip_l_outputs["intermediate_output"], + clip_g_outputs["intermediate_output"], + padding=t5_hidden_dim - clip_hidden_dim, ) if self.t5 is not None: - t5_outputs = self.t5(token_ids["t5"], training=False) + t5_outputs = self.t5( + { + "token_ids": token_ids["t5"], + "padding_mask": ops.ones_like(token_ids["t5"]), + }, + training=False, + ) embeddings = ops.concatenate([embeddings, t5_outputs], axis=-2) else: padded_size = self.clip_l.max_sequence_length @@ -490,9 +521,7 @@ def encode(token_ids): def encode_image_step(self, images): latents = self.vae.encode(images) - return ops.multiply( - ops.subtract(latents, self.vae.shift), self.vae.scale - ) + return self.image_rescaling(latents) def add_noise_step(self, latents, noises, step, num_steps): return self.scheduler.add_noise(latents, noises, step, num_steps) @@ -553,8 +582,7 @@ def get_config(self): "output_channels": self.output_channels, "num_train_timesteps": self.num_train_timesteps, "shift": self.shift, - "height": self.height, - "width": self.width, + "image_shape": self.image_shape, } ) return config diff --git a/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone_test.py b/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone_test.py index 37723b0b5a..77415a6eec 100644 --- a/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone_test.py +++ b/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone_test.py @@ -11,7 +11,8 @@ class StableDiffusion3BackboneTest(TestCase): def setUp(self): - height, width = 64, 64 + image_shape = (64, 64, 3) + height, width = image_shape[0], image_shape[1] vae = VAEBackbone( [32, 32, 32, 32], [1, 1, 1, 1], @@ -36,8 +37,7 @@ def setUp(self): "vae": vae, "clip_l": clip_l, "clip_g": clip_g, - "height": height, - "width": width, + "image_shape": image_shape, } self.input_data = { "images": ops.ones((2, height, width, 3)), @@ -82,7 +82,6 @@ def test_all_presets(self): preset=preset, input_data=self.input_data, init_kwargs={ - "height": self.init_kwargs["height"], - "width": self.init_kwargs["width"], + "image_shape": self.init_kwargs["image_shape"], }, ) diff --git a/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_image_to_image.py b/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_image_to_image.py new file mode 100644 index 0000000000..285ba834b4 --- /dev/null +++ b/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_image_to_image.py @@ -0,0 +1,171 @@ +from keras import ops + +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.image_to_image import ImageToImage +from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_backbone import ( + StableDiffusion3Backbone, +) +from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_text_to_image_preprocessor import ( + StableDiffusion3TextToImagePreprocessor, +) + + +@keras_hub_export("keras_hub.models.StableDiffusion3ImageToImage") +class StableDiffusion3ImageToImage(ImageToImage): + """An end-to-end Stable Diffusion 3 model for image-to-image generation. + + This model has a `generate()` method, which generates images based + on a combination of a reference image and a text prompt. + + Args: + backbone: A `keras_hub.models.StableDiffusion3Backbone` instance. + preprocessor: A + `keras_hub.models.StableDiffusion3TextToImagePreprocessor` instance. + + Examples: + + Use `generate()` to do image generation. + ```python + image_to_image = keras_hub.models.StableDiffusion3ImageToImage.from_preset( + "stable_diffusion_3_medium", image_shape=(512, 512, 3) + ) + image_to_image.generate( + { + "images": np.ones((512, 512, 3), dtype="float32"), + "prompts": "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k", + } + ) + + # Generate with batched prompts. + image_to_image.generate( + { + "images": np.ones((2, 512, 512, 3), dtype="float32"), + "prompts": ["cute wallpaper art of a cat", "cute wallpaper art of a dog"], + } + ) + + # Generate with different `num_steps`, `guidance_scale` and `strength`. + image_to_image.generate( + { + "images": np.ones((512, 512, 3), dtype="float32"), + "prompts": "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k", + } + num_steps=50, + guidance_scale=5.0, + strength=0.6, + ) + + # Generate with `negative_prompts`. + text_to_image.generate( + { + "images": np.ones((512, 512, 3), dtype="float32"), + "prompts": "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k", + "negative_prompts": "green color", + } + ) + ``` + """ + + backbone_cls = StableDiffusion3Backbone + preprocessor_cls = StableDiffusion3TextToImagePreprocessor + + def __init__( + self, + backbone, + preprocessor, + **kwargs, + ): + # === Layers === + self.backbone = backbone + self.preprocessor = preprocessor + + # === Functional Model === + inputs = backbone.input + outputs = backbone.output + super().__init__( + inputs=inputs, + outputs=outputs, + **kwargs, + ) + + def fit(self, *args, **kwargs): + raise NotImplementedError( + "Currently, `fit` is not supported for " + "`StableDiffusion3ImageToImage`." + ) + + def generate_step( + self, + images, + noises, + token_ids, + starting_step, + num_steps, + guidance_scale, + ): + """A compilable generation function for batched of inputs. + + This function represents the inner, XLA-compilable, generation function + for batched inputs. + + Args: + images: A (batch_size, image_height, image_width, 3) tensor + containing the reference images. + noises: A (batch_size, latent_height, latent_width, channels) tensor + containing the noises to be added to the latents. Typically, + this tensor is sampled from the Gaussian distribution. + token_ids: A pair of (batch_size, num_tokens) tensor containing the + tokens based on the input prompts and negative prompts. + starting_step: int. The number of the starting diffusion step. + num_steps: int. The number of diffusion steps to take. + guidance_scale: float. The classifier free guidance scale defined in + [Classifier-Free Diffusion Guidance]( + https://arxiv.org/abs/2207.12598). Higher scale encourages to + generate images that are closely linked to prompts, usually at + the expense of lower image quality. + """ + token_ids, negative_token_ids = token_ids + + # Encode images. + latents = self.backbone.encode_image_step(images) + + # Add noises to latents. + latents = self.backbone.add_noise_step( + latents, noises, starting_step, num_steps + ) + + # Encode inputs. + embeddings = self.backbone.encode_text_step( + token_ids, negative_token_ids + ) + + # Denoise. + def body_fun(step, latents): + return self.backbone.denoise_step( + latents, + embeddings, + step, + num_steps, + guidance_scale, + ) + + latents = ops.fori_loop(starting_step, num_steps, body_fun, latents) + + # Decode. + return self.backbone.decode_step(latents) + + def generate( + self, + inputs, + num_steps=50, + guidance_scale=7.0, + strength=0.8, + seed=None, + ): + return super().generate( + inputs, + num_steps=num_steps, + guidance_scale=guidance_scale, + strength=strength, + seed=seed, + ) diff --git a/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_image_to_image_test.py b/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_image_to_image_test.py new file mode 100644 index 0000000000..8fa5b167ab --- /dev/null +++ b/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_image_to_image_test.py @@ -0,0 +1,180 @@ +import keras +import pytest +from keras import ops + +from keras_hub.src.models.clip.clip_preprocessor import CLIPPreprocessor +from keras_hub.src.models.clip.clip_text_encoder import CLIPTextEncoder +from keras_hub.src.models.clip.clip_tokenizer import CLIPTokenizer +from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_backbone import ( + StableDiffusion3Backbone, +) +from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_image_to_image import ( + StableDiffusion3ImageToImage, +) +from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_text_to_image_preprocessor import ( + StableDiffusion3TextToImagePreprocessor, +) +from keras_hub.src.models.vae.vae_backbone import VAEBackbone +from keras_hub.src.tests.test_case import TestCase + + +class StableDiffusion3ImageToImageTest(TestCase): + def setUp(self): + # Instantiate the preprocessor. + vocab = ["air", "plane", "port"] + vocab += ["<|endoftext|>", "<|startoftext|>"] + vocab = dict([(token, i) for i, token in enumerate(vocab)]) + merges = ["a i", "p l", "n e", "p o", "r t", "ai r", "pl a"] + merges += ["po rt", "pla ne"] + clip_l_tokenizer = CLIPTokenizer(vocab, merges, pad_with_end_token=True) + clip_g_tokenizer = CLIPTokenizer(vocab, merges) + clip_l_preprocessor = CLIPPreprocessor(clip_l_tokenizer) + clip_g_preprocessor = CLIPPreprocessor(clip_g_tokenizer) + self.preprocessor = StableDiffusion3TextToImagePreprocessor( + clip_l_preprocessor, clip_g_preprocessor + ) + + self.backbone = StableDiffusion3Backbone( + mmdit_patch_size=2, + mmdit_hidden_dim=16 * 2, + mmdit_num_layers=2, + mmdit_num_heads=2, + mmdit_position_size=192, + vae=VAEBackbone( + [32, 32, 32, 32], + [1, 1, 1, 1], + [32, 32, 32, 32], + [1, 1, 1, 1], + # Use `mode` generate a deterministic output. + sampler_method="mode", + name="vae", + ), + clip_l=CLIPTextEncoder( + 20, 64, 64, 2, 2, 128, "quick_gelu", -2, name="clip_l" + ), + clip_g=CLIPTextEncoder( + 20, 128, 128, 2, 2, 256, "gelu", -2, name="clip_g" + ), + image_shape=(64, 64, 3), + ) + self.init_kwargs = { + "preprocessor": self.preprocessor, + "backbone": self.backbone, + } + self.input_data = { + "images": ops.ones((2, 64, 64, 3)), + "latents": ops.ones((2, 8, 8, 16)), + "clip_l_token_ids": ops.ones((2, 5), dtype="int32"), + "clip_l_negative_token_ids": ops.ones((2, 5), dtype="int32"), + "clip_g_token_ids": ops.ones((2, 5), dtype="int32"), + "clip_g_negative_token_ids": ops.ones((2, 5), dtype="int32"), + "num_steps": ops.ones((2,), dtype="int32"), + "guidance_scale": ops.ones((2,)), + } + + def test_image_to_image_basics(self): + pytest.skip( + reason="TODO: enable after preprocessor flow is figured out" + ) + self.run_task_test( + cls=StableDiffusion3ImageToImage, + init_kwargs=self.init_kwargs, + train_data=None, + expected_output_shape={ + "images": (2, 64, 64, 3), + "latents": (2, 8, 8, 16), + }, + ) + + def test_generate(self): + image_to_image = StableDiffusion3ImageToImage(**self.init_kwargs) + seed = 42 + image = self.input_data["images"][0] + # String input. + prompt = ["airplane"] + negative_prompt = [""] + output = image_to_image.generate( + { + "images": image, + "prompts": prompt, + "negative_prompts": negative_prompt, + }, + seed=seed, + ) + # Int tensor input. + prompt_ids = self.preprocessor.generate_preprocess(prompt) + negative_prompt_ids = self.preprocessor.generate_preprocess( + negative_prompt + ) + image_to_image.preprocessor = None + output2 = image_to_image.generate( + { + "images": image, + "prompts": prompt_ids, + "negative_prompts": negative_prompt_ids, + }, + seed=seed, + ) + self.assertAllClose(output, output2) + + def test_generate_with_lower_precision(self): + original_floatx = keras.config.floatx() + try: + for dtype in ["float16", "bfloat16"]: + keras.config.set_floatx(dtype) + image_to_image = StableDiffusion3ImageToImage( + **self.init_kwargs + ) + seed = 42 + image = self.input_data["images"][0] + # String input. + prompt = ["airplane"] + negative_prompt = [""] + output = image_to_image.generate( + { + "images": image, + "prompts": prompt, + "negative_prompts": negative_prompt, + }, + seed=seed, + ) + # Int tensor input. + prompt_ids = self.preprocessor.generate_preprocess(prompt) + negative_prompt_ids = self.preprocessor.generate_preprocess( + negative_prompt + ) + image_to_image.preprocessor = None + output2 = image_to_image.generate( + { + "images": image, + "prompts": prompt_ids, + "negative_prompts": negative_prompt_ids, + }, + seed=seed, + ) + self.assertAllClose(output, output2) + finally: + # Restore floatx to the original value to prevent impact on other + # tests even if there is an exception. + keras.config.set_floatx(original_floatx) + + def test_generate_compilation(self): + image_to_image = StableDiffusion3ImageToImage(**self.init_kwargs) + image = self.input_data["images"][0] + # Assert we do not recompile with successive calls. + image_to_image.generate({"images": image, "prompts": "airplane"}) + first_fn = image_to_image.generate_function + image_to_image.generate({"images": image, "prompts": "airplane"}) + second_fn = image_to_image.generate_function + self.assertEqual(first_fn, second_fn) + # Assert we do recompile after compile is called. + image_to_image.compile() + self.assertIsNone(image_to_image.generate_function) + + @pytest.mark.large + def test_saved_model(self): + self.run_model_saving_test( + cls=StableDiffusion3ImageToImage, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + ) diff --git a/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_inpaint.py b/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_inpaint.py new file mode 100644 index 0000000000..8d5ed7c6af --- /dev/null +++ b/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_inpaint.py @@ -0,0 +1,194 @@ +from keras import ops + +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.inpaint import Inpaint +from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_backbone import ( + StableDiffusion3Backbone, +) +from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_text_to_image_preprocessor import ( + StableDiffusion3TextToImagePreprocessor, +) + + +@keras_hub_export("keras_hub.models.StableDiffusion3Inpaint") +class StableDiffusion3Inpaint(Inpaint): + """An end-to-end Stable Diffusion 3 model for inpaint generation. + + This model has a `generate()` method, which generates images based + on a combination of a reference image, mask and a text prompt. + + Args: + backbone: A `keras_hub.models.StableDiffusion3Backbone` instance. + preprocessor: A + `keras_hub.models.StableDiffusion3TextToImagePreprocessor` instance. + + Examples: + + Use `generate()` to do image generation. + ```python + reference_image = np.ones((1024, 1024, 3), dtype="float32") + reference_mask = np.ones((1024, 1024), dtype="float32") + inpaint = keras_hub.models.StableDiffusion3Inpaint.from_preset( + "stable_diffusion_3_medium", image_shape=(512, 512, 3) + ) + inpaint.generate( + reference_image, + reference_mask, + "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k", + ) + + # Generate with batched prompts. + reference_images = np.ones((2, 512, 512, 3), dtype="float32") + reference_mask = np.ones((2, 1024, 1024), dtype="float32") + inpaint.generate( + reference_images, + reference_mask, + ["cute wallpaper art of a cat", "cute wallpaper art of a dog"] + ) + + # Generate with different `num_steps`, `guidance_scale` and `strength`. + inpaint.generate( + reference_image, + reference_mask, + "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k", + num_steps=50, + guidance_scale=5.0, + strength=0.6, + ) + ``` + """ + + backbone_cls = StableDiffusion3Backbone + preprocessor_cls = StableDiffusion3TextToImagePreprocessor + + def __init__( + self, + backbone, + preprocessor, + **kwargs, + ): + # === Layers === + self.backbone = backbone + self.preprocessor = preprocessor + + # === Functional Model === + inputs = backbone.input + outputs = backbone.output + super().__init__( + inputs=inputs, + outputs=outputs, + **kwargs, + ) + + def fit(self, *args, **kwargs): + raise NotImplementedError( + "Currently, `fit` is not supported for " + "`StableDiffusion3Inpaint`." + ) + + def generate_step( + self, + images, + masks, + noises, + token_ids, + starting_step, + num_steps, + guidance_scale, + ): + """A compilable generation function for batched of inputs. + + This function represents the inner, XLA-compilable, generation function + for batched inputs. + + Args: + images: A (batch_size, image_height, image_width, 3) tensor + containing the reference images. + masks: A (batch_size, image_height, image_width) tensor + containing the reference masks. + noises: A (batch_size, latent_height, latent_width, channels) tensor + containing the noises to be added to the latents. Typically, + this tensor is sampled from the Gaussian distribution. + token_ids: A pair of (batch_size, num_tokens) tensor containing the + tokens based on the input prompts and negative prompts. + starting_step: int. The number of the starting diffusion step. + num_steps: int. The number of diffusion steps to take. + guidance_scale: float. The classifier free guidance scale defined in + [Classifier-Free Diffusion Guidance]( + https://arxiv.org/abs/2207.12598). Higher scale encourages to + generate images that are closely linked to prompts, usually at + the expense of lower image quality. + """ + token_ids, negative_token_ids = token_ids + + # Get masked images. + masks = ops.cast(ops.expand_dims(masks, axis=-1) > 0.5, images.dtype) + masks_latent_size = ops.image.resize( + masks, + (self.backbone.latent_shape[1], self.backbone.latent_shape[2]), + interpolation="nearest", + ) + + # Encode images. + image_latents = self.backbone.encode_image_step(images) + + # Add noises to latents. + latents = self.backbone.add_noise_step( + image_latents, noises, starting_step, num_steps + ) + + # Encode inputs. + embeddings = self.backbone.encode_text_step( + token_ids, negative_token_ids + ) + + # Denoise. + def body_fun(step, latents): + latents = self.backbone.denoise_step( + latents, + embeddings, + step, + num_steps, + guidance_scale, + ) + + # Compute the previous latents x_t -> x_t-1. + def true_fn(): + next_step = ops.add(step, 1) + return self.backbone.add_noise_step( + image_latents, noises, next_step, num_steps + ) + + init_latents = ops.cond( + step < ops.subtract(num_steps, 1), + true_fn, + lambda: ops.cast(image_latents, noises.dtype), + ) + latents = ops.add( + ops.multiply( + ops.subtract(1.0, masks_latent_size), init_latents + ), + ops.multiply(masks_latent_size, latents), + ) + return latents + + latents = ops.fori_loop(starting_step, num_steps, body_fun, latents) + + # Decode. + return self.backbone.decode_step(latents) + + def generate( + self, + inputs, + num_steps=50, + guidance_scale=7.0, + strength=0.6, + seed=None, + ): + return super().generate( + inputs, + num_steps=num_steps, + guidance_scale=guidance_scale, + strength=strength, + seed=seed, + ) diff --git a/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_inpaint_test.py b/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_inpaint_test.py new file mode 100644 index 0000000000..5e8ddd32c6 --- /dev/null +++ b/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_inpaint_test.py @@ -0,0 +1,197 @@ +import keras +import pytest +from keras import ops + +from keras_hub.src.models.clip.clip_preprocessor import CLIPPreprocessor +from keras_hub.src.models.clip.clip_text_encoder import CLIPTextEncoder +from keras_hub.src.models.clip.clip_tokenizer import CLIPTokenizer +from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_backbone import ( + StableDiffusion3Backbone, +) +from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_inpaint import ( + StableDiffusion3Inpaint, +) +from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_text_to_image_preprocessor import ( + StableDiffusion3TextToImagePreprocessor, +) +from keras_hub.src.models.vae.vae_backbone import VAEBackbone +from keras_hub.src.tests.test_case import TestCase + + +class StableDiffusion3InpaintTest(TestCase): + def setUp(self): + # Instantiate the preprocessor. + vocab = ["air", "plane", "port"] + vocab += ["<|endoftext|>", "<|startoftext|>"] + vocab = dict([(token, i) for i, token in enumerate(vocab)]) + merges = ["a i", "p l", "n e", "p o", "r t", "ai r", "pl a"] + merges += ["po rt", "pla ne"] + clip_l_tokenizer = CLIPTokenizer(vocab, merges, pad_with_end_token=True) + clip_g_tokenizer = CLIPTokenizer(vocab, merges) + clip_l_preprocessor = CLIPPreprocessor(clip_l_tokenizer) + clip_g_preprocessor = CLIPPreprocessor(clip_g_tokenizer) + self.preprocessor = StableDiffusion3TextToImagePreprocessor( + clip_l_preprocessor, clip_g_preprocessor + ) + + self.backbone = StableDiffusion3Backbone( + mmdit_patch_size=2, + mmdit_hidden_dim=16 * 2, + mmdit_num_layers=2, + mmdit_num_heads=2, + mmdit_position_size=192, + vae=VAEBackbone( + [32, 32, 32, 32], + [1, 1, 1, 1], + [32, 32, 32, 32], + [1, 1, 1, 1], + # Use `mode` generate a deterministic output. + sampler_method="mode", + name="vae", + ), + clip_l=CLIPTextEncoder( + 20, 64, 64, 2, 2, 128, "quick_gelu", -2, name="clip_l" + ), + clip_g=CLIPTextEncoder( + 20, 128, 128, 2, 2, 256, "gelu", -2, name="clip_g" + ), + image_shape=(64, 64, 3), + ) + self.init_kwargs = { + "preprocessor": self.preprocessor, + "backbone": self.backbone, + } + self.input_data = { + "images": ops.ones((2, 64, 64, 3)), + "latents": ops.ones((2, 8, 8, 16)), + "clip_l_token_ids": ops.ones((2, 5), dtype="int32"), + "clip_l_negative_token_ids": ops.ones((2, 5), dtype="int32"), + "clip_g_token_ids": ops.ones((2, 5), dtype="int32"), + "clip_g_negative_token_ids": ops.ones((2, 5), dtype="int32"), + "num_steps": ops.ones((2,), dtype="int32"), + "guidance_scale": ops.ones((2,)), + } + + def test_inpaint_basics(self): + pytest.skip( + reason="TODO: enable after preprocessor flow is figured out" + ) + self.run_task_test( + cls=StableDiffusion3Inpaint, + init_kwargs=self.init_kwargs, + train_data=None, + expected_output_shape={ + "images": (2, 64, 64, 3), + "latents": (2, 8, 8, 16), + }, + ) + + def test_generate(self): + inpaint = StableDiffusion3Inpaint(**self.init_kwargs) + seed = 42 + image = self.input_data["images"][0] + mask = self.input_data["images"][0][..., 0] # (B, H, W) + # String input. + prompt = ["airplane"] + negative_prompt = [""] + output = inpaint.generate( + { + "images": image, + "masks": mask, + "prompts": prompt, + "negative_prompts": negative_prompt, + }, + seed=seed, + ) + # Int tensor input. + prompt_ids = self.preprocessor.generate_preprocess(prompt) + negative_prompt_ids = self.preprocessor.generate_preprocess( + negative_prompt + ) + inpaint.preprocessor = None + output2 = inpaint.generate( + { + "images": image, + "masks": mask, + "prompts": prompt_ids, + "negative_prompts": negative_prompt_ids, + }, + seed=seed, + ) + self.assertAllClose(output, output2) + + def test_generate_with_lower_precision(self): + original_floatx = keras.config.floatx() + try: + for dtype in ["float16", "bfloat16"]: + keras.config.set_floatx(dtype) + inpaint = StableDiffusion3Inpaint(**self.init_kwargs) + seed = 42 + image = self.input_data["images"][0] + mask = self.input_data["images"][0][..., 0] # (B, H, W) + # String input. + prompt = ["airplane"] + negative_prompt = [""] + output = inpaint.generate( + { + "images": image, + "masks": mask, + "prompts": prompt, + "negative_prompts": negative_prompt, + }, + seed=seed, + ) + # Int tensor input. + prompt_ids = self.preprocessor.generate_preprocess(prompt) + negative_prompt_ids = self.preprocessor.generate_preprocess( + negative_prompt + ) + inpaint.preprocessor = None + output2 = inpaint.generate( + { + "images": image, + "masks": mask, + "prompts": prompt_ids, + "negative_prompts": negative_prompt_ids, + }, + seed=seed, + ) + self.assertAllClose(output, output2) + finally: + # Restore floatx to the original value to prevent impact on other + # tests even if there is an exception. + keras.config.set_floatx(original_floatx) + + def test_generate_compilation(self): + inpaint = StableDiffusion3Inpaint(**self.init_kwargs) + image = self.input_data["images"][0] + mask = self.input_data["images"][0][..., 0] # (B, H, W) + # Assert we do not recompile with successive calls. + inpaint.generate( + { + "images": image, + "masks": mask, + "prompts": "airplane", + } + ) + first_fn = inpaint.generate_function + inpaint.generate( + { + "images": image, + "masks": mask, + "prompts": "airplane", + } + ) + second_fn = inpaint.generate_function + self.assertEqual(first_fn, second_fn) + # Assert we do recompile after compile is called. + inpaint.compile() + self.assertIsNone(inpaint.generate_function) + + @pytest.mark.large + def test_saved_model(self): + self.run_model_saving_test( + cls=StableDiffusion3Inpaint, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + ) diff --git a/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_presets.py b/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_presets.py index 2067fdb8dc..a7756fc645 100644 --- a/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_presets.py +++ b/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_presets.py @@ -10,9 +10,9 @@ ), "params": 2987080931, "official_name": "StableDiffusion3", - "path": "stablediffusion3", + "path": "stable_diffusion_3", "model_card": "https://arxiv.org/abs/2110.00476", }, - "kaggle_handle": "kaggle://kerashub/stablediffusion3/keras/stable_diffusion_3_medium/3", + "kaggle_handle": "kaggle://keras/stablediffusion3/keras/stable_diffusion_3_medium/3", } } diff --git a/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image.py b/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image.py index 63f0ba6c28..739c6f4650 100644 --- a/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image.py +++ b/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image.py @@ -27,7 +27,7 @@ class StableDiffusion3TextToImage(TextToImage): Use `generate()` to do image generation. ```python text_to_image = keras_hub.models.StableDiffusion3TextToImage.from_preset( - "stable_diffusion_3_medium", height=512, width=512 + "stable_diffusion_3_medium", image_shape=(512, 512, 3) ) text_to_image.generate( "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k" @@ -44,6 +44,14 @@ class StableDiffusion3TextToImage(TextToImage): num_steps=50, guidance_scale=5.0, ) + + # Generate with `negative_prompts`. + text_to_image.generate( + { + "prompts": "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k", + "negative_prompts": "green color", + } + ) ``` """ @@ -79,7 +87,6 @@ def generate_step( self, latents, token_ids, - negative_token_ids, num_steps, guidance_scale, ): @@ -92,10 +99,8 @@ def generate_step( latents: A (batch_size, height, width, channels) tensor containing the latents to start generation from. Typically, this tensor is sampled from the Gaussian distribution. - token_ids: A (batch_size, num_tokens) tensor containing the - tokens based on the input prompts. - negative_token_ids: A (batch_size, num_tokens) tensor - containing the negative tokens based on the input prompts. + token_ids: A pair of (batch_size, num_tokens) tensor containing the + tokens based on the input prompts and negative prompts. num_steps: int. The number of diffusion steps to take. guidance_scale: float. The classifier free guidance scale defined in [Classifier-Free Diffusion Guidance]( @@ -103,7 +108,9 @@ def generate_step( generate images that are closely linked to prompts, usually at the expense of lower image quality. """ - # Encode inputs. + token_ids, negative_token_ids = token_ids + + # Encode prompts. embeddings = self.backbone.encode_text_step( token_ids, negative_token_ids ) @@ -126,14 +133,12 @@ def body_fun(step, latents): def generate( self, inputs, - negative_inputs=None, num_steps=28, guidance_scale=7.0, seed=None, ): return super().generate( inputs, - negative_inputs=negative_inputs, num_steps=num_steps, guidance_scale=guidance_scale, seed=seed, diff --git a/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image_test.py b/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image_test.py index 837c95fa37..69d30de834 100644 --- a/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image_test.py +++ b/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image_test.py @@ -55,8 +55,7 @@ def setUp(self): clip_g=CLIPTextEncoder( 20, 128, 128, 2, 2, 256, "gelu", -2, name="clip_g" ), - height=64, - width=64, + image_shape=(64, 64, 3), ) self.init_kwargs = { "preprocessor": self.preprocessor, @@ -93,7 +92,13 @@ def test_generate(self): # String input. prompt = ["airplane"] negative_prompt = [""] - output = text_to_image.generate(prompt, negative_prompt, seed=seed) + output = text_to_image.generate( + { + "prompts": prompt, + "negative_prompts": negative_prompt, + }, + seed=seed, + ) # Int tensor input. prompt_ids = self.preprocessor.generate_preprocess(prompt) negative_prompt_ids = self.preprocessor.generate_preprocess( @@ -101,7 +106,11 @@ def test_generate(self): ) text_to_image.preprocessor = None output2 = text_to_image.generate( - prompt_ids, negative_prompt_ids, seed=seed + { + "prompts": prompt_ids, + "negative_prompts": negative_prompt_ids, + }, + seed=seed, ) self.assertAllClose(output, output2) @@ -116,7 +125,11 @@ def test_generate_with_lower_precision(self): prompt = ["airplane"] negative_prompt = [""] output = text_to_image.generate( - prompt, negative_prompt, seed=seed + { + "prompts": prompt, + "negative_prompts": negative_prompt, + }, + seed=seed, ) # Int tensor input. prompt_ids = self.preprocessor.generate_preprocess(prompt) @@ -125,7 +138,11 @@ def test_generate_with_lower_precision(self): ) text_to_image.preprocessor = None output2 = text_to_image.generate( - prompt_ids, negative_prompt_ids, seed=seed + { + "prompts": prompt_ids, + "negative_prompts": negative_prompt_ids, + }, + seed=seed, ) self.assertAllClose(output, output2) finally: diff --git a/keras_hub/src/models/task.py b/keras_hub/src/models/task.py index b107284444..af12f1cb1c 100644 --- a/keras_hub/src/models/task.py +++ b/keras_hub/src/models/task.py @@ -280,7 +280,7 @@ def summary( def highlight_number(x): if x is None: - f"[color(45)]{x}[/]" + return f"[color(45)]{x}[/]" return f"[color(34)]{x:,}[/]" # Format number with commas. def highlight_symbol(x): @@ -339,7 +339,10 @@ def add_layer(layer, info): add_layer(layer, info) elif isinstance(layer, ImageConverter): info = "Image size: " - info += highlight_shape(layer.image_size()) + image_size = layer.image_size + if image_size is None: + image_size = (None, None) + info += highlight_shape(image_size) add_layer(layer, info) elif isinstance(layer, AudioConverter): info = "Audio shape: " diff --git a/keras_hub/src/models/text_to_image.py b/keras_hub/src/models/text_to_image.py index 291a4b023e..54b8dcdae2 100644 --- a/keras_hub/src/models/text_to_image.py +++ b/keras_hub/src/models/text_to_image.py @@ -56,6 +56,11 @@ def __init__(self, *args, **kwargs): # Default compilation. self.compile() + @property + def support_negative_prompts(self): + """Whether the model supports `negative_prompts` key in `generate()`.""" + return bool(True) + @property def latent_shape(self): return tuple(self.backbone.latent_shape) @@ -171,9 +176,26 @@ def _normalize_generate_inputs(self, inputs): This function converts all inputs to tensors, adds a batch dimension if necessary, and returns a iterable "dataset like" object (either an actual `tf.data.Dataset` or a list with a single batch element). + + The input format must be one of the following: + - A single string + - A list of strings + - A dict with "prompts" and/or "negative_prompts" keys + - A tf.data.Dataset with "prompts" and/or "negative_prompts" keys + + The output will be a dict with "prompts" and/or "negative_prompts" keys. """ if tf and isinstance(inputs, tf.data.Dataset): - return inputs.as_numpy_iterator(), False + _inputs = { + "prompts": inputs.map( + lambda x: x["prompts"] + ).as_numpy_iterator() + } + if self.support_negative_prompts: + _inputs["negative_prompts"] = inputs.map( + lambda x: x["negative_prompts"] + ).as_numpy_iterator() + return _inputs, False def normalize(x): if isinstance(x, str): @@ -182,13 +204,24 @@ def normalize(x): return x[tf.newaxis], True return x, False + def get_dummy_prompts(x): + dummy_prompts = [""] * len(x) + if tf and isinstance(x, tf.Tensor): + return tf.convert_to_tensor(dummy_prompts) + else: + return dummy_prompts + if isinstance(inputs, dict): for key in inputs: inputs[key], input_is_scalar = normalize(inputs[key]) else: inputs, input_is_scalar = normalize(inputs) + inputs = {"prompts": inputs} - return inputs, input_is_scalar + if self.support_negative_prompts and "negative_prompts" not in inputs: + inputs["negative_prompts"] = get_dummy_prompts(inputs["prompts"]) + + return [inputs], input_is_scalar def _normalize_generate_outputs(self, outputs, input_is_scalar): """Normalize user output from the generate function. @@ -199,12 +232,11 @@ def _normalize_generate_outputs(self, outputs, input_is_scalar): """ def normalize(x): - outputs = ops.clip(ops.divide(ops.add(x, 1.0), 2.0), 0.0, 1.0) + outputs = ops.concatenate(x, axis=0) + outputs = ops.clip(ops.divide(ops.add(outputs, 1.0), 2.0), 0.0, 1.0) outputs = ops.cast(ops.round(ops.multiply(outputs, 255.0)), "uint8") - outputs = ops.convert_to_numpy(outputs) - if input_is_scalar: - outputs = outputs[0] - return outputs + outputs = ops.squeeze(outputs, 0) if input_is_scalar else outputs + return ops.convert_to_numpy(outputs) if isinstance(outputs[0], dict): normalized = {} @@ -216,23 +248,40 @@ def normalize(x): def generate( self, inputs, - negative_inputs, num_steps, guidance_scale, seed=None, ): - """Generate image based on the provided `inputs` and `negative_inputs`. + """Generate image based on the provided `inputs`. + + Typically, `inputs` contains a text description (known as a prompt) used + to guide the image generation. + + Some models support a `negative_prompts` key, which helps steer the + model away from generating certain styles and elements. To enable this, + pass `prompts` and `negative_prompts` as a dict: + + ```python + text_to_image.generate( + { + "prompts": "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k", + "negative_prompts": "green color", + } + ) + ``` If `inputs` are a `tf.data.Dataset`, outputs will be generated "batch-by-batch" and concatenated. Otherwise, all inputs will be processed as batches. Args: - inputs: python data, tensor data, or a `tf.data.Dataset`. - negative_inputs: python data, tensor data, or a `tf.data.Dataset`. - Unlike `inputs`, these are used as negative inputs to guide the - generation. If not provided, it defaults to `""` for each input - in `inputs`. + inputs: python data, tensor data, or a `tf.data.Dataset`. The format + must be one of the following: + - A single string + - A list of strings + - A dict with "prompts" and/or "negative_prompts" keys + - A `tf.data.Dataset` with "prompts" and/or "negative_prompts" + keys num_steps: int. The number of diffusion steps to take. guidance_scale: float. The classifier free guidance scale defined in [Classifier-Free Diffusion Guidance]( @@ -251,32 +300,36 @@ def generate( generate_function = self.make_generate_function() def preprocess(x): - return self.preprocessor.generate_preprocess(x) + if self.preprocessor is not None: + return self.preprocessor.generate_preprocess(x) + else: + return x + + def generate(x): + token_ids = x[0] if self.support_negative_prompts else x + + # Initialize latents. + if isinstance(token_ids, dict): + arbitrary_key = list(token_ids.keys())[0] + batch_size = ops.shape(token_ids[arbitrary_key])[0] + else: + batch_size = ops.shape(token_ids)[0] + latent_shape = (batch_size,) + self.latent_shape[1:] + latents = random.normal(latent_shape, dtype="float32", seed=seed) + + return generate_function(latents, x, num_steps, guidance_scale) # Normalize and preprocess inputs. inputs, input_is_scalar = self._normalize_generate_inputs(inputs) - if negative_inputs is None: - negative_inputs = [""] * len(inputs) - negative_inputs, _ = self._normalize_generate_inputs(negative_inputs) - - if self.preprocessor is not None: - inputs = preprocess(inputs) - negative_inputs = preprocess(negative_inputs) - if isinstance(inputs, dict): - batch_size = len(inputs[list(inputs.keys())[0]]) + if self.support_negative_prompts: + token_ids = [preprocess(x["prompts"]) for x in inputs] + negative_token_ids = [ + preprocess(x["negative_prompts"]) for x in inputs + ] + inputs = [x for x in zip(token_ids, negative_token_ids)] else: - batch_size = len(inputs) - - # Initialize random latents. - latent_shape = (batch_size,) + self.latent_shape[1:] - latents = random.normal(latent_shape, dtype="float32", seed=seed) + inputs = [preprocess(x["prompts"]) for x in inputs] # Text-to-image. - outputs = generate_function( - latents, - inputs, - negative_inputs, - num_steps, - guidance_scale, - ) + outputs = [generate(x) for x in inputs] return self._normalize_generate_outputs(outputs, input_is_scalar) diff --git a/keras_hub/src/models/vae/vae_backbone.py b/keras_hub/src/models/vae/vae_backbone.py index c84986314d..606107d17f 100644 --- a/keras_hub/src/models/vae/vae_backbone.py +++ b/keras_hub/src/models/vae/vae_backbone.py @@ -10,7 +10,7 @@ class VAEBackbone(Backbone): - """VAE backbone used in latent diffusion models. + """Variational Autoencoder(VAE) backbone used in latent diffusion models. When encoding, this model generates mean and log variance of the input images. When decoding, it reconstructs images from the latent space. @@ -51,6 +51,18 @@ class VAEBackbone(Backbone): `"channels_last"`. dtype: `None` or str or `keras.mixed_precision.DTypePolicy`. The dtype to use for the model's computations and weights. + + Example: + ```Python + backbone = VAEBackbone( + encoder_num_filters=[32, 32, 32, 32], + encoder_num_blocks=[1, 1, 1, 1], + decoder_num_filters=[32, 32, 32, 32], + decoder_num_blocks=[1, 1, 1, 1], + ) + input_data = ops.ones((2, self.height, self.width, 3)) + output = backbone(input_data) + ``` """ def __init__( diff --git a/keras_hub/src/models/vgg/__init__.py b/keras_hub/src/models/vgg/__init__.py index e69de29bb2..4850d0eab4 100644 --- a/keras_hub/src/models/vgg/__init__.py +++ b/keras_hub/src/models/vgg/__init__.py @@ -0,0 +1,5 @@ +from keras_hub.src.models.vgg.vgg_backbone import VGGBackbone +from keras_hub.src.models.vgg.vgg_presets import backbone_presets +from keras_hub.src.utils.preset_utils import register_presets + +register_presets(backbone_presets, VGGBackbone) diff --git a/keras_hub/src/models/vgg/vgg_backbone.py b/keras_hub/src/models/vgg/vgg_backbone.py index cf2638146e..ef91c8689d 100644 --- a/keras_hub/src/models/vgg/vgg_backbone.py +++ b/keras_hub/src/models/vgg/vgg_backbone.py @@ -20,7 +20,7 @@ class VGGBackbone(Backbone): stackwise_num_filters: list of ints, filter size for convolutional blocks per VGG block. For both VGG16 and VGG19 this is [ 64, 128, 256, 512, 512]. - image_shape: tuple, optional shape tuple, defaults to (224, 224, 3). + image_shape: tuple, optional shape tuple, defaults to (None, None, 3). Examples: ```python @@ -47,12 +47,11 @@ def __init__( image_shape=(None, None, 3), **kwargs, ): - # === Functional Model === img_input = keras.layers.Input(shape=image_shape) x = img_input - for stack_index in range(len(stackwise_num_repeats) - 1): + for stack_index in range(len(stackwise_num_repeats)): x = apply_vgg_block( x=x, num_layers=stackwise_num_repeats[stack_index], diff --git a/keras_hub/src/models/vgg/vgg_backbone_test.py b/keras_hub/src/models/vgg/vgg_backbone_test.py index 87e9ed6ef5..19dd7844da 100644 --- a/keras_hub/src/models/vgg/vgg_backbone_test.py +++ b/keras_hub/src/models/vgg/vgg_backbone_test.py @@ -19,7 +19,7 @@ def test_backbone_basics(self): cls=VGGBackbone, init_kwargs=self.init_kwargs, input_data=self.input_data, - expected_output_shape=(2, 4, 4, 64), + expected_output_shape=(2, 2, 2, 64), run_mixed_precision_check=False, ) diff --git a/keras_hub/src/models/vgg/vgg_image_classifier.py b/keras_hub/src/models/vgg/vgg_image_classifier.py index 4d02f1ca5f..a72b256288 100644 --- a/keras_hub/src/models/vgg/vgg_image_classifier.py +++ b/keras_hub/src/models/vgg/vgg_image_classifier.py @@ -4,6 +4,9 @@ from keras_hub.src.models.image_classifier import ImageClassifier from keras_hub.src.models.task import Task from keras_hub.src.models.vgg.vgg_backbone import VGGBackbone +from keras_hub.src.models.vgg.vgg_image_classifier_preprocessor import ( + VGGImageClassifierPreprocessor, +) @keras_hub_export("keras_hub.models.VGGImageClassifier") @@ -96,13 +99,14 @@ class VGGImageClassifier(ImageClassifier): """ backbone_cls = VGGBackbone + preprocessor_cls = VGGImageClassifierPreprocessor def __init__( self, backbone, num_classes, preprocessor=None, - pooling="flatten", + pooling="avg", pooling_hidden_dim=4096, activation=None, dropout=0.0, @@ -141,24 +145,46 @@ def __init__( "Unknown `pooling` type. Polling should be either `'avg'` or " f"`'max'`. Received: pooling={pooling}." ) - self.output_dropout = keras.layers.Dropout( - dropout, - dtype=head_dtype, - name="output_dropout", - ) - self.output_dense = keras.layers.Dense( - num_classes, - activation=activation, - dtype=head_dtype, - name="predictions", + + self.head = keras.Sequential( + [ + keras.layers.Conv2D( + filters=4096, + kernel_size=7, + name="fc1", + activation=activation, + use_bias=True, + padding="same", + ), + keras.layers.Dropout( + rate=dropout, + dtype=head_dtype, + name="output_dropout", + ), + keras.layers.Conv2D( + filters=4096, + kernel_size=1, + name="fc2", + activation=activation, + use_bias=True, + padding="same", + ), + self.pooler, + keras.layers.Dense( + num_classes, + activation=activation, + dtype=head_dtype, + name="predictions", + ), + ], + name="head", ) # === Functional Model === inputs = self.backbone.input x = self.backbone(inputs) - x = self.pooler(x) - x = self.output_dropout(x) - outputs = self.output_dense(x) + outputs = self.head(x) + # Skip the parent class functional model. Task.__init__( self, @@ -173,6 +199,7 @@ def __init__( self.pooling = pooling self.pooling_hidden_dim = pooling_hidden_dim self.dropout = dropout + self.preprocessor = preprocessor def get_config(self): # Backbone serialized in `super` diff --git a/keras_hub/src/models/vgg/vgg_image_classifier_preprocessor.py b/keras_hub/src/models/vgg/vgg_image_classifier_preprocessor.py new file mode 100644 index 0000000000..f32f965095 --- /dev/null +++ b/keras_hub/src/models/vgg/vgg_image_classifier_preprocessor.py @@ -0,0 +1,12 @@ +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.image_classifier_preprocessor import ( + ImageClassifierPreprocessor, +) +from keras_hub.src.models.vgg.vgg_backbone import VGGBackbone +from keras_hub.src.models.vgg.vgg_image_converter import VGGImageConverter + + +@keras_hub_export("keras_hub.models.VGGImageClassifierPreprocessor") +class VGGImageClassifierPreprocessor(ImageClassifierPreprocessor): + backbone_cls = VGGBackbone + image_converter_cls = VGGImageConverter diff --git a/keras_hub/src/models/vgg/vgg_image_classifier_test.py b/keras_hub/src/models/vgg/vgg_image_classifier_test.py index 6d95ddaac5..34bb7e3db8 100644 --- a/keras_hub/src/models/vgg/vgg_image_classifier_test.py +++ b/keras_hub/src/models/vgg/vgg_image_classifier_test.py @@ -3,24 +3,33 @@ from keras_hub.src.models.vgg.vgg_backbone import VGGBackbone from keras_hub.src.models.vgg.vgg_image_classifier import VGGImageClassifier +from keras_hub.src.models.vgg.vgg_image_classifier_preprocessor import ( + VGGImageClassifierPreprocessor, +) +from keras_hub.src.models.vgg.vgg_image_converter import VGGImageConverter from keras_hub.src.tests.test_case import TestCase class VGGImageClassifierTest(TestCase): def setUp(self): # Setup model. - self.images = np.ones((2, 4, 4, 3), dtype="float32") - self.labels = [0, 3] + self.images = np.ones((2, 8, 8, 3), dtype="float32") + self.labels = [0, 1] self.backbone = VGGBackbone( stackwise_num_repeats=[2, 4, 4], stackwise_num_filters=[2, 16, 16], - image_shape=(4, 4, 3), + image_shape=(8, 8, 3), + ) + image_converter = VGGImageConverter(image_size=(8, 8)) + self.preprocessor = VGGImageClassifierPreprocessor( + image_converter=image_converter, ) self.init_kwargs = { "backbone": self.backbone, "num_classes": 2, "activation": "softmax", "pooling": "flatten", + "preprocessor": self.preprocessor, } self.train_data = ( self.images, @@ -28,9 +37,6 @@ def setUp(self): ) def test_classifier_basics(self): - pytest.skip( - reason="TODO: enable after preprocessor flow is figured out" - ) self.run_task_test( cls=VGGImageClassifier, init_kwargs=self.init_kwargs, diff --git a/keras_hub/src/models/vgg/vgg_image_converter.py b/keras_hub/src/models/vgg/vgg_image_converter.py new file mode 100644 index 0000000000..69ccacbd1d --- /dev/null +++ b/keras_hub/src/models/vgg/vgg_image_converter.py @@ -0,0 +1,8 @@ +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.layers.preprocessing.image_converter import ImageConverter +from keras_hub.src.models.vgg.vgg_backbone import VGGBackbone + + +@keras_hub_export("keras_hub.layers.VGGImageConverter") +class VGGImageConverter(ImageConverter): + backbone_cls = VGGBackbone diff --git a/keras_hub/src/models/vgg/vgg_presets.py b/keras_hub/src/models/vgg/vgg_presets.py new file mode 100644 index 0000000000..e0379a8da0 --- /dev/null +++ b/keras_hub/src/models/vgg/vgg_presets.py @@ -0,0 +1,56 @@ +"""vgg preset configurations.""" + +backbone_presets = { + "vgg_11_imagenet": { + "metadata": { + "description": ( + "11-layer vgg model pre-trained on the ImageNet 1k dataset " + "at a 224x224 resolution." + ), + "params": 9220480, + "official_name": "vgg", + "path": "vgg", + "model_card": "https://arxiv.org/abs/1409.1556", + }, + "kaggle_handle": "kaggle://keras/vgg/keras/vgg_11_imagenet/1", + }, + "vgg_13_imagenet": { + "metadata": { + "description": ( + "13-layer vgg model pre-trained on the ImageNet 1k dataset " + "at a 224x224 resolution." + ), + "params": 9404992, + "official_name": "vgg", + "path": "vgg", + "model_card": "https://arxiv.org/abs/1409.1556", + }, + "kaggle_handle": "kaggle://keras/vgg/keras/vgg_13_imagenet/1", + }, + "vgg_16_imagenet": { + "metadata": { + "description": ( + "16-layer vgg model pre-trained on the ImageNet 1k dataset " + "at a 224x224 resolution." + ), + "params": 14714688, + "official_name": "vgg", + "path": "vgg", + "model_card": "https://arxiv.org/abs/1409.1556", + }, + "kaggle_handle": "kaggle://keras/vgg/keras/vgg_16_imagenet/1", + }, + "vgg_19_imagenet": { + "metadata": { + "description": ( + "19-layer vgg model pre-trained on the ImageNet 1k dataset " + "at a 224x224 resolution." + ), + "params": 20024384, + "official_name": "vgg", + "path": "vgg", + "model_card": "https://arxiv.org/abs/1409.1556", + }, + "kaggle_handle": "kaggle://keras/vgg/keras/vgg_19_imagenet/1", + }, +} diff --git a/keras_hub/src/models/vit_det/vit_det_backbone.py b/keras_hub/src/models/vit_det/vit_det_backbone.py index 94f7887c44..7d5883409e 100644 --- a/keras_hub/src/models/vit_det/vit_det_backbone.py +++ b/keras_hub/src/models/vit_det/vit_det_backbone.py @@ -31,7 +31,7 @@ class ViTDetBackbone(Backbone): global_attention_layer_indices (list): Indexes for blocks using global attention. image_shape (tuple[int], optional): The size of the input image in - `(H, W, C)` format. Defaults to `(1024, 1024, 3)`. + `(H, W, C)` format. Defaults to `(None, None, 3)`. patch_size (int, optional): the patch size to be supplied to the Patching layer to turn input images into a flattened sequence of patches. Defaults to `16`. @@ -79,7 +79,7 @@ def __init__( intermediate_dim, num_heads, global_attention_layer_indices, - image_shape=(1024, 1024, 3), + image_shape=(None, None, 3), patch_size=16, num_output_channels=256, use_bias=True, diff --git a/keras_hub/src/models/whisper/whisper_audio_converter.py b/keras_hub/src/models/whisper/whisper_audio_converter.py index 633042f547..9890109bac 100644 --- a/keras_hub/src/models/whisper/whisper_audio_converter.py +++ b/keras_hub/src/models/whisper/whisper_audio_converter.py @@ -39,7 +39,7 @@ class WhisperAudioConverter(AudioConverter): audio_tensor = tf.ones((8000,), dtype="float32") # Compute the log-mel spectrogram. - audio_converter = keras_hub.models.WhisperAudioConverter.from_preset( + audio_converter = keras_hub.layers.WhisperAudioConverter.from_preset( "whisper_base_en", ) audio_converter(audio_tensor) diff --git a/keras_hub/src/tests/test_case.py b/keras_hub/src/tests/test_case.py index 6d06c7266c..03c01cb24b 100644 --- a/keras_hub/src/tests/test_case.py +++ b/keras_hub/src/tests/test_case.py @@ -388,6 +388,8 @@ def run_model_saving_test( cls, init_kwargs, input_data, + atol=0.000001, + rtol=0.000001, ): """Save and load a model from disk and assert output is unchanged.""" model = cls(**init_kwargs) @@ -401,7 +403,7 @@ def run_model_saving_test( # Check that output matches. restored_output = restored_model(input_data) - self.assertAllClose(model_output, restored_output) + self.assertAllClose(model_output, restored_output, atol=atol, rtol=rtol) def run_backbone_test( self, @@ -567,6 +569,15 @@ def run_task_test( ds = tf.data.Dataset.from_tensor_slices(train_data).batch(batch_size) x, y, sw = keras.utils.unpack_x_y_sample_weight(train_data) + # Test: the tree struct output by the + # preprocessor must match what model expects. + preprocessed_data = preprocessor(*train_data)[0] + tree.assert_same_structure( + preprocessed_data, + task._inputs_struct, + check_types=False, + ) + # Test predict. output = task.predict(x) if expected_output_shape is not None: diff --git a/keras_hub/src/tokenizers/byte_pair_tokenizer.py b/keras_hub/src/tokenizers/byte_pair_tokenizer.py index 41cef2b652..a7447c562e 100644 --- a/keras_hub/src/tokenizers/byte_pair_tokenizer.py +++ b/keras_hub/src/tokenizers/byte_pair_tokenizer.py @@ -43,7 +43,11 @@ SPLIT_PATTERN_1 = SPLIT_PATTERN_1.replace( "{special_spaces}", SPECIAL_WHITESPACES ) -SPLIT_PATTERN_2 = rf"""[\są„¬{SPECIAL_WHITESPACES}]$""" + +# The pattern " \t\r\f\v" is the same as \s "all spaces" but without the \n. +# Multiple \n\n\n in sequence must not be split for Llama3. +# SPLIT_PATTERN_2 = rf"""[\są„¬{SPECIAL_WHITESPACES}]$""" +SPLIT_PATTERN_2 = rf"""[ \t\r\f\vą„¬{SPECIAL_WHITESPACES}]$""" def create_alts_for_unsplittable_tokens(unsplittable_tokens): diff --git a/keras_hub/src/tokenizers/byte_pair_tokenizer_test.py b/keras_hub/src/tokenizers/byte_pair_tokenizer_test.py index 5995df2fed..1aef54e214 100644 --- a/keras_hub/src/tokenizers/byte_pair_tokenizer_test.py +++ b/keras_hub/src/tokenizers/byte_pair_tokenizer_test.py @@ -1,5 +1,4 @@ import keras -import pytest import tensorflow as tf from keras_hub.src.tests.test_case import TestCase @@ -15,7 +14,6 @@ ) -@pytest.mark.large class BytePairTokenizerTest(TestCase): def setUp(self): super().setUp() @@ -111,6 +109,24 @@ def test_whitespace_split(self): encoded = self.tokenizer(input_data) self.assertAllEqual(encoded, [1437, 1437, 50140, 50118, 29]) + # This is important for Llama3 which uses the \n\n sequence in chat + # templates: \n\n must be tokenized as a single token + input_data = "Hello\n\nHello" + encoded = self.tokenizer(input_data) + self.assertAllEqual(encoded, [31414, 50140, 31414]) + + input_data = "Hello\n\n\n\nHello" + encoded = self.tokenizer(input_data) + self.assertAllEqual(encoded, [31414, 50140, 50140, 31414]) + + input_data = "Hello\n\n" + encoded = self.tokenizer(input_data) + self.assertAllEqual(encoded, [31414, 50140]) + + input_data = "Hello\n\n\n\n" + encoded = self.tokenizer(input_data) + self.assertAllEqual(encoded, [31414, 50140, 50140]) + def test_special_whitespace(self): input_data = "\xa0 \xa0 \x3000 s" encoded = self.tokenizer(input_data) diff --git a/keras_hub/src/tokenizers/tokenizer.py b/keras_hub/src/tokenizers/tokenizer.py index b97efae444..5e8986a89e 100644 --- a/keras_hub/src/tokenizers/tokenizer.py +++ b/keras_hub/src/tokenizers/tokenizer.py @@ -66,7 +66,7 @@ def detokenize(self, inputs): backbone_cls = None def __init__(self, *args, **kwargs): - self.config_name = kwargs.pop("config_name", TOKENIZER_CONFIG_FILE) + self.config_file = kwargs.pop("config_file", TOKENIZER_CONFIG_FILE) super().__init__(*args, **kwargs) self.file_assets = None @@ -178,7 +178,7 @@ def get_config(self): config = super().get_config() config.update( { - "config_name": self.config_name, + "config_file": self.config_file, } ) return config @@ -199,11 +199,11 @@ def call(self, inputs, *args, training=None, **kwargs): def load_preset_assets(self, preset): asset_path = None for asset in self.file_assets: - subdir = self.config_name.split(".")[0] + subdir = self.config_file.split(".")[0] preset_path = os.path.join(ASSET_DIR, subdir, asset) asset_path = get_file(preset, preset_path) - tokenizer_config_name = os.path.dirname(asset_path) - self.load_assets(tokenizer_config_name) + tokenizer_config_file = os.path.dirname(asset_path) + self.load_assets(tokenizer_config_file) @classproperty def presets(cls): @@ -214,7 +214,7 @@ def presets(cls): def from_preset( cls, preset, - config_name=TOKENIZER_CONFIG_FILE, + config_file=TOKENIZER_CONFIG_FILE, **kwargs, ): """Instantiate a `keras_hub.models.Tokenizer` from a model preset. @@ -260,4 +260,4 @@ class like `keras_hub.models.Tokenizer.from_preset()`, or from backbone_cls = loader.check_backbone_class() if cls.backbone_cls != backbone_cls: cls = find_subclass(preset, cls, backbone_cls) - return loader.load_tokenizer(cls, config_name, **kwargs) + return loader.load_tokenizer(cls, config_file, **kwargs) diff --git a/keras_hub/src/utils/pipeline_model.py b/keras_hub/src/utils/pipeline_model.py index 68bc4d8877..f874b057fe 100644 --- a/keras_hub/src/utils/pipeline_model.py +++ b/keras_hub/src/utils/pipeline_model.py @@ -232,7 +232,7 @@ def train_on_batch( ): data = self.preprocess_samples(x, y, sample_weight) x, y, sample_weight = keras.utils.unpack_x_y_sample_weight(data) - x = ops.convert_to_tensor(x) + x = tree.map_structure(ops.convert_to_tensor, x) if y is not None: y = ops.convert_to_tensor(y) if sample_weight is not None: @@ -253,7 +253,7 @@ def test_on_batch( ): data = self.preprocess_samples(x, y, sample_weight) x, y, sample_weight = keras.utils.unpack_x_y_sample_weight(data) - x = ops.convert_to_tensor(x) + x = tree.map_structure(ops.convert_to_tensor, x) if y is not None: y = ops.convert_to_tensor(y) if sample_weight is not None: @@ -272,7 +272,7 @@ def predict_on_batch( ): data = self.preprocess_samples(x) x, _, _ = keras.utils.unpack_x_y_sample_weight(data) - x = ops.convert_to_tensor(x) + x = tree.map_structure(ops.convert_to_tensor, x) return super().predict_on_batch( x=x, **kwargs, diff --git a/keras_hub/src/utils/preset_utils.py b/keras_hub/src/utils/preset_utils.py index 65af19df7f..52aad373a0 100644 --- a/keras_hub/src/utils/preset_utils.py +++ b/keras_hub/src/utils/preset_utils.py @@ -563,10 +563,8 @@ def get_backbone_kwargs(self, **kwargs): backbone_kwargs["dtype"] = kwargs.pop("dtype", None) # Forward `height` and `width` to backbone when using `TextToImage`. - if "height" in kwargs: - backbone_kwargs["height"] = kwargs.pop("height", None) - if "width" in kwargs: - backbone_kwargs["width"] = kwargs.pop("width", None) + if "image_shape" in kwargs: + backbone_kwargs["image_shape"] = kwargs.pop("image_shape", None) return backbone_kwargs, kwargs @@ -578,7 +576,7 @@ def load_backbone(self, cls, load_weights, **kwargs): """Load the backbone model from the preset.""" raise NotImplementedError - def load_tokenizer(self, cls, config_name=TOKENIZER_CONFIG_FILE, **kwargs): + def load_tokenizer(self, cls, config_file=TOKENIZER_CONFIG_FILE, **kwargs): """Load a tokenizer layer from the preset.""" raise NotImplementedError @@ -609,7 +607,7 @@ def load_task(self, cls, load_weights, load_task_weights, **kwargs): return cls(**kwargs) def load_preprocessor( - self, cls, config_name=PREPROCESSOR_CONFIG_FILE, **kwargs + self, cls, config_file=PREPROCESSOR_CONFIG_FILE, **kwargs ): """Load a prepocessor layer from the preset. @@ -632,8 +630,8 @@ def load_backbone(self, cls, load_weights, **kwargs): backbone.load_weights(get_file(self.preset, MODEL_WEIGHTS_FILE)) return backbone - def load_tokenizer(self, cls, config_name=TOKENIZER_CONFIG_FILE, **kwargs): - tokenizer_config = load_json(self.preset, config_name) + def load_tokenizer(self, cls, config_file=TOKENIZER_CONFIG_FILE, **kwargs): + tokenizer_config = load_json(self.preset, config_file) tokenizer = load_serialized_object(tokenizer_config, **kwargs) if hasattr(tokenizer, "load_preset_assets"): tokenizer.load_preset_assets(self.preset) @@ -678,13 +676,13 @@ def load_task(self, cls, load_weights, load_task_weights, **kwargs): return task def load_preprocessor( - self, cls, config_name=PREPROCESSOR_CONFIG_FILE, **kwargs + self, cls, config_file=PREPROCESSOR_CONFIG_FILE, **kwargs ): # If there is no `preprocessing.json` or it's for the wrong class, # delegate to the super class loader. - if not check_file_exists(self.preset, config_name): + if not check_file_exists(self.preset, config_file): return super().load_preprocessor(cls, **kwargs) - preprocessor_json = load_json(self.preset, config_name) + preprocessor_json = load_json(self.preset, config_file) if not issubclass(check_config_class(preprocessor_json), cls): return super().load_preprocessor(cls, **kwargs) # We found a `preprocessing.json` with a complete config for our class. diff --git a/keras_hub/src/utils/preset_utils_test.py b/keras_hub/src/utils/preset_utils_test.py index 00baf28235..9d36428698 100644 --- a/keras_hub/src/utils/preset_utils_test.py +++ b/keras_hub/src/utils/preset_utils_test.py @@ -18,6 +18,7 @@ class PresetUtilsTest(TestCase): + @pytest.mark.large def test_preset_errors(self): with self.assertRaisesRegex(ValueError, "must be a string"): AlbertTextClassifier.from_preset(AlbertTextClassifier) @@ -34,6 +35,7 @@ def test_preset_errors(self): with self.assertRaisesRegex(ValueError, "class keras_hub>BortBackbone"): BertBackbone.from_preset(preset_dir) + @pytest.mark.large def test_upload_empty_preset(self): temp_dir = self.get_temp_dir() empty_preset = os.path.join(temp_dir, "empty") diff --git a/keras_hub/src/utils/timm/convert_mobilenet.py b/keras_hub/src/utils/timm/convert_mobilenet.py index 7b4cf4c8e1..e2de6d8e34 100644 --- a/keras_hub/src/utils/timm/convert_mobilenet.py +++ b/keras_hub/src/utils/timm/convert_mobilenet.py @@ -30,8 +30,8 @@ def convert_backbone_config(timm_config): stackwise_se_ratio = [ [None, None], [0.25, 0.25, 0.25], - [0.3, 0.3], - [0.3, 0.25, 0.25], + [0.25, 0.25], + [0.25, 0.25, 0.25], ] stackwise_activation = [ ["relu", "relu"], @@ -39,64 +39,12 @@ def convert_backbone_config(timm_config): ["hard_swish", "hard_swish"], ["hard_swish", "hard_swish", "hard_swish"], ] + stackwise_padding = [[1, 1], [2, 2, 2], [2, 2], [2, 2, 2]] output_num_filters = 1024 input_num_filters = 16 depthwise_filters = 8 squeeze_and_excite = 0.5 last_layer_filter = 288 - - # elif timm_architecture == "mobilenetv2_050": - # stackwise_num_blocks = ([2, 3, 4, 3, 3, 1],) - # stackwise_expansion = ( - # [ - # [48, 96], - # [96, 96, 96], - # [96, 192, 192, 192], - # [192, 288, 288], - # [288, 480, 480], - # [480], - # ], - # ) - # stackwise_num_filters = ( - # [ - # [16, 16], - # [16, 16, 16], - # [32, 32, 32, 32], - # [48, 48, 48], - # [80, 80, 80], - # [160], - # ], - # ) - # stackwise_kernel_size = ( - # [[3, 3], [3, 3, 3], [3, 3, 3, 3], [3, 3, 3], [3, 3, 3], [3]], - # ) - # stackwise_num_strides = ( - # [[2, 1], [2, 1, 1], [2, 1, 1, 1], [1, 1, 1], [2, 1, 1], [1]], - # ) - # stackwise_se_ratio = ( - # [ - # [None, None], - # [None, None, None], - # [None, None, None, None], - # [None, None, None], - # [None, None, None], - # [None], - # ], - # ) - # stackwise_activation = ( - # [ - # ["relu6", "relu6"], - # ["relu6", "relu6", "relu6"], - # ["relu6", "relu6", "relu6", "relu6"], - # ["relu6", "relu6", "relu6"], - # ["relu6", "relu6", "relu6"], - # ["relu6"], - # ], - # ) - # output_num_filters = 1280 - # input_num_filters = 16 - # depthwise_filters = 8 - # squeeze_and_excite = None else: raise ValueError( f"Currently, the architecture {timm_architecture} is not supported." @@ -114,6 +62,7 @@ def convert_backbone_config(timm_config): stackwise_num_strides=stackwise_num_strides, stackwise_se_ratio=stackwise_se_ratio, stackwise_activation=stackwise_activation, + stackwise_padding=stackwise_padding, output_num_filters=output_num_filters, output_activation=output_activation, last_layer_filter=last_layer_filter, @@ -122,6 +71,7 @@ def convert_backbone_config(timm_config): def convert_weights(backbone, loader, timm_config): def port_conv2d(keras_layer_name, hf_weight_prefix): + print(f"porting weights {hf_weight_prefix} -> {keras_layer_name}") loader.port_weight( backbone.get_layer(keras_layer_name).kernel, hf_weight_key=f"{hf_weight_prefix}.weight", @@ -129,6 +79,7 @@ def port_conv2d(keras_layer_name, hf_weight_prefix): ) def port_batch_normalization(keras_layer_name, hf_weight_prefix): + print(f"porting weights {hf_weight_prefix} -> {keras_layer_name}") loader.port_weight( backbone.get_layer(keras_layer_name).gamma, hf_weight_key=f"{hf_weight_prefix}.weight", @@ -145,9 +96,11 @@ def port_batch_normalization(keras_layer_name, hf_weight_prefix): backbone.get_layer(keras_layer_name).moving_variance, hf_weight_key=f"{hf_weight_prefix}.running_var", ) - - version = "v3" if backbone.output_activation == "hard_swish" else "v2" - + loader.port_weight( + backbone.get_layer(keras_layer_name).moving_variance, + hf_weight_key=f"{hf_weight_prefix}.running_var", + ) + # Stem port_conv2d("input_conv", "conv_stem") port_batch_normalization("input_batch_norm", "bn1") @@ -155,6 +108,7 @@ def port_batch_normalization(keras_layer_name, hf_weight_prefix): # DepthWise Block (block 0) hf_name = "blocks.0.0" keras_name = "block_0_0" + port_conv2d(f"{keras_name}_conv1", f"{hf_name}.conv_dw") port_batch_normalization(f"{keras_name}_bn1", f"{hf_name}.bn1") @@ -196,14 +150,10 @@ def port_batch_normalization(keras_layer_name, hf_weight_prefix): f"block_{num_stacks+1}_0_bn", f"blocks.{num_stacks+1}.0.bn1" ) - if version == "v3": - hf_name = f"blocks.{num_stacks+1}.0" - keras_name = "Dfs" port_conv2d("output_conv", "conv_head") # if version == "v2": # port_batch_normalization("output_batch_norm", "bn2") - def convert_head(task, loader, timm_config): prefix = "classifier." loader.port_weight( diff --git a/keras_hub/src/utils/timm/convert_mobilenet_test.py b/keras_hub/src/utils/timm/convert_mobilenet_test.py index 4d036ae033..59c504b306 100644 --- a/keras_hub/src/utils/timm/convert_mobilenet_test.py +++ b/keras_hub/src/utils/timm/convert_mobilenet_test.py @@ -13,7 +13,7 @@ def test_convert_mobilenet_backbone(self): "hf://timm/mobilenetv3_small_050.lamb_in1k" ) outputs = model.predict(ops.ones((1, 224, 224, 3))) - self.assertEqual(outputs.shape, (1, 14, 14, 1024)) + self.assertEqual(outputs.shape, (1, 7, 7, 1024)) @pytest.mark.large def test_convert_mobilenet_classifier(self): diff --git a/keras_hub/src/utils/timm/convert_vgg.py b/keras_hub/src/utils/timm/convert_vgg.py new file mode 100644 index 0000000000..445d0ee436 --- /dev/null +++ b/keras_hub/src/utils/timm/convert_vgg.py @@ -0,0 +1,85 @@ +from typing import Any + +import numpy as np + +from keras_hub.src.models.vgg.vgg_backbone import VGGBackbone +from keras_hub.src.models.vgg.vgg_image_classifier import VGGImageClassifier + +backbone_cls = VGGBackbone + + +REPEATS_BY_SIZE = { + "vgg11": [1, 1, 2, 2, 2], + "vgg13": [2, 2, 2, 2, 2], + "vgg16": [2, 2, 3, 3, 3], + "vgg19": [2, 2, 4, 4, 4], +} + + +def convert_backbone_config(timm_config): + architecture = timm_config["architecture"] + stackwise_num_repeats = REPEATS_BY_SIZE[architecture] + return dict( + stackwise_num_repeats=stackwise_num_repeats, + stackwise_num_filters=[64, 128, 256, 512, 512], + ) + + +def convert_conv2d( + model, + loader, + keras_layer_name: str, + hf_layer_name: str, +): + loader.port_weight( + model.get_layer(keras_layer_name).kernel, + hf_weight_key=f"{hf_layer_name}.weight", + hook_fn=lambda x, _: np.transpose(x, (2, 3, 1, 0)), + ) + loader.port_weight( + model.get_layer(keras_layer_name).bias, + hf_weight_key=f"{hf_layer_name}.bias", + ) + + +def convert_weights( + backbone: VGGBackbone, + loader, + timm_config: dict[Any], +): + architecture = timm_config["architecture"] + stackwise_num_repeats = REPEATS_BY_SIZE[architecture] + + hf_index_to_keras_layer_name = {} + layer_index = 0 + for block_index, repeats_in_block in enumerate(stackwise_num_repeats): + for repeat_index in range(repeats_in_block): + hf_index = layer_index + layer_index += 2 # Conv + activation layers. + layer_name = f"block{block_index + 1}_conv{repeat_index + 1}" + hf_index_to_keras_layer_name[hf_index] = layer_name + layer_index += 1 # Pooling layer after blocks. + + for hf_index, keras_layer_name in hf_index_to_keras_layer_name.items(): + convert_conv2d( + backbone, loader, keras_layer_name, f"features.{hf_index}" + ) + + +def convert_head( + task: VGGImageClassifier, + loader, + timm_config: dict[Any], +): + convert_conv2d(task.head, loader, "fc1", "pre_logits.fc1") + convert_conv2d(task.head, loader, "fc2", "pre_logits.fc2") + + loader.port_weight( + task.head.get_layer("predictions").kernel, + hf_weight_key="head.fc.weight", + hook_fn=lambda x, _: np.transpose(np.squeeze(x)), + ) + loader.port_weight( + task.head.get_layer("predictions").bias, + hf_weight_key="head.fc.bias", + ) diff --git a/keras_hub/src/utils/timm/preset_loader.py b/keras_hub/src/utils/timm/preset_loader.py index 65149b042f..392f432bb1 100644 --- a/keras_hub/src/utils/timm/preset_loader.py +++ b/keras_hub/src/utils/timm/preset_loader.py @@ -6,6 +6,7 @@ from keras_hub.src.utils.timm import convert_densenet from keras_hub.src.utils.timm import convert_mobilenet from keras_hub.src.utils.timm import convert_resnet +from keras_hub.src.utils.timm import convert_vgg from keras_hub.src.utils.transformers.safetensor_utils import SafetensorLoader @@ -19,6 +20,8 @@ def __init__(self, preset, config): self.converter = convert_densenet elif "mobilenet" in architecture: self.converter = convert_mobilenet + elif "vgg" in architecture: + self.converter = convert_vgg else: raise ValueError( "KerasHub has no converter for timm models " diff --git a/keras_hub/src/utils/transformers/convert_llama3.py b/keras_hub/src/utils/transformers/convert_llama3.py index 08e982e862..75c7eb801c 100644 --- a/keras_hub/src/utils/transformers/convert_llama3.py +++ b/keras_hub/src/utils/transformers/convert_llama3.py @@ -107,10 +107,26 @@ def convert_tokenizer(cls, preset, **kwargs): vocab = tokenizer_config["model"]["vocab"] merges = tokenizer_config["model"]["merges"] - bot = tokenizer_config["added_tokens"][0] # begin of text - eot = tokenizer_config["added_tokens"][1] # end of text - - vocab[bot["content"]] = bot["id"] - vocab[eot["content"]] = eot["id"] + # Load all special tokens with the exception of "reserved" ones. + special_tokens = set() + for token in tokenizer_config["added_tokens"]: + if not token["content"].startswith("<|reserved_special_token_"): + vocab[token["content"]] = token["id"] + special_tokens.add(token["content"]) + + # Load text start and stop tokens from the config. + # Llama3 uses the <|end_of_text|> end token for regular models + # but uses <|eot_id|> for instruction-tuned variants. + tokenizer_config2 = load_json(preset, "tokenizer_config.json") + bos_token = tokenizer_config2["bos_token"] + eos_token = tokenizer_config2["eos_token"] + + kwargs.update( + { + "bos_token": bos_token, + "eos_token": eos_token, + "misc_special_tokens": special_tokens, + } + ) return cls(vocabulary=vocab, merges=merges, **kwargs) diff --git a/keras_hub/src/version_utils.py b/keras_hub/src/version_utils.py index 1b36b8e41f..0a67b13192 100644 --- a/keras_hub/src/version_utils.py +++ b/keras_hub/src/version_utils.py @@ -1,7 +1,7 @@ from keras_hub.src.api_export import keras_hub_export # Unique source of truth for the version number. -__version__ = "0.16.1" +__version__ = "0.17.0.dev0" @keras_hub_export("keras_hub.version") diff --git a/tools/checkpoint_conversion/convert_mix_transformer.py b/tools/checkpoint_conversion/convert_mix_transformer.py new file mode 100644 index 0000000000..6419cc405e --- /dev/null +++ b/tools/checkpoint_conversion/convert_mix_transformer.py @@ -0,0 +1,196 @@ +# Usage example +# python tools/checkpoint_conversion/convert_mix_transformer.py --preset "B0_ade_512" + +from absl import app +from absl import flags +from transformers import SegformerForSemanticSegmentation + +import keras_hub + +FLAGS = flags.FLAGS + + +DOWNLOAD_URLS = { + "B0_ade_512": "nvidia/segformer-b0-finetuned-ade-512-512", + "B1_ade_512": "nvidia/segformer-b1-finetuned-ade-512-512", + "B2_ade_512": "nvidia/segformer-b2-finetuned-ade-512-512", + "B3_ade_512": "nvidia/segformer-b3-finetuned-ade-512-512", + "B4_ade_512": "nvidia/segformer-b4-finetuned-ade-512-512", + "B5_ade_640": "nvidia/segformer-b5-finetuned-ade-640-640", + "B0_cityscapes_1024": "nvidia/segformer-b0-finetuned-cityscapes-1024-1024", + "B1_cityscapes_1024": "nvidia/segformer-b1-finetuned-cityscapes-1024-1024", + "B2_cityscapes_1024": "nvidia/segformer-b2-finetuned-cityscapes-1024-1024", + "B3_cityscapes_1024": "nvidia/segformer-b3-finetuned-cityscapes-1024-1024", + "B4_cityscapes_1024": "nvidia/segformer-b4-finetuned-cityscapes-1024-1024", + "B5_cityscapes_1024": "nvidia/segformer-b5-finetuned-cityscapes-1024-1024", +} + + +MODEL_CONFIGS = { + "B0": {"hidden_dims": [32, 64, 160, 256], "depths": [2, 2, 2, 2]}, + "B1": {"hidden_dims": [64, 128, 320, 512], "depths": [2, 2, 2, 2]}, + "B2": {"hidden_dims": [64, 128, 320, 512], "depths": [3, 4, 6, 3]}, + "B3": {"hidden_dims": [64, 128, 320, 512], "depths": [3, 4, 18, 3]}, + "B4": {"hidden_dims": [64, 128, 320, 512], "depths": [3, 8, 27, 3]}, + "B5": {"hidden_dims": [64, 128, 320, 512], "depths": [3, 6, 40, 3]}, +} + +flags.DEFINE_string( + "preset", None, f'Must be one of {",".join(DOWNLOAD_URLS.keys())}' +) + + +def get_indices_from_depths(depths): + proj_indices = [] + norm_indices = [] + hierarchical_encoder_indices = [] + + current_layer_idx = 1 + + for layer_idx, depth in enumerate(depths): + # Add projection index (before the hierarchical encoders) + proj_indices.append(current_layer_idx) + + # Hierarchical encoder block indices + for block_idx in range(depth): + hierarchical_encoder_indices.append( + (current_layer_idx + 1, layer_idx, block_idx) + ) + current_layer_idx += 1 + + # Add normalization index (after the hierarchical encoders) + norm_indices.append(current_layer_idx + 1) + + # Skip to the next layer after output_level + current_layer_idx += 3 + + return proj_indices, norm_indices, hierarchical_encoder_indices + + +def set_conv_weights(conv_layer, state_dict): + conv_weights = state_dict["weight"].numpy().transpose(2, 3, 1, 0) + conv_bias = state_dict["bias"].numpy() + conv_layer.set_weights([conv_weights, conv_bias]) + + +def set_dwconv_weights(conv_layer, state_dict): + conv_weights = state_dict["dwconv.weight"].numpy().transpose(2, 3, 0, 1) + conv_bias = state_dict["dwconv.bias"].numpy() + conv_layer.set_weights([conv_weights, conv_bias]) + + +def set_layer_norm_weights(layer_norm, state_dict): + gamma = state_dict["weight"].numpy() + beta = state_dict["bias"].numpy() + layer_norm.set_weights([gamma, beta]) + + +def set_dense_weights(dense_layer, state_dict): + weight = state_dict["weight"].numpy().T + bias = state_dict["bias"].numpy() + dense_layer.set_weights([weight, bias]) + + +def set_hierarchical_encoder_weights(keras_layer, pytorch_layer, key): + + set_layer_norm_weights( + keras_layer.norm1, pytorch_layer.layer_norm_1.state_dict() + ) + + set_dense_weights( + keras_layer.attn.q, pytorch_layer.attention.self.query.state_dict() + ) + set_dense_weights( + keras_layer.attn.k, pytorch_layer.attention.self.key.state_dict() + ) + set_dense_weights( + keras_layer.attn.v, pytorch_layer.attention.self.value.state_dict() + ) + set_dense_weights( + keras_layer.attn.proj, pytorch_layer.attention.output.dense.state_dict() + ) + + if keras_layer.attn.sr_ratio > 1: + set_conv_weights( + keras_layer.attn.sr, pytorch_layer.attention.self.sr.state_dict() + ) + set_layer_norm_weights( + keras_layer.attn.norm, + pytorch_layer.attention.self.layer_norm.state_dict(), + ) + + set_layer_norm_weights( + keras_layer.norm2, pytorch_layer.layer_norm_2.state_dict() + ) + + set_dense_weights( + keras_layer.mlp.fc1, pytorch_layer.mlp.dense1.state_dict() + ) + set_dwconv_weights( + keras_layer.mlp.dwconv, pytorch_layer.mlp.dwconv.state_dict() + ) + set_dense_weights( + keras_layer.mlp.fc2, pytorch_layer.mlp.dense2.state_dict() + ) + + +def main(_): + print("\n-> Loading HuggingFace model") + model = SegformerForSemanticSegmentation.from_pretrained( + DOWNLOAD_URLS[FLAGS.preset] + ) + original_mit = original_mit = model.segformer.encoder + + model_type = FLAGS.preset.split("_")[0] + print("\n-> Instantiating KerasHub Model") + keras_mit = keras_hub.models.MiTBackbone( + depths=MODEL_CONFIGS[model_type]["depths"], + image_shape=(224, 224, 3), + hidden_dims=MODEL_CONFIGS[model_type]["hidden_dims"], + num_layers=4, + blockwise_num_heads=[1, 2, 5, 8], + blockwise_sr_ratios=[8, 4, 2, 1], + max_drop_path_rate=0.1, + patch_sizes=[7, 3, 3, 3], + strides=[4, 2, 2, 2], + ) + + # Indices for the different patch embeddings and layer norms + proj_indices, layer_norm_indices, hierarchical_encoder_indices = ( + get_indices_from_depths(MODEL_CONFIGS[model_type]["depths"]) + ) + + print("\n-> Converting weights...") + # Loop through the indices to set convolutional and normalization weights + for i, idx in enumerate(proj_indices): + set_conv_weights( + keras_mit.layers[idx].proj, + original_mit.patch_embeddings[i].proj.state_dict(), + ) + set_layer_norm_weights( + keras_mit.layers[idx].norm, + original_mit.patch_embeddings[i].layer_norm.state_dict(), + ) + + # Set layer normalization weights + for i, idx in enumerate(layer_norm_indices): + set_layer_norm_weights( + keras_mit.layers[idx], original_mit.layer_norm[i].state_dict() + ) + + # Set hierarchical encoder weights + for layer_idx, block_idx, key in hierarchical_encoder_indices: + set_hierarchical_encoder_weights( + keras_mit.layers[layer_idx], + original_mit.block[block_idx][int(key)], + key=key, + ) + + directory = f"MiT_{FLAGS.preset}" + print(f"\n-> Saving converted KerasHub model in {directory}") + keras_mit.save_to_preset(directory) + + +if __name__ == "__main__": + flags.mark_flag_as_required("preset") + app.run(main) diff --git a/tools/checkpoint_conversion/convert_pali_gemma_checkpoints.py b/tools/checkpoint_conversion/convert_pali_gemma_checkpoints.py index d21bb8d82d..befb6093cf 100644 --- a/tools/checkpoint_conversion/convert_pali_gemma_checkpoints.py +++ b/tools/checkpoint_conversion/convert_pali_gemma_checkpoints.py @@ -1,17 +1,47 @@ +""" +python tools/checkpoint_conversion/convert_pali_gemma_checkpoints.py \ + --weights_path=paligemma-3b-mix-224.npz \ + --image_size=224 --checkpoint_name=pali_gemma_3b_mix_224 +python tools/checkpoint_conversion/convert_pali_gemma_checkpoints.py \ + --weights_path=paligemma-3b-mix-448.npz \ + --image_size=448 --checkpoint_name=pali_gemma_3b_mix_448 +python tools/checkpoint_conversion/convert_pali_gemma_checkpoints.py \ + --weights_path=paligemma-3b-pt-224.npz \ + --image_size=224 --checkpoint_name=pali_gemma_3b_224 +python tools/checkpoint_conversion/convert_pali_gemma_checkpoints.py \ + --weights_path=paligemma-3b-pt-448.npz \ + --image_size=448 --checkpoint_name=pali_gemma_3b_448 +python tools/checkpoint_conversion/convert_pali_gemma_checkpoints.py \ + --weights_path=paligemma-3b-pt-896.npz \ + --image_size=896 --checkpoint_name=pali_gemma_3b_896 +""" + import argparse import os import numpy as np +from keras_hub.src.models.pali_gemma.pali_gemma_backbone import ( + PaliGemmaBackbone, +) +from keras_hub.src.models.pali_gemma.pali_gemma_causal_lm import ( + PaliGemmaCausalLM, +) +from keras_hub.src.models.pali_gemma.pali_gemma_causal_lm_preprocessor import ( + PaliGemmaCausalLMPreprocessor, +) +from keras_hub.src.models.pali_gemma.pali_gemma_image_converter import ( + PaliGemmaImageConverter, +) +from keras_hub.src.models.pali_gemma.pali_gemma_tokenizer import ( + PaliGemmaTokenizer, +) + os.environ["KERAS_BACKEND"] = "jax" import keras # noqa: E402 from keras import ops # noqa: E402 -from keras_hub.src.models.pali_gemma.pali_gemma_backbone import ( # noqa: E402 - PaliGemmaBackbone, -) - # No GPU for conversion, makes memory management easier. os.environ["CUDA_VISIBLE_DEVICES"] = "-1" @@ -299,14 +329,39 @@ def main(args): pali_gemma_backbone_config = { "vit_num_layers": 27, "vit_hidden_dim": 1152, + "vocabulary_size": 257152, "image_size": args.image_size, + "num_layers": 18, + "num_query_heads": 8, + "num_key_value_heads": 1, + "hidden_dim": 2048, + "intermediate_dim": 32768, + "head_dim": 256, + "vit_patch_size": 14, + "vit_num_heads": 16, } - keras_model = PaliGemmaBackbone(**pali_gemma_backbone_config) + pg_image_converter = PaliGemmaImageConverter( + image_size=(args.image_size, args.image_size), + scale=1.0 / 127.5, + offset=-1, + ) + tokenizer = PaliGemmaTokenizer( + proto="vocabulary.spm", + ) + pg_presprocessor = PaliGemmaCausalLMPreprocessor( + tokenizer=tokenizer, image_converter=pg_image_converter + ) + pg_backbone = PaliGemmaBackbone(**pali_gemma_backbone_config) + keras_model = PaliGemmaCausalLM( + preprocessor=pg_presprocessor, backbone=pg_backbone + ) # This could be from kaggle or provide local dir path weights = np.load(args.weights_path) jax_weights = get_weights_as_numpy(weights, **pali_gemma_backbone_config) - keras_model = convert_pali_gemma_weights( - keras_model, jax_weights["params"], **pali_gemma_backbone_config + keras_model.backbone = convert_pali_gemma_weights( + keras_model.backbone, + jax_weights["params"], + **pali_gemma_backbone_config, ) # Specify preset name keras_model.save_to_preset(args.checkpoint_name) diff --git a/tools/checkpoint_conversion/convert_resnet_vd_checkpoints.py b/tools/checkpoint_conversion/convert_resnet_vd_checkpoints.py new file mode 100644 index 0000000000..a8a424f494 --- /dev/null +++ b/tools/checkpoint_conversion/convert_resnet_vd_checkpoints.py @@ -0,0 +1,317 @@ +#!/usr/bin/env python3 + +"""Converts ResNet_vd models from PaddleClas. + +Usage: python3 convert_resnet_vd_checkpoints.py + +ResNet_vd model weights from PaddleClas listed in `configurations` below will +be downloaded, saved as Keras model files and the resulting models will be +verified for numerical agreement with PaddleClas. + +Requirements: +pip3 install -q git+https://github.com/keras-team/keras-hub.git +pip3 install -q paddleclas paddlepaddle +""" + +import os +import re +import tarfile +import urllib.request + +import keras +import numpy as np +import paddle +import paddleclas +from paddleclas.deploy.python import preprocess as pc_preproc +from PIL import Image + +import keras_hub + +"""Architecture Specifications""" + +configurations = { + "ResNet18_vd": { + "stackwise_num_blocks": [2, 2, 2, 2], + "block_type": "basic_block_vd", + }, + "ResNet34_vd": { + "stackwise_num_blocks": [3, 4, 6, 3], + "block_type": "basic_block_vd", + }, + "ResNet50_vd": { + "stackwise_num_blocks": [3, 4, 6, 3], + "block_type": "bottleneck_block_vd", + }, + "ResNet50_vd_ssld": { + "stackwise_num_blocks": [3, 4, 6, 3], + "block_type": "bottleneck_block_vd", + }, + "ResNet50_vd_ssld_v2": { + "stackwise_num_blocks": [3, 4, 6, 3], + "block_type": "bottleneck_block_vd", + }, + "Fix_ResNet50_vd_ssld_v2": { + "stackwise_num_blocks": [3, 4, 6, 3], + "block_type": "bottleneck_block_vd", + }, + "ResNet101_vd": { + "stackwise_num_blocks": [3, 4, 23, 3], + "block_type": "bottleneck_block_vd", + }, + "ResNet101_vd_ssld": { + "stackwise_num_blocks": [3, 4, 23, 3], + "block_type": "bottleneck_block_vd", + }, + "ResNet152_vd": { + "stackwise_num_blocks": [3, 8, 36, 3], + "block_type": "bottleneck_block_vd", + }, + "ResNet200_vd": { + "stackwise_num_blocks": [3, 12, 48, 3], + "block_type": "bottleneck_block_vd", + }, +} + + +"""Download Files""" + +# Create the directory if it doesn't exist +os.makedirs("pretrained_models", exist_ok=True) +base_url = "https://paddle-imagenet-models-name.bj.bcebos.com/" + +for arch in configurations.keys(): + tar_file = f"{arch}_pretrained.tar" + download_url = f"{base_url}{tar_file}" + file_path = os.path.join("pretrained_models", tar_file) + + # Download the tar file + print(f"Downloading {tar_file}...") + urllib.request.urlretrieve(download_url, file_path) + + # Extract the tar file + print(f"Extracting {tar_file}...") + with tarfile.open(file_path, "r") as tar: + tar.extractall(path="pretrained_models", filter="data") + + +"""Model Conversion""" + + +def convert_paddle_to_keras(paddle_weights: dict, keras_model: keras.Model): + """Ports a paddle weights dictionary to a Keras model.""" + + def map_residual_layer_name(name: str): + """Translate a Keras ResNet_vd layer name to a PaddleClas ResNet + layer name prefix for a residual block.""" + branch_mapping = { + # this suffix addresses the specific conv layer within a block + 0: "1", + 1: "2a", + 2: "2b", + 3: "2c", + } + match = re.match( + r"^stack(?P\d)_block(?P\d+)_(?P\d)_(?Pbn|conv)", + name, + ) + assert match is not None + + # ResNet models have two different formats of layer name encodings + # in PaddleClas. first try a mapping in the form + # stack2_block3_1_conv -> res4b2_branch2a + paddle_address = ( + f'{int(match["stack"])+2}b{int(match["block"])}' + f'_branch{branch_mapping[int(match["conv"])]}' + ) + if match["type"] == "bn": + paddle_name = f"bn{paddle_address}" + elif match["type"] == "conv": + paddle_name = f"res{paddle_address}" + if any(name.startswith(paddle_name) for name in paddle_weights): + return paddle_name + + # if that was not successful, try a mapping like + # stack2_block3_1_conv -> res4c_branch2a + paddle_address = ( + f'{int(match["stack"])+2}{"abcdefghijkl"[int(match["block"])]}' + f'_branch{branch_mapping[int(match["conv"])]}' + ) + if match["type"] == "bn": + paddle_name = f"bn{paddle_address}" + elif match["type"] == "conv": + paddle_name = f"res{paddle_address}" + return paddle_name + + def map_layer_name(name: str): + """Translate a Keras ResNet_vd layer name to a PaddleClas ResNet layer + name prefix.""" + mapping = { + # stem layers + "conv1_conv": "conv1_1", + "conv1_bn": "bnv1_1", + "conv2_conv": "conv1_2", + "conv2_bn": "bnv1_2", + "conv3_conv": "conv1_3", + "conv3_bn": "bnv1_3", + } + return mapping.get(name) or map_residual_layer_name(name) + + def set_batchnorm_layer( + paddle_name_prefix: str, target_layer: keras.layers.Layer + ): + """Assign Keras BatchNorm layer weigths from Paddle weights.""" + target_layer.set_weights( + [ + paddle_weights.pop(f"{paddle_name_prefix}_scale"), + paddle_weights.pop(f"{paddle_name_prefix}_offset"), + paddle_weights.pop(f"{paddle_name_prefix}_mean"), + paddle_weights.pop(f"{paddle_name_prefix}_variance"), + ] + ) + + def set_conv_layer( + paddle_name_prefix: str, target_layer: keras.layers.Layer + ): + """Assign Keras Conv2D layer weights from Paddle weights.""" + if target_layer.use_bias: + target_layer.set_weights( + [ + np.transpose( + paddle_weights.pop(f"{paddle_name_prefix}_weights"), + (2, 3, 1, 0), + ), + paddle_weights.pop(f"{paddle_name_prefix}_bias"), + ] + ) + else: + target_layer.set_weights( + [ + np.transpose( + paddle_weights.pop(f"{paddle_name_prefix}_weights"), + (2, 3, 1, 0), + ) + ] + ) + + def set_dense_layer( + paddle_name_prefix: str, target_layer: keras.layers.Layer + ): + """Assign Keras Dense layer weights from Paddle weights.""" + if target_layer.use_bias: + target_layer.set_weights( + [ + paddle_weights.pop(f"{paddle_name_prefix}.w_0"), + paddle_weights.pop(f"{paddle_name_prefix}.b_0"), + ] + ) + else: + target_layer.set_weights( + [paddle_weights.pop(f"{paddle_name_prefix}.w_0")] + ) + + for layer in keras_model.backbone.layers: + # iterate over all layers that have parameters in the keras model, + # to ensure we process all weights in the Keras model + if layer.variables: + if isinstance(layer, keras.layers.Conv2D): + set_conv_layer(map_layer_name(layer.name), layer) + elif isinstance(layer, keras.layers.BatchNormalization): + set_batchnorm_layer(map_layer_name(layer.name), layer) + else: + raise TypeError("Unexpected layer type encountered in model") + set_dense_layer("fc_0", keras_model.get_layer("predictions")) + + # ensure we have consumed all weights, i.e. there are no leftover + # weights in the paddle model + assert len(paddle_weights) == 0 + + +"""Instantiate model architectures as indicated above and load PaddleClas +weights into the Keras model""" + +for architecture_name, architecture_config in configurations.items(): + print(f"Converting {architecture_name}") + backbone_model = keras_hub.models.ResNetBackbone( + input_conv_filters=[32, 32, 64], + input_conv_kernel_sizes=[3, 3, 3], + stackwise_num_filters=[64, 128, 256, 512], + stackwise_num_strides=[1, 2, 2, 2], + **architecture_config, + ) + image_converter = keras_hub.layers.ResNetImageConverter( + height=224, + width=224, + mean=[0.485, 0.456, 0.406], + variance=[0.229**2, 0.224**2, 0.225**2], + scale=1 / 255.0, + ) + resnet_preprocessor = keras_hub.models.ResNetImageClassifierPreprocessor( + image_converter + ) + classifier_model = keras_hub.models.ResNetImageClassifier( + backbone=backbone_model, + preprocessor=resnet_preprocessor, + num_classes=1000, + ) + paddle_model = paddle.load( + f"pretrained_models/{architecture_name}_pretrained" + ) + convert_paddle_to_keras(paddle_model, classifier_model) + classifier_model.save(f"{architecture_name}.keras") + classifier_model.save_to_preset(f"{architecture_name}") + print(f"Parameter count: {classifier_model.count_params()}") + +"""Check for Numerical Agreement + +Compare results when using PaddleClas with results when using our Keras models. +In general, PaddleClas appears to mainly target command-line utilisation +rather than offering an API. While PaddleClas model architectures can directly +be instantiated, this interface strangely only provides some of the pretrained +models (and doesn't appear to be documented anywhere). + +To ensure behaviour and performances when using PaddleClas as command-line tool +match our observed results, we here use `PaddleClas` directly. +""" + +urllib.request.urlretrieve( + "https://storage.googleapis.com/tensorflow/keras-applications/tests/elephant.jpg", + "elephant.jpg", +) + +print(f'{"Model": <25}Error') +for architecture_name in configurations: + # PaddleClas prediction + predictor = paddleclas.PaddleClas(model_name=architecture_name).predictor + # PaddleClas selects the top 5 predictions during + # postprocessing. turn this off. + predictor.postprocess = None + # for comparable results, manually perform resizing and cropping + preprocess_ops = [ + op + for op in predictor.preprocess_ops + if isinstance( + op, + ( + pc_preproc.NormalizeImage, + pc_preproc.ResizeImage, + pc_preproc.CropImage, + ), + ) + ] + predictor.preprocess_ops = [ + op for op in predictor.preprocess_ops if op not in preprocess_ops + ] + image = np.asarray(Image.open("elephant.jpg"), dtype=np.float32) + for op in preprocess_ops: + image = op(image) + paddle_prediction = predictor.predict(image) + + # Keras prediction + # in contrast to PaddleClas, Keras' predictions are not softmax'ed + keras_model = keras.saving.load_model(f"{architecture_name}.keras") + keras_prediction = keras_model(image[None]).numpy() + keras_prediction = keras.ops.softmax(keras_prediction) + + # compare + max_error = np.max(np.abs(paddle_prediction - keras_prediction)) + print(f"{architecture_name: <25}{max_error}") diff --git a/tools/checkpoint_conversion/convert_sam_checkpoints.py b/tools/checkpoint_conversion/convert_sam_checkpoints.py index 08f4f4a504..69cd1482cb 100644 --- a/tools/checkpoint_conversion/convert_sam_checkpoints.py +++ b/tools/checkpoint_conversion/convert_sam_checkpoints.py @@ -1,3 +1,7 @@ +# Get the huge PyTorch model weights from the following location +# curl -sSL https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth -o sam_vit_h_4b8939.pth +# curl -sSL https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth -o sam_vit_l_0b3195.pth +# curl -sSL https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth -o sam_vit_b_01ec64.pth import argparse import os diff --git a/tools/checkpoint_conversion/convert_segformer_checkpoints.py b/tools/checkpoint_conversion/convert_segformer_checkpoints.py new file mode 100644 index 0000000000..230cf5227d --- /dev/null +++ b/tools/checkpoint_conversion/convert_segformer_checkpoints.py @@ -0,0 +1,143 @@ +# Usage example +# python tools/checkpoint_conversion/convert_mix_transformer.py --preset "B0_ade_512" + +import numpy as np +from absl import app +from absl import flags +from transformers import SegformerForSemanticSegmentation + +import keras_hub +from keras_hub.src.models.segformer.segformer_image_segmenter_preprocessor import ( + SegFormerImageSegmenterPreprocessor, +) + +FLAGS = flags.FLAGS + +PROJECTION_FILTERS = { + "b0_ade20k_512": 256, + "b1_ade20k_512": 256, + "b2_ade20k_512": 768, + "b3_ade20k_512": 768, + "b4_ade20k_512": 768, + "b5_ade20k_640": 768, + "b0_cityscapes_1024": 256, + "b1_cityscapes_1024": 256, + "b2_cityscapes_1024": 768, + "b3_cityscapes_1024": 768, + "b4_cityscapes_1024": 768, + "b5_cityscapes_1024": 768, +} + + +DOWNLOAD_URLS = { + "b0_ade20k_512": "nvidia/segformer-b0-finetuned-ade-512-512", + "b1_ade20k_512": "nvidia/segformer-b1-finetuned-ade-512-512", + "b2_ade20k_512": "nvidia/segformer-b2-finetuned-ade-512-512", + "b3_ade20k_512": "nvidia/segformer-b3-finetuned-ade-512-512", + "b4_ade20k_512": "nvidia/segformer-b4-finetuned-ade-512-512", + "b5_ade20k_640": "nvidia/segformer-b5-finetuned-ade-640-640", + "b0_cityscapes_1024": "nvidia/segformer-b0-finetuned-cityscapes-1024-1024", + "b1_cityscapes_1024": "nvidia/segformer-b1-finetuned-cityscapes-1024-1024", + "b2_cityscapes_1024": "nvidia/segformer-b2-finetuned-cityscapes-1024-1024", + "b3_cityscapes_1024": "nvidia/segformer-b3-finetuned-cityscapes-1024-1024", + "b4_cityscapes_1024": "nvidia/segformer-b4-finetuned-cityscapes-1024-1024", + "b5_cityscapes_1024": "nvidia/segformer-b5-finetuned-cityscapes-1024-1024", +} + +flags.DEFINE_string( + "preset", None, f'Must be one of {",".join(DOWNLOAD_URLS.keys())}' +) + + +def set_conv_weights(conv_layer, state_dict): + conv_weights = state_dict["weight"].numpy().transpose(2, 3, 1, 0) + bias = None + if "bias" in state_dict.keys(): + bias = state_dict["bias"].numpy() + conv_layer.set_weights([conv_weights, bias]) + else: + conv_layer.set_weights([conv_weights]) + + +def set_dense_weights(dense_layer, state_dict): + weight = state_dict["weight"].numpy().T + bias = state_dict["bias"].numpy() + dense_layer.set_weights([weight, bias]) + + +def set_batchnorm_weights(bn_layer, state_dict): + gamma = state_dict["weight"].numpy() + beta = state_dict["bias"].numpy() + running_mean = state_dict["running_mean"].numpy() + running_var = state_dict["running_var"].numpy() + + bn_layer.set_weights([gamma, beta, running_mean, running_var]) + + +def main(_): + print("\n-> Loading HuggingFace model") + original_segformer = SegformerForSemanticSegmentation.from_pretrained( + DOWNLOAD_URLS[FLAGS.preset] + ) + + print("\n-> Instantiating KerasHub Model") + + resolution = int(FLAGS.preset.split("_")[-1]) + + encoder = keras_hub.models.MiTBackbone.from_preset( + "mit_" + FLAGS.preset, image_shape=(resolution, resolution, 3) + ) + segformer_backbone = keras_hub.models.SegFormerBackbone( + image_encoder=encoder, + projection_filters=PROJECTION_FILTERS[FLAGS.preset], + ) + num_classes = 150 if "ade20k" in FLAGS.preset else 19 + + preprocessor = SegFormerImageSegmenterPreprocessor() + segformer_segmenter = keras_hub.models.SegFormerImageSegmenter( + backbone=segformer_backbone, + num_classes=num_classes, + preprocessor=preprocessor, + ) + segformer_backbone(np.random.rand(1, resolution, resolution, 3)) + + set_dense_weights( + segformer_backbone.layers[5], + original_segformer.decode_head.linear_c[0].proj.state_dict(), + ) + set_dense_weights( + segformer_backbone.layers[4], + original_segformer.decode_head.linear_c[1].proj.state_dict(), + ) + set_dense_weights( + segformer_backbone.layers[3], + original_segformer.decode_head.linear_c[2].proj.state_dict(), + ) + set_dense_weights( + segformer_backbone.layers[2], + original_segformer.decode_head.linear_c[3].proj.state_dict(), + ) + set_conv_weights( + segformer_backbone.layers[-1].layers[0], + original_segformer.decode_head.linear_fuse.state_dict(), + ) + set_batchnorm_weights( + segformer_backbone.layers[-1].layers[1], + original_segformer.decode_head.batch_norm.state_dict(), + ) + + set_conv_weights( + segformer_segmenter.layers[-2], + original_segformer.decode_head.classifier.state_dict(), + ) + + print("\n-> Converting weights...") + + directory = f"SegFormer_{FLAGS.preset}" + print(f"\n-> Saving converted KerasHub model in {directory}") + segformer_segmenter.save_to_preset(directory) + + +if __name__ == "__main__": + flags.mark_flag_as_required("preset") + app.run(main) diff --git a/tools/checkpoint_conversion/convert_stable_diffusion_3_checkpoints.py b/tools/checkpoint_conversion/convert_stable_diffusion_3_checkpoints.py index 15b9691532..38e19cf107 100644 --- a/tools/checkpoint_conversion/convert_stable_diffusion_3_checkpoints.py +++ b/tools/checkpoint_conversion/convert_stable_diffusion_3_checkpoints.py @@ -113,8 +113,7 @@ def convert_model(preset, height, width): vae, clip_l, clip_g, - height=height, - width=width, + image_shape=(height, width, 3), name="stable_diffusion_3_backbone", ) return backbone @@ -130,23 +129,23 @@ def convert_preprocessor(): vocabulary, merges, pad_with_end_token=True, - config_name="clip_l_tokenizer.json", + config_file="clip_l_tokenizer.json", name="clip_l_tokenizer", ) clip_g_tokenizer = CLIPTokenizer( vocabulary, merges, - config_name="clip_g_tokenizer.json", + config_file="clip_g_tokenizer.json", name="clip_g_tokenizer", ) clip_l_preprocessor = CLIPPreprocessor( clip_l_tokenizer, - config_name="clip_l_preprocessor.json", + config_file="clip_l_preprocessor.json", name="clip_l_preprocessor", ) clip_g_preprocessor = CLIPPreprocessor( clip_g_tokenizer, - config_name="clip_g_preprocessor.json", + config_file="clip_g_preprocessor.json", name="clip_g_preprocessor", ) preprocessor = StableDiffusion3TextToImagePreprocessor( @@ -310,19 +309,19 @@ def port_diffuser(preset, filename, model): ) port_dense(loader, model.context_embedding, "context_embedder") port_dense( - loader, model.vector_embedding.layers[0], "y_embedder.mlp.0" + loader, model.vector_embedding.dense1, "y_embedder.mlp.0" ) port_dense( - loader, model.vector_embedding.layers[1], "y_embedder.mlp.2" + loader, model.vector_embedding.dense2, "y_embedder.mlp.2" ) port_dense( loader, - model.timestep_embedding.mlp.layers[0], + model.timestep_embedding.mlp.dense1, "t_embedder.mlp.0", ) port_dense( loader, - model.timestep_embedding.mlp.layers[1], + model.timestep_embedding.mlp.dense2, "t_embedder.mlp.2", ) @@ -338,7 +337,7 @@ def port_diffuser(preset, filename, model): prefix = f"joint_blocks.{i}.{block_name}" port_dense( loader, - block.adaptive_norm_modulation.layers[1], + block.ada_layer_norm.dense, f"{prefix}.adaLN_modulation.1", ) port_dense( @@ -351,18 +350,16 @@ def port_diffuser(preset, filename, model): port_dense( loader, block.attention_proj, f"{prefix}.attn.proj" ) - port_dense(loader, block.mlp.layers[0], f"{prefix}.mlp.fc1") - port_dense(loader, block.mlp.layers[1], f"{prefix}.mlp.fc2") + port_dense(loader, block.mlp.dense1, f"{prefix}.mlp.fc1") + port_dense(loader, block.mlp.dense2, f"{prefix}.mlp.fc2") # Output layer port_dense( loader, - model.output_layer.adaptive_norm_modulation.layers[1], + model.output_ada_layer_norm.dense, "final_layer.adaLN_modulation.1", ) - port_dense( - loader, model.output_layer.output_dense, "final_layer.linear" - ) + port_dense(loader, model.output_dense, "final_layer.linear") return model def port_vae(preset, filename, model): @@ -534,8 +531,7 @@ def main(_): keras_preprocessor.save_to_preset(preset) # Set the image size to 1024, the same as in huggingface/diffusers. - keras_model.height = 1024 - keras_model.width = 1024 + keras_model.image_shape = (1024, 1024, 3) keras_model.save_to_preset(preset) print(f"šŸ Preset saved to ./{preset}.") diff --git a/tools/checkpoint_conversion/convert_vgg_checkpoints.py b/tools/checkpoint_conversion/convert_vgg_checkpoints.py new file mode 100644 index 0000000000..fea9aaf01f --- /dev/null +++ b/tools/checkpoint_conversion/convert_vgg_checkpoints.py @@ -0,0 +1,116 @@ +"""Loads an external VGG model and saves it in Keras format. + +Optionally uploads the model to Keras if the `--upload_uri` flag is passed. + +python tools/checkpoint_conversion/convert_vgg_checkpoints.py \ + --preset vgg11 --upload_uri kaggle://kerashub/vgg/keras/vgg11 +""" + +import os +import shutil + +import keras +import numpy as np +import PIL +import timm +import torch +from absl import app +from absl import flags + +import keras_hub + +PRESET_MAP = { + "vgg11": "timm/vgg11.tv_in1k", + "vgg13": "timm/vgg13.tv_in1k", + "vgg16": "timm/vgg16.tv_in1k", + "vgg19": "timm/vgg19.tv_in1k", + # TODO(jeffcarp): Add BN variants. +} + + +PRESET = flags.DEFINE_string( + "preset", + None, + "Must be a valid `VGG` preset from KerasHub", + required=True, +) +UPLOAD_URI = flags.DEFINE_string( + "upload_uri", + None, + 'Could be "kaggle://keras/{variant}/keras/{preset}_int8"', +) + + +def validate_output(keras_model, timm_model): + file = keras.utils.get_file( + origin=( + "https://storage.googleapis.com/keras-cv/" + "models/paligemma/cow_beach_1.png" + ) + ) + image = PIL.Image.open(file) + batch = np.array([image]) + + # Preprocess with Timm. + data_config = timm.data.resolve_model_data_config(timm_model) + data_config["crop_pct"] = 1.0 # Stop timm from cropping. + transforms = timm.data.create_transform(**data_config, is_training=False) + timm_preprocessed = transforms(image) + timm_preprocessed = keras.ops.transpose(timm_preprocessed, axes=(1, 2, 0)) + timm_preprocessed = keras.ops.expand_dims(timm_preprocessed, 0) + + # Preprocess with Keras. + keras_preprocessed = keras_model.preprocessor(batch) + + # Call with Timm. Use the keras preprocessed image so we can keep modeling + # and preprocessing comparisons independent. + timm_batch = keras.ops.transpose(keras_preprocessed, axes=(0, 3, 1, 2)) + timm_batch = torch.from_numpy(np.array(timm_batch)) + timm_outputs = timm_model(timm_batch).detach().numpy() + timm_label = np.argmax(timm_outputs[0]) + + # Call with Keras. + keras_outputs = keras_model.predict(batch) + keras_label = np.argmax(keras_outputs[0]) + + print("šŸ”¶ Keras output:", keras_outputs[0, :10]) + print("šŸ”¶ TIMM output:", timm_outputs[0, :10]) + print("šŸ”¶ Keras label:", keras_label) + print("šŸ”¶ TIMM label:", timm_label) + modeling_diff = np.mean(np.abs(keras_outputs - timm_outputs)) + print("šŸ”¶ Modeling difference:", modeling_diff) + preprocessing_diff = np.mean(np.abs(keras_preprocessed - timm_preprocessed)) + print("šŸ”¶ Preprocessing difference:", preprocessing_diff) + + +def main(_): + preset = PRESET.value + if os.path.exists(preset): + shutil.rmtree(preset) + os.makedirs(preset) + + timm_name = PRESET_MAP[preset] + + timm_model = timm.create_model(timm_name, pretrained=True) + timm_model = timm_model.eval() + print("āœ… Loaded TIMM model.") + print(timm_model) + + keras_model = keras_hub.models.ImageClassifier.from_preset( + "hf://" + timm_name, + ) + print("āœ… Loaded KerasHub model.") + + keras_model.save_to_preset(f"./{preset}") + print(f"šŸ Preset saved to ./{preset}") + + validate_output(keras_model, timm_model) + + upload_uri = UPLOAD_URI.value + if upload_uri: + keras_hub.upload_preset(uri=upload_uri, preset=f"./{preset}") + print(f"šŸ Preset uploaded to {upload_uri}") + + +if __name__ == "__main__": + app.run(main)