From 67b27969d1b3359f80f7461e40104bbfc495abfb Mon Sep 17 00:00:00 2001 From: divyashreepathihalli Date: Fri, 2 Feb 2024 01:08:39 +0000 Subject: [PATCH 01/38] clip refactor --- keras_cv/models/CLIP/__init__.py | 22 ++ keras_cv/models/CLIP/clip_encoder.py | 262 +++++++++++++++++++++++ keras_cv/models/CLIP/clip_image_model.py | 133 ++++++++++++ keras_cv/models/CLIP/clip_model.py | 198 +++++++++++++++++ keras_cv/models/CLIP/clip_model_test.py | 34 +++ keras_cv/models/CLIP/clip_processor.py | 120 +++++++++++ keras_cv/models/CLIP/clip_text_model.py | 65 ++++++ keras_cv/models/CLIP/clip_tokenizer.py | 194 +++++++++++++++++ 8 files changed, 1028 insertions(+) create mode 100644 keras_cv/models/CLIP/__init__.py create mode 100644 keras_cv/models/CLIP/clip_encoder.py create mode 100644 keras_cv/models/CLIP/clip_image_model.py create mode 100644 keras_cv/models/CLIP/clip_model.py create mode 100644 keras_cv/models/CLIP/clip_model_test.py create mode 100644 keras_cv/models/CLIP/clip_processor.py create mode 100644 keras_cv/models/CLIP/clip_text_model.py create mode 100644 keras_cv/models/CLIP/clip_tokenizer.py diff --git a/keras_cv/models/CLIP/__init__.py b/keras_cv/models/CLIP/__init__.py new file mode 100644 index 0000000000..e15d291766 --- /dev/null +++ b/keras_cv/models/CLIP/__init__.py @@ -0,0 +1,22 @@ +# Copyright 2023 The KerasCV Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from keras_cv.models.clip.clip_image_model import ( + CLIPImageEncoder, +) +from keras_cv.models.clip.clip_image_model import ( + CLIPTextEncoder, +) +from keras_cv.models.clip.clip_processor import CLIPProcessor +from keras_cv.models.clip.clip_tokenizer import CLIPTokenizer \ No newline at end of file diff --git a/keras_cv/models/CLIP/clip_encoder.py b/keras_cv/models/CLIP/clip_encoder.py new file mode 100644 index 0000000000..b7c657c357 --- /dev/null +++ b/keras_cv/models/CLIP/clip_encoder.py @@ -0,0 +1,262 @@ +# Copyright 2023 The KerasCV Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from keras_cv.backend import keras +from keras_cv.backend import ops + + +def get_initializer(initializer_range=0.02): + """ + Creates a `keras.initializers.TruncatedNormal` with the given range. + + Args: + initializer_range (*float*, defaults to 0.02): Standard deviation of the initializer range. + + Returns: + `keras.initializers.TruncatedNormal`: The truncated normal initializer. + """ + return keras.initializers.TruncatedNormal(stddev=initializer_range) + + +class QuickGELU(keras.layers.Layer): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def call(self, x): + return x * ops.sigmoid(1.702 * x) + + +class ResidualAttention(keras.layers.Layer): + def __init__( + self, + proj_dim, + n_head, + num_hidden_layers, + attn_mask=None, + ): + super().__init__() + self.proj_dim = proj_dim + self.n_head = n_head + self.attn_mask = attn_mask + self.num_hidden_layers = num_hidden_layers + self.fc_std = ops.power(2 * self.proj_dim, -0.5) * 0.02 + + self.in_proj_std = ( + ops.power(self.proj_dim, -0.5) + * (ops.power(2 * self.num_hidden_layers, -0.5)) + * 0.02 + ) + + def attention(self, x): + self.attn_mask = ( + ops.cast(self.attn_mask, dtype=x.dtype) + if self.attn_mask is not None + else None + ) + + return self.attn(x, attention_mask=self.attn_mask) + + def build(self, input_shape): + super().build(input_shape) + self.attn = CLIPAttention( + self.proj_dim, + self.n_head, + self.num_hidden_layers, + name="multi_head_attention", + ) + self.ln_1 = keras.layers.LayerNormalization(epsilon=1e-5, name="ln_1") + self.mlp = keras.Sequential( + [ + keras.layers.Dense( + self.proj_dim * 4, + kernel_initializer=get_initializer(self.in_proj_std), + name="c_fc", + ), + QuickGELU(name="gelu"), + keras.layers.Dense( + self.proj_dim, + kernel_initializer=get_initializer(self.fc_std), + name="c_proj", + ), + ] + ) + self.ln_2 = keras.layers.LayerNormalization(epsilon=1e-5, name="ln_2") + + def call(self, x): + x = x + self.attention(self.ln_1(x)) + x = x + self.mlp(self.ln_2(x)) + return x + + def compute_output_shape(self, inputs_shape): + return inputs_shape + + +class CLIPEncoder(keras.layers.Layer): + def __init__(self, width, layers, heads, attn_mask=None, **kwargs): + super().__init__(**kwargs) + self.width = width + self.layers = layers + self.heads = heads + self.attn_mask = attn_mask + self.resblocks = keras.Sequential( + [ + ResidualAttention( + self.width, self.heads, self.layers, self.attn_mask + ) + for _ in range(self.layers) + ] + ) + + def build(self, input_shape): + super().build(input_shape) + + def call(self, x): + return self.resblocks(x) + + def compute_output_shape(self, inputs_shape): + return inputs_shape + + +class CLIPAttention(keras.layers.Layer): + """ + - Documentation page: https://huggingface.co/docs/transformers/model_doc/clip # noqa: E501 + - Implementation: https://github.com/huggingface/transformers/blob/main/src/transformers/models/clip/modeling_clip.py # noqa: E501 + """ + + def __init__( + self, project_dim, num_heads, num_hidden_layers, dropout=0.0, **kwargs + ): + super().__init__(**kwargs) + + self.project_dim = project_dim + self.num_heads = num_heads + self.num_hidden_layers = num_hidden_layers + self.head_dim = self.project_dim // self.num_heads + if self.head_dim * self.num_heads != self.project_dim: + raise ValueError( + f"project_dim must be divisible by num_heads (got `project_dim`" + f": {self.project_dim} and `num_heads`:" + f" {self.num_heads})." + ) + + self.sqrt_att_head_size = ops.sqrt(self.head_dim) + self.scale = self.head_dim**-0.5 + in_proj_std = ( + (self.project_dim**-0.5) + * ((2 * self.num_hidden_layers) ** -0.5) + * 0.02 + ) + out_proj_std = (self.project_dim**-0.5) * 0.02 + self.dropout = dropout + self.q_proj = keras.layers.Dense( + units=self.project_dim, + kernel_initializer=get_initializer(in_proj_std), + name="q_proj", + ) + self.k_proj = keras.layers.Dense( + units=self.project_dim, + kernel_initializer=get_initializer(in_proj_std), + name="k_proj", + ) + self.v_proj = keras.layers.Dense( + units=self.project_dim, + kernel_initializer=get_initializer(in_proj_std), + name="v_proj", + ) + self.out_proj = keras.layers.Dense( + units=self.project_dim, + kernel_initializer=get_initializer(out_proj_std), + name="out_proj", + ) + + def build(self, input_shape): + super().build(input_shape) + self.q_proj.build(input_shape) + self.k_proj.build(input_shape) + self.v_proj.build(input_shape) + self.out_proj.build(input_shape) + + def _transpose_for_scores(self, tensor, batch_size): + """ + Copied from https://github.com/huggingface/transformers/blob/8e164c5400b7b413c7b8fb32e35132001effc970/src/transformers/models/bert/modeling_tf_bert.py#L252 # noqa: E501 + """ + # [batch_size, seq_len, all_head_dim] -> + # [batch_size, seq_len, num_heads, head_dim] + tensor = ops.reshape( + tensor, (batch_size, -1, self.num_heads, self.head_dim) + ) + # [batch_size, seq_len, num_heads, head_dim] -> + # [batch_size, num_heads, seq_len, head_dim] + return ops.transpose(tensor, axes=[0, 2, 1, 3]) + + def call( + self, + x, + attention_mask=None, + causal_attention_mask=None, + output_attentions=None, + training=False, + ): + batch_size = ops.shape(x)[0] + mixed_query_layer = self.q_proj(inputs=x) + mixed_key_layer = self.k_proj(inputs=x) + mixed_value_layer = self.v_proj(inputs=x) + query_layer = self._transpose_for_scores(mixed_query_layer, batch_size) + key_layer = self._transpose_for_scores(mixed_key_layer, batch_size) + value_layer = self._transpose_for_scores(mixed_value_layer, batch_size) + + # Scaled dot product between key and query = raw attention scores. + attention_scores = ops.matmul( + query_layer, ops.transpose(key_layer, axes=[0, 1, 3, 2]) + ) + dk = ops.cast(ops.sqrt(self.head_dim), dtype=attention_scores.dtype) + attention_scores = ops.divide( + attention_scores, dk + ) # (batch_size, num_heads, seq_len_q, seq_len_k) + + # Apply the causal_attention_mask first + if causal_attention_mask is not None: + # Apply the causal attention mask (precomputed for all layers in + # the call() function) + attention_scores = ops.add(attention_scores, causal_attention_mask) + + if attention_mask is not None: + # Apply the attention mask (precomputed for all layers in the + # call() function) + attention_scores = ops.add(attention_scores, attention_mask) + + # Normalize the attention scores to probabilities. + _attention_probs = ops.softmax(attention_scores + 1e-9, axis=-1) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = keras.layers.Dropout(self.dropout)( + inputs=_attention_probs, training=training + ) + + attn_output = ops.matmul(attention_probs, value_layer) + attn_output = ops.transpose(attn_output, axes=[0, 2, 1, 3]) + + # (batch_size, seq_len_q, project_dim) + attn_output = ops.reshape( + attn_output, (batch_size, -1, self.project_dim) + ) + + attn_output = self.out_proj(attn_output, training=training) + outputs = ( + (attn_output, _attention_probs) + if output_attentions + else attn_output + ) + + return outputs diff --git a/keras_cv/models/CLIP/clip_image_model.py b/keras_cv/models/CLIP/clip_image_model.py new file mode 100644 index 0000000000..c66d964155 --- /dev/null +++ b/keras_cv/models/CLIP/clip_image_model.py @@ -0,0 +1,133 @@ +# Copyright 2023 The KerasCV Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from keras_cv.backend import keras +from keras_cv.backend import ops +from keras_cv.models.feature_extractors.clip.clip_encoder import CLIPEncoder +from keras_cv.models.feature_extractors.clip.clip_encoder import get_initializer + + +class CLIPPatchingAndEmbedding(keras.layers.Layer): + def __init__(self, width, patch_size, input_resolution, output_dim): + super().__init__() + + self.conv1 = keras.layers.Conv2D( + filters=width, + kernel_size=patch_size, + strides=patch_size, + padding="valid", + use_bias=False, + data_format="channels_last", + kernel_initializer=get_initializer(0.02), + name="patch_embed.embedding", + ) + self.width = width + self.input_resolution = input_resolution + self.patch_size = patch_size + self.num_patches = ops.power( + (self.input_resolution // self.patch_size), 2 + ) + self.class_embedding_initializer = get_initializer( + ops.power(self.width, -0.5) * 0.02 + ) + self.output_dim = output_dim + + def build(self, input_shape): + self.conv1.build(input_shape) + self.class_embedding = self.add_weight( + shape=((self.width,)), + initializer=self.class_embedding_initializer, + name="patch_embed.class_embedding", + ) + + self.positional_embedding = self.add_weight( + shape=( + ( + (self.input_resolution // self.patch_size) ** 2 + 1, + self.width, + ) + ), + trainable=True, + name="patch_embed.positional_embedding", + ) + + def call(self, x): + batch_size, _, _, _ = ops.shape(x) + patch_embeddings = self.conv1(x) # shape = [*, grid, grid, channel] + + patch_embeddings = ops.reshape( + patch_embeddings, (batch_size, self.num_patches, -1) + ) + class_embeds = ops.broadcast_to( + self.class_embedding, (batch_size, 1, self.width) + ) + embeddings = ops.concatenate( + [class_embeds, patch_embeddings], axis=1 + ) # shape = [*, grid ** 2 + 1, width] + positional_embedding = self.positional_embedding + embeddings = embeddings + positional_embedding + return embeddings + + +class CLIPImageEncoder(keras.Model): + def __init__( + self, + input_resolution: int, + patch_size: int, + width: int, + layers: int, + heads: int, + output_dim: int, + **kwargs, + ): + super().__init__( + **kwargs, + ) + self.input_resolution = input_resolution + self.width = width + self.patch_size = patch_size + self.output_dim = output_dim + + self.embeddings = CLIPPatchingAndEmbedding( + width=self.width, + patch_size=self.patch_size, + input_resolution=self.input_resolution, + output_dim=self.output_dim, + ) + self.pre_norm = keras.layers.LayerNormalization( + epsilon=1e-5, name="ln_1" + ) + self.encoder = CLIPEncoder( + width, + layers, + heads, + name="residual_transformer_encoder", + ) + self.post_norm = keras.layers.LayerNormalization( + epsilon=1e-5, name="ln_2" + ) + self.image_projector = keras.layers.Dense( + output_dim, name="vision_projector", use_bias=False + ) + + def build(self, input_shape): + self.embeddings.build(input_shape) + + def call(self, image): + embeddings = self.embeddings(image) + pre_norm = self.pre_norm(embeddings) + encoded_output = self.encoder(pre_norm) + post_norm = self.post_norm(encoded_output[:, 0, :]) + image_projected_embeddings = self.image_projector(post_norm) + return image_projected_embeddings diff --git a/keras_cv/models/CLIP/clip_model.py b/keras_cv/models/CLIP/clip_model.py new file mode 100644 index 0000000000..6eae744380 --- /dev/null +++ b/keras_cv/models/CLIP/clip_model.py @@ -0,0 +1,198 @@ +# Copyright 2023 The KerasCV Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from keras_cv.api_export import keras_cv_export +from keras_cv.backend import keras +from keras_cv.backend import ops +from keras_cv.models.feature_extractors.clip.clip_image_model import CLIPEncoder +from keras_cv.models.feature_extractors.clip.clip_image_model import ( + CLIPImageEncoder, +) +from keras_cv.models.feature_extractors.clip.clip_text_model import ( + CLIPTextEncoder, +) + +MODEL_CONFIGS = { + "CLIP_B32": { + "embed_dim": 512, + "context_length": 77, + "vocab_size": 49408, + "transformer_width": 512, + "transformer_heads": 8, + "transformer_layers": 12, + "vision_layers": 12, + "vision_width": 768, + "image_resolution": 224, + "vision_patch_size": 32, + }, + "CLIP_B16": { + "embed_dim": 512, + "context_length": 77, + "vocab_size": 49408, + "transformer_width": 512, + "transformer_heads": 8, + "transformer_layers": 12, + "vision_layers": 12, + "vision_width": 768, + "image_resolution": 224, + "vision_patch_size": 16, + }, + "CLIP_L14": { + "embed_dim": 768, + "context_length": 77, + "vocab_size": 49408, + "transformer_width": 768, + "transformer_heads": 12, + "transformer_layers": 12, + "vision_layers": 24, + "vision_width": 1024, + "image_resolution": 224, + "vision_patch_size": 14, + }, + "CLIP_L14_336": { + "embed_dim": 768, + "context_length": 77, + "vocab_size": 49408, + "transformer_width": 768, + "transformer_heads": 12, + "transformer_layers": 12, + "vision_layers": 24, + "vision_width": 1024, + "image_resolution": 336, + "vision_patch_size": 14, + }, +} + + +@keras_cv_export( + ["keras_cv.models.CLIP", "keras_cv.models.feature_extractors.CLIP"] +) +class CLIP(keras.Model): + """ + CLIP implements the Contrastive Language-Image Pretraining (CLIP) + architecture, which enables joint learning of visual and textual + representations for various downstream tasks. + + Args: + embed_dim (int): The dimensionality of the joint embedding space for + images and texts. + image_resolution (int): The resolution of the input images (both height + and width). + vision_layers (int): The number of layers in the vision (image) encoder. + vision_width (int): The width of the hidden layers in the vision + encoder. + vision_patch_size (int): The size of each square patch in the input + images. + context_length (int): The maximum length of the contextualized text + sequences. + vocab_size (int): The size of the vocabulary for tokenization. + transformer_width (int): The width of the hidden layers in the + transformer-based text encoder. + transformer_heads (int): The number of attention heads in the + transformer-based text encoder. + transformer_layers (int): The number of layers in the transformer-based + text encoder. + """ + + def __init__( + self, + embed_dim, + image_resolution, + vision_layers, + vision_width, + vision_patch_size, + context_length, + vocab_size, + transformer_width, + transformer_heads, + transformer_layers, + ): + super().__init__() + + self.context_length = context_length + + vision_heads = vision_width // 64 + self.image_encoder = CLIPImageEncoder( + input_resolution=image_resolution, + patch_size=vision_patch_size, + width=vision_width, + layers=vision_layers, + heads=vision_heads, + output_dim=embed_dim, + name="clip_encoder", + ) + self.text_encoder = CLIPTextEncoder( + transformer_width=transformer_width, + transformer_layers=transformer_layers, + transformer_heads=transformer_heads, + vocab_size=vocab_size, + embed_dim=embed_dim, + context_length=context_length, + name="text_encoder", + ) + + self.logit_scale = keras.Variable( + ops.ones([]) * ops.log(1 / 0.07), name="logit_scale" + ) + self.image_embeddings = None + self.text_embeddings = None + + def build_attention_mask(self): + mask = ops.ones((self.context_length, self.context_length)) + # Zero out the lower diagonal + mask = ops.triu(mask) + return ops.cast(mask, "float32") + + def encode_images(self, image): + return self.image_encoder(image) + + def encode_text(self, text): + return self.text_encoder(text) + + def call(self, image, text): + self.image_embeddings = self.encode_images(image) + self.text_embeddings = self.encode_text(text) + normalize_image_features = keras.ops.sqrt( + keras.ops.sum( + keras.ops.power(self.image_embeddings, 2), keepdims=True + ) + ) + normalize_text_features = keras.ops.sqrt( + keras.ops.sum( + keras.ops.power(self.text_embeddings, 2), keepdims=True + ) + ) + self.image_embeddings = self.image_embeddings / normalize_image_features + self.text_embeddings = self.text_embeddings / normalize_text_features + + logit_scale = ops.exp(self.logit_scale) + print("logit scale", logit_scale) + print( + "matmul", + ops.matmul( + self.image_embeddings, + ops.transpose(self.text_embeddings), + ), + ) + + logits_per_image = ( + ops.matmul( + self.image_embeddings, + ops.transpose(self.text_embeddings), + ) + * logit_scale + ) + print("logit per image", logits_per_image) + logits_per_text = ops.transpose(logits_per_image) + + return logits_per_image, logits_per_text diff --git a/keras_cv/models/CLIP/clip_model_test.py b/keras_cv/models/CLIP/clip_model_test.py new file mode 100644 index 0000000000..1cc8cc3b10 --- /dev/null +++ b/keras_cv/models/CLIP/clip_model_test.py @@ -0,0 +1,34 @@ +# Copyright 2022 The KerasCV Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +from tensorflow.keras import mixed_precision + +from keras_cv.backend import ops +from keras_cv.backend import random +from keras_cv.models import CLIP +from keras_cv.tests.test_case import TestCase + + +@pytest.mark.tf_only +class StableDiffusionTest(TestCase): + def test_clip_tokenizer(self): + pass + + def test_presets(self): + pass + + @pytest.mark.extra_large + def test_mixed_precision(self): + pass diff --git a/keras_cv/models/CLIP/clip_processor.py b/keras_cv/models/CLIP/clip_processor.py new file mode 100644 index 0000000000..bcf81e6fe5 --- /dev/null +++ b/keras_cv/models/CLIP/clip_processor.py @@ -0,0 +1,120 @@ +# Copyright 2023 The KerasCV Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import numpy as np +from keras_nlp.layers import StartEndPacker + +from keras_cv.api_export import keras_cv_export +from keras_cv.backend import keras +from keras_cv.backend import ops +from keras_cv.models.feature_extractors.clip.clip_tokenizer import CLIPTokenizer + + +@keras_cv_export("keras_cv.models.feature_extractors.CLIPProcessor") +class CLIPProcessor: + """ + CLIPProcessor is a utility class that provides functionality for processing + images and texts in the context of the CLIP (Contrastive Language-Image + Pretraining) model. + + Args: + input_resolution (int): The resolution of input images. + vocabulary (str): string or dict, maps token to integer ids. If it is a + string, it should be the file path to a json file. + merges: string or list, contains the merge rule. If it is a string, it + should be the file path to merge rules. The merge rule file should + have one merge rule per line. + + Methods: + process_images(image_path: List[str]): Transforms an image located at + the specified path. + + process_texts(texts: Union[str, List[str]], context_length: int = 77): + Processes a single text or a list of texts, returning packed token + sequences. + + """ + + def __init__(self, input_resolution, vocabulary, merges): + self.input_resolution = input_resolution + self.image_transform = self.transform_image + self.tokenizer = CLIPTokenizer( + vocabulary=vocabulary, + merges=merges, + unsplittable_tokens=[""], + ) + self.packer = StartEndPacker( + start_value=self.tokenizer.token_to_id("<|startoftext|>"), + end_value=self.tokenizer.token_to_id("<|endoftext|>"), + pad_value=None, + sequence_length=77, + return_padding_mask=True, + ) + + def transform_image(self, image_path): + input_resolution = self.input_resolution + mean = np.array([0.48145466, 0.4578275, 0.40821073]) + std = np.array([0.26862954, 0.26130258, 0.27577711]) + + image = keras.utils.load_img(image_path) + image = keras.utils.img_to_array(image) + image = ( + ops.image.resize( + image, + (input_resolution, input_resolution), + interpolation="bicubic", + ) + / 255.0 + ) + central_fraction = input_resolution / image.shape[0] + width, height = image.shape[0], image.shape[1] + left = ops.cast((width - width * central_fraction) / 2, dtype="int32") + top = ops.cast((height - height * central_fraction) / 2, dtype="int32") + right = ops.cast((width + width * central_fraction) / 2, dtype="int32") + bottom = ops.cast( + (height + height * central_fraction) / 2, dtype="int32" + ) + + image = ops.slice( + image, [left, top, 0], [right - left, bottom - top, 3] + ) + + image = (image - mean) / std + return image + + def process_images(self, images): + if isinstance(images, str): + images = [images] + + def process_image(image): + if isinstance(image, str): + return self.image_transform(image) + + processed_images = list(map(process_image, images)) + processed_images = ops.stack(processed_images) + return processed_images + + def process_texts(self, texts, context_length: int = 77): + if isinstance(texts, str): + texts = [texts] + + def pack_tokens(text): + tok, _ = self.packer( + self.tokenizer(text), + sequence_length=context_length, + add_start_value=True, + add_end_value=True, + ) + return tok + + return pack_tokens(texts) diff --git a/keras_cv/models/CLIP/clip_text_model.py b/keras_cv/models/CLIP/clip_text_model.py new file mode 100644 index 0000000000..f83108f110 --- /dev/null +++ b/keras_cv/models/CLIP/clip_text_model.py @@ -0,0 +1,65 @@ +from keras_cv.backend import keras +from keras_cv.backend import ops +from keras_cv.models.feature_extractors.clip.clip_encoder import CLIPEncoder + + +class CLIPTextEncoder(keras.Model): + def __init__( + self, + transformer_width, + transformer_layers, + transformer_heads, + vocab_size, + embed_dim, + context_length, + **kwargs, + ): + super().__init__( + **kwargs, + ) + self.context_length = context_length + self.token_embedding = keras.layers.Embedding( + vocab_size, + transformer_width, + name="token_embedding", + ) + + self.vocab_size = vocab_size + self.positional_embedding = self.add_weight( + shape=[self.context_length, transformer_width], + name="positional_embedding", + ) + self.encoder = CLIPEncoder( + width=transformer_width, + layers=transformer_layers, + heads=transformer_heads, + attn_mask=self.build_attention_mask(), + name="residual_transformer_encoder", + ) + self.ln_final = keras.layers.LayerNormalization(name="ln_final") + + self.text_projector = keras.layers.Dense( + embed_dim, name="text_projector", use_bias=False + ) + + def call(self, inputs): + token_embedding = self.token_embedding(inputs) + encoded_output = self.encoder( + token_embedding + self.positional_embedding + ) + layer_norm = self.ln_final(encoded_output) + indices = ops.expand_dims( + ops.cast(ops.argmax(inputs, axis=1), "int32"), axis=-1 + ) + selected_features = ops.take_along_axis( + layer_norm, indices[:, :, None], axis=1 + ) + text_features = self.text_projector(selected_features) + output = ops.squeeze(text_features, axis=1) + return output + + def build_attention_mask(self): + mask = ops.ones((self.context_length, self.context_length)) + # Zero out the lower diagonal + mask = ops.triu(mask) + return ops.cast(mask, "float32") diff --git a/keras_cv/models/CLIP/clip_tokenizer.py b/keras_cv/models/CLIP/clip_tokenizer.py new file mode 100644 index 0000000000..8a7b5ac3b9 --- /dev/null +++ b/keras_cv/models/CLIP/clip_tokenizer.py @@ -0,0 +1,194 @@ +# Copyright 2023 The KerasCV Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import regex as re +import tensorflow as tf +import tensorflow_text as tf_text +from keras_nlp.tokenizers import BytePairTokenizer + +VOCAB_FILENAME = "keras_cv/models/feature_extractors/clip/vocab.json" +MERGES_FILENAME = "keras_cv/models/feature_extractors/clip/merges.txt" +# As python and TF handles special spaces differently, we need to +# manually handle special spaces during string split. +SPECIAL_WHITESPACES = r"\x{a0}\x{2009}\x{202f}\x{3000}" +SPLIT_PATTERN_1 = ( + r"'s|'t|'re|'ve|'m|'ll|'d" + + r"|[\s{special_spaces}]+[\n\r\t\f६{special_spaces}]| ?\p{L}+|" + + r" ?[\p{N}]+| ?[^\s\p{L}\p{N}{special_spaces}]+" +) +SPLIT_PATTERN_1 = SPLIT_PATTERN_1.replace( + "{special_spaces}", SPECIAL_WHITESPACES +) +SPLIT_PATTERN_2 = rf"""[\s६{SPECIAL_WHITESPACES}]$""" + + +def split_strings_for_bpe(inputs, unsplittable_tokens=None): + # We need to recreate the exact behavior of token presplitting in the + # original gpt2 tokenizer which uses a lookahead. As re2 does not + # support lookahead match, we are using an alternative insert a special + # token "६" before leading space of non-space characters and after the + # trailing space, e.g., " keras" will be "६ keras". + inputs = tf.strings.regex_replace( + inputs, rf"( )([^\s{SPECIAL_WHITESPACES}])", r"६\1\2" + ) + inputs = tf.strings.regex_replace( + inputs, rf"(\s{SPECIAL_WHITESPACES})$", r"\1६" + ) + inputs = tf.strings.regex_replace(inputs, r"\s", "") + if unsplittable_tokens: + alts = create_alts_for_unsplittable_tokens(unsplittable_tokens) + for token, alt in zip(unsplittable_tokens, alts): + escaped_token = re.escape(token) + inputs = tf_text.regex_split(inputs, escaped_token, escaped_token) + inputs = tf.strings.regex_replace(inputs, escaped_token, alt) + raw_tokens = tf_text.regex_split(inputs, SPLIT_PATTERN_1, SPLIT_PATTERN_1) + # Second pass splits out the last whilespace char or "६". + raw_tokens = tf_text.regex_split( + raw_tokens, SPLIT_PATTERN_2, SPLIT_PATTERN_2 + ) + if unsplittable_tokens: + # Replace special tokens alternate with originals. + for token, alt in zip(unsplittable_tokens, alts): + escaped_alt = re.escape(alt) + raw_tokens = tf.strings.regex_replace( + raw_tokens, escaped_alt, token + ) + + # Add '' to the end of each token + tokens_with_end_tag = tf.strings.regex_replace( + raw_tokens, r"(\p{L}+)", r"\1" + ) + + while tokens_with_end_tag.shape.rank > 2: + tokens_with_end_tag = tokens_with_end_tag.merge_dims(1, 2) + + return remove_strings_from_inputs(tokens_with_end_tag, "६") + + +def create_alts_for_unsplittable_tokens(unsplittable_tokens): + # Create alternates for all special tokens that will be not split during + # tokenization. + alts = [] + prefix = "Ĵ" + # Trim out splitters. + replace_pattern = r"'|\s+|[^\p{L}\p{N}]+" + for token in unsplittable_tokens: + token = re.sub(replace_pattern, "", token) + alts.append(prefix + token) + return alts + + +def bytes_to_unicode(): + bs = ( + list(range(ord("!"), ord("~") + 1)) + + list(range(ord("¡"), ord("¬") + 1)) + + list(range(ord("®"), ord("ÿ") + 1)) + ) + cs = bs[:] + n = 0 + # removes mapping an int to a whitespace character + for b in range(2**8): + if b not in bs: + bs.append(b) + cs.append(2**8 + n) + n += 1 + cs = [chr(n) for n in cs] + bs = [n.to_bytes(1, "little") for n in bs] + return bs, cs # int to string mapping + + +def remove_strings_from_inputs(tensor, string_to_remove): + """Remove certain strings from input tensor.""" + non_empty_mask = tensor != string_to_remove + flatten_indexes = tf.where(non_empty_mask) + flatten_result = tf.gather_nd(tensor, flatten_indexes) + row_lengths = tf.reduce_sum(tf.cast(non_empty_mask, "int64"), axis=1) + result = tf.RaggedTensor.from_row_lengths( + values=flatten_result, + row_lengths=row_lengths, + ) + return result + + +class CLIPTokenizer(BytePairTokenizer): + def _bpe_merge_and_update_cache(self, tokens): + """Process unseen tokens and add to cache.""" + words = self._transform_bytes(tokens) + tokenized_words = self._bpe_merge(words) + + # For each word, join all its token by a whitespace, + # e.g., ["dragon", "fly"] => "dragon fly" for hash purpose. + tokenized_words = tf.strings.reduce_join( + tokenized_words, + axis=1, + ) + self.cache.insert(tokens, tokenized_words) + + def tokenize(self, inputs): + self._check_vocabulary() + if not isinstance(inputs, (tf.Tensor, tf.RaggedTensor)): + inputs = tf.convert_to_tensor(inputs) + + if self.add_prefix_space: + inputs = tf.strings.join([" ", inputs]) + + scalar_input = inputs.shape.rank == 0 + if scalar_input: + inputs = tf.expand_dims(inputs, 0) + + raw_tokens = split_strings_for_bpe(inputs, self.unsplittable_tokens) + token_row_splits = raw_tokens.row_splits + flat_tokens = raw_tokens.flat_values + # Check cache. + cache_lookup = self.cache.lookup(flat_tokens) + cache_mask = cache_lookup == "" + + has_unseen_words = tf.math.reduce_any( + (cache_lookup == "") & (flat_tokens != "") + ) + + def process_unseen_tokens(): + unseen_tokens = tf.boolean_mask(flat_tokens, cache_mask) + self._bpe_merge_and_update_cache(unseen_tokens) + return self.cache.lookup(flat_tokens) + + # If `has_unseen_words == True`, it means not all tokens are in cache, + # we will process the unseen tokens. Otherwise return the cache lookup. + tokenized_words = tf.cond( + has_unseen_words, + process_unseen_tokens, + lambda: cache_lookup, + ) + tokens = tf.strings.split(tokenized_words, sep=" ") + if self.compute_dtype != tf.string: + # Encode merged tokens. + tokens = self.token_to_id_map.lookup(tokens) + + # Unflatten to match input. + tokens = tf.RaggedTensor.from_row_splits( + tokens.flat_values, + tf.gather(tokens.row_splits, token_row_splits), + ) + + # Convert to a dense output if `sequence_length` is set. + if self.sequence_length: + output_shape = tokens.shape.as_list() + output_shape[-1] = self.sequence_length + tokens = tokens.to_tensor(shape=output_shape) + + # Convert to a dense output if input in scalar + if scalar_input: + tokens = tf.squeeze(tokens, 0) + tf.ensure_shape(tokens, shape=[self.sequence_length]) + + return tokens From 88ae6a4de92358108dd5525ee303cb6fd8c44872 Mon Sep 17 00:00:00 2001 From: Divyashree Sreepathihalli Date: Fri, 2 Feb 2024 18:52:28 +0000 Subject: [PATCH 02/38] code cleanup and reformat --- benchmarks/vectorized_randomly_zoomed_crop.py | 8 +++---- .../base_image_augmentation_layer.py | 12 +++++----- .../preprocessing/random_crop_and_resize.py | 8 +++---- .../layers/regularization/squeeze_excite.py | 8 +++---- .../object_detection/box_coco_metrics.py | 6 ++--- keras_cv/models/__init__.py | 1 + .../backbones/densenet/densenet_backbone.py | 6 ++--- .../backbones/resnet_v1/resnet_v1_backbone.py | 6 ++--- .../backbones/resnet_v2/resnet_v2_backbone.py | 6 ++--- .../{CLIP => feature_extractor}/__init__.py | 9 ------- .../models/feature_extractor/clip/__init__.py | 23 ++++++++++++++++++ .../clip}/clip_encoder.py | 3 ++- .../clip}/clip_image_model.py | 10 ++++---- .../clip}/clip_model.py | 24 +++++-------------- .../clip}/clip_model_test.py | 1 - .../clip}/clip_processor.py | 2 +- .../clip}/clip_text_model.py | 2 +- .../clip}/clip_tokenizer.py | 0 .../yolo_v8/yolo_v8_backbone.py | 6 ++--- .../yolo_v8/yolo_v8_detector.py | 6 ++--- .../stable_diffusion/noise_scheduler.py | 4 +++- 21 files changed, 79 insertions(+), 72 deletions(-) rename keras_cv/models/{CLIP => feature_extractor}/__init__.py (67%) create mode 100644 keras_cv/models/feature_extractor/clip/__init__.py rename keras_cv/models/{CLIP => feature_extractor/clip}/clip_encoder.py (99%) rename keras_cv/models/{CLIP => feature_extractor/clip}/clip_image_model.py (93%) rename keras_cv/models/{CLIP => feature_extractor/clip}/clip_model.py (90%) rename keras_cv/models/{CLIP => feature_extractor/clip}/clip_model_test.py (95%) rename keras_cv/models/{CLIP => feature_extractor/clip}/clip_processor.py (98%) rename keras_cv/models/{CLIP => feature_extractor/clip}/clip_text_model.py (96%) rename keras_cv/models/{CLIP => feature_extractor/clip}/clip_tokenizer.py (100%) diff --git a/benchmarks/vectorized_randomly_zoomed_crop.py b/benchmarks/vectorized_randomly_zoomed_crop.py index 3a207ed2e3..4e807fd1ab 100644 --- a/benchmarks/vectorized_randomly_zoomed_crop.py +++ b/benchmarks/vectorized_randomly_zoomed_crop.py @@ -249,10 +249,10 @@ def from_config(cls, config): config["zoom_factor"] ) if isinstance(config["aspect_ratio_factor"], dict): - config["aspect_ratio_factor"] = ( - keras.utils.deserialize_keras_object( - config["aspect_ratio_factor"] - ) + config[ + "aspect_ratio_factor" + ] = keras.utils.deserialize_keras_object( + config["aspect_ratio_factor"] ) return cls(**config) diff --git a/keras_cv/layers/preprocessing/base_image_augmentation_layer.py b/keras_cv/layers/preprocessing/base_image_augmentation_layer.py index 167da7ad0b..ef2e9cefe7 100644 --- a/keras_cv/layers/preprocessing/base_image_augmentation_layer.py +++ b/keras_cv/layers/preprocessing/base_image_augmentation_layer.py @@ -236,15 +236,15 @@ def _compute_output_signature(self, inputs): bounding_boxes = inputs.get(BOUNDING_BOXES, None) if bounding_boxes is not None: - fn_output_signature[BOUNDING_BOXES] = ( - self._compute_bounding_box_signature(bounding_boxes) - ) + fn_output_signature[ + BOUNDING_BOXES + ] = self._compute_bounding_box_signature(bounding_boxes) segmentation_masks = inputs.get(SEGMENTATION_MASKS, None) if segmentation_masks is not None: - fn_output_signature[SEGMENTATION_MASKS] = ( - self.compute_image_signature(segmentation_masks) - ) + fn_output_signature[ + SEGMENTATION_MASKS + ] = self.compute_image_signature(segmentation_masks) keypoints = inputs.get(KEYPOINTS, None) if keypoints is not None: diff --git a/keras_cv/layers/preprocessing/random_crop_and_resize.py b/keras_cv/layers/preprocessing/random_crop_and_resize.py index cd947d5835..593515ad09 100644 --- a/keras_cv/layers/preprocessing/random_crop_and_resize.py +++ b/keras_cv/layers/preprocessing/random_crop_and_resize.py @@ -272,10 +272,10 @@ def from_config(cls, config): config["crop_area_factor"] ) if isinstance(config["aspect_ratio_factor"], dict): - config["aspect_ratio_factor"] = ( - keras.utils.deserialize_keras_object( - config["aspect_ratio_factor"] - ) + config[ + "aspect_ratio_factor" + ] = keras.utils.deserialize_keras_object( + config["aspect_ratio_factor"] ) return cls(**config) diff --git a/keras_cv/layers/regularization/squeeze_excite.py b/keras_cv/layers/regularization/squeeze_excite.py index 8cbcc5bd94..cb03cc6942 100644 --- a/keras_cv/layers/regularization/squeeze_excite.py +++ b/keras_cv/layers/regularization/squeeze_excite.py @@ -118,10 +118,10 @@ def get_config(self): @classmethod def from_config(cls, config): if isinstance(config["squeeze_activation"], dict): - config["squeeze_activation"] = ( - keras.saving.deserialize_keras_object( - config["squeeze_activation"] - ) + config[ + "squeeze_activation" + ] = keras.saving.deserialize_keras_object( + config["squeeze_activation"] ) if isinstance(config["excite_activation"], dict): config["excite_activation"] = keras.saving.deserialize_keras_object( diff --git a/keras_cv/metrics/object_detection/box_coco_metrics.py b/keras_cv/metrics/object_detection/box_coco_metrics.py index 47d86ba1c2..a59af8c767 100644 --- a/keras_cv/metrics/object_detection/box_coco_metrics.py +++ b/keras_cv/metrics/object_detection/box_coco_metrics.py @@ -212,9 +212,9 @@ def result_fn(self, force=False): ) result = {} for i, key in enumerate(METRIC_NAMES): - result[self.name_prefix() + METRIC_MAPPING[key]] = ( - py_func_result[i] - ) + result[ + self.name_prefix() + METRIC_MAPPING[key] + ] = py_func_result[i] return result obj.result = types.MethodType(result_fn, obj) diff --git a/keras_cv/models/__init__.py b/keras_cv/models/__init__.py index b9b90b946a..8e6a849a95 100644 --- a/keras_cv/models/__init__.py +++ b/keras_cv/models/__init__.py @@ -183,6 +183,7 @@ from keras_cv.models.backbones.vit_det.vit_det_aliases import ViTDetLBackbone from keras_cv.models.backbones.vit_det.vit_det_backbone import ViTDetBackbone from keras_cv.models.classification.image_classifier import ImageClassifier +from keras_cv.models.feature_extractor.clip import CLIP from keras_cv.models.object_detection.retinanet.retinanet import RetinaNet from keras_cv.models.object_detection.yolo_v8.yolo_v8_backbone import ( YOLOV8Backbone, diff --git a/keras_cv/models/backbones/densenet/densenet_backbone.py b/keras_cv/models/backbones/densenet/densenet_backbone.py index 251f3601ec..28109b64fa 100644 --- a/keras_cv/models/backbones/densenet/densenet_backbone.py +++ b/keras_cv/models/backbones/densenet/densenet_backbone.py @@ -119,9 +119,9 @@ def __init__( name=f"conv{len(stackwise_num_repeats) + 1}", ) - pyramid_level_inputs[f"P{len(stackwise_num_repeats) + 1}"] = ( - utils.get_tensor_input_name(x) - ) + pyramid_level_inputs[ + f"P{len(stackwise_num_repeats) + 1}" + ] = utils.get_tensor_input_name(x) x = keras.layers.BatchNormalization( axis=BN_AXIS, epsilon=BN_EPSILON, name="bn" )(x) diff --git a/keras_cv/models/backbones/resnet_v1/resnet_v1_backbone.py b/keras_cv/models/backbones/resnet_v1/resnet_v1_backbone.py index 07c896613c..61046234d3 100644 --- a/keras_cv/models/backbones/resnet_v1/resnet_v1_backbone.py +++ b/keras_cv/models/backbones/resnet_v1/resnet_v1_backbone.py @@ -130,9 +130,9 @@ def __init__( first_shortcut=(block_type == "block" or stack_index > 0), name=f"v2_stack_{stack_index}", ) - pyramid_level_inputs[f"P{stack_index + 2}"] = ( - utils.get_tensor_input_name(x) - ) + pyramid_level_inputs[ + f"P{stack_index + 2}" + ] = utils.get_tensor_input_name(x) # Create model. super().__init__(inputs=inputs, outputs=x, **kwargs) diff --git a/keras_cv/models/backbones/resnet_v2/resnet_v2_backbone.py b/keras_cv/models/backbones/resnet_v2/resnet_v2_backbone.py index 6a0cc74740..a31841f7fc 100644 --- a/keras_cv/models/backbones/resnet_v2/resnet_v2_backbone.py +++ b/keras_cv/models/backbones/resnet_v2/resnet_v2_backbone.py @@ -136,9 +136,9 @@ def __init__( first_shortcut=(block_type == "block" or stack_index > 0), name=f"v2_stack_{stack_index}", ) - pyramid_level_inputs[f"P{stack_index + 2}"] = ( - utils.get_tensor_input_name(x) - ) + pyramid_level_inputs[ + f"P{stack_index + 2}" + ] = utils.get_tensor_input_name(x) x = keras.layers.BatchNormalization( axis=BN_AXIS, epsilon=BN_EPSILON, name="post_bn" diff --git a/keras_cv/models/CLIP/__init__.py b/keras_cv/models/feature_extractor/__init__.py similarity index 67% rename from keras_cv/models/CLIP/__init__.py rename to keras_cv/models/feature_extractor/__init__.py index e15d291766..3992ffb59a 100644 --- a/keras_cv/models/CLIP/__init__.py +++ b/keras_cv/models/feature_extractor/__init__.py @@ -11,12 +11,3 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -from keras_cv.models.clip.clip_image_model import ( - CLIPImageEncoder, -) -from keras_cv.models.clip.clip_image_model import ( - CLIPTextEncoder, -) -from keras_cv.models.clip.clip_processor import CLIPProcessor -from keras_cv.models.clip.clip_tokenizer import CLIPTokenizer \ No newline at end of file diff --git a/keras_cv/models/feature_extractor/clip/__init__.py b/keras_cv/models/feature_extractor/clip/__init__.py new file mode 100644 index 0000000000..8826871115 --- /dev/null +++ b/keras_cv/models/feature_extractor/clip/__init__.py @@ -0,0 +1,23 @@ +# Copyright 2023 The KerasCV Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from keras_cv.models.feature_extractor.clip.clip_image_model import ( + CLIPImageEncoder, +) +from keras_cv.models.feature_extractor.clip.clip_model import CLIP +from keras_cv.models.feature_extractor.clip.clip_processor import CLIPProcessor +from keras_cv.models.feature_extractor.clip.clip_text_model import ( + CLIPTextEncoder, +) +from keras_cv.models.feature_extractor.clip.clip_tokenizer import CLIPTokenizer diff --git a/keras_cv/models/CLIP/clip_encoder.py b/keras_cv/models/feature_extractor/clip/clip_encoder.py similarity index 99% rename from keras_cv/models/CLIP/clip_encoder.py rename to keras_cv/models/feature_extractor/clip/clip_encoder.py index b7c657c357..653189ca7d 100644 --- a/keras_cv/models/CLIP/clip_encoder.py +++ b/keras_cv/models/feature_extractor/clip/clip_encoder.py @@ -43,8 +43,9 @@ def __init__( n_head, num_hidden_layers, attn_mask=None, + **kwargs, ): - super().__init__() + super().__init__(**kwargs) self.proj_dim = proj_dim self.n_head = n_head self.attn_mask = attn_mask diff --git a/keras_cv/models/CLIP/clip_image_model.py b/keras_cv/models/feature_extractor/clip/clip_image_model.py similarity index 93% rename from keras_cv/models/CLIP/clip_image_model.py rename to keras_cv/models/feature_extractor/clip/clip_image_model.py index c66d964155..265399f18b 100644 --- a/keras_cv/models/CLIP/clip_image_model.py +++ b/keras_cv/models/feature_extractor/clip/clip_image_model.py @@ -14,13 +14,15 @@ from keras_cv.backend import keras from keras_cv.backend import ops -from keras_cv.models.feature_extractors.clip.clip_encoder import CLIPEncoder -from keras_cv.models.feature_extractors.clip.clip_encoder import get_initializer +from keras_cv.models.feature_extractor.clip.clip_encoder import CLIPEncoder +from keras_cv.models.feature_extractor.clip.clip_encoder import get_initializer class CLIPPatchingAndEmbedding(keras.layers.Layer): - def __init__(self, width, patch_size, input_resolution, output_dim): - super().__init__() + def __init__( + self, width, patch_size, input_resolution, output_dim, **kwargs + ): + super().__init__(**kwargs) self.conv1 = keras.layers.Conv2D( filters=width, diff --git a/keras_cv/models/CLIP/clip_model.py b/keras_cv/models/feature_extractor/clip/clip_model.py similarity index 90% rename from keras_cv/models/CLIP/clip_model.py rename to keras_cv/models/feature_extractor/clip/clip_model.py index 6eae744380..93c9676351 100644 --- a/keras_cv/models/CLIP/clip_model.py +++ b/keras_cv/models/feature_extractor/clip/clip_model.py @@ -14,11 +14,11 @@ from keras_cv.api_export import keras_cv_export from keras_cv.backend import keras from keras_cv.backend import ops -from keras_cv.models.feature_extractors.clip.clip_image_model import CLIPEncoder -from keras_cv.models.feature_extractors.clip.clip_image_model import ( +from keras_cv.models.feature_extractor.clip.clip_image_model import CLIPEncoder +from keras_cv.models.feature_extractor.clip.clip_image_model import ( CLIPImageEncoder, ) -from keras_cv.models.feature_extractors.clip.clip_text_model import ( +from keras_cv.models.feature_extractor.clip.clip_text_model import ( CLIPTextEncoder, ) @@ -74,9 +74,7 @@ } -@keras_cv_export( - ["keras_cv.models.CLIP", "keras_cv.models.feature_extractors.CLIP"] -) +@keras_cv_export(["keras_cv.models.CLIP"]) class CLIP(keras.Model): """ CLIP implements the Contrastive Language-Image Pretraining (CLIP) @@ -116,8 +114,9 @@ def __init__( transformer_width, transformer_heads, transformer_layers, + **kwargs, ): - super().__init__() + super().__init__(**kwargs) self.context_length = context_length @@ -174,17 +173,7 @@ def call(self, image, text): ) self.image_embeddings = self.image_embeddings / normalize_image_features self.text_embeddings = self.text_embeddings / normalize_text_features - logit_scale = ops.exp(self.logit_scale) - print("logit scale", logit_scale) - print( - "matmul", - ops.matmul( - self.image_embeddings, - ops.transpose(self.text_embeddings), - ), - ) - logits_per_image = ( ops.matmul( self.image_embeddings, @@ -192,7 +181,6 @@ def call(self, image, text): ) * logit_scale ) - print("logit per image", logits_per_image) logits_per_text = ops.transpose(logits_per_image) return logits_per_image, logits_per_text diff --git a/keras_cv/models/CLIP/clip_model_test.py b/keras_cv/models/feature_extractor/clip/clip_model_test.py similarity index 95% rename from keras_cv/models/CLIP/clip_model_test.py rename to keras_cv/models/feature_extractor/clip/clip_model_test.py index 1cc8cc3b10..a97106e56d 100644 --- a/keras_cv/models/CLIP/clip_model_test.py +++ b/keras_cv/models/feature_extractor/clip/clip_model_test.py @@ -13,7 +13,6 @@ # limitations under the License. import pytest -from tensorflow.keras import mixed_precision from keras_cv.backend import ops from keras_cv.backend import random diff --git a/keras_cv/models/CLIP/clip_processor.py b/keras_cv/models/feature_extractor/clip/clip_processor.py similarity index 98% rename from keras_cv/models/CLIP/clip_processor.py rename to keras_cv/models/feature_extractor/clip/clip_processor.py index bcf81e6fe5..526952ba74 100644 --- a/keras_cv/models/CLIP/clip_processor.py +++ b/keras_cv/models/feature_extractor/clip/clip_processor.py @@ -17,7 +17,7 @@ from keras_cv.api_export import keras_cv_export from keras_cv.backend import keras from keras_cv.backend import ops -from keras_cv.models.feature_extractors.clip.clip_tokenizer import CLIPTokenizer +from keras_cv.models.feature_extractor.clip.clip_tokenizer import CLIPTokenizer @keras_cv_export("keras_cv.models.feature_extractors.CLIPProcessor") diff --git a/keras_cv/models/CLIP/clip_text_model.py b/keras_cv/models/feature_extractor/clip/clip_text_model.py similarity index 96% rename from keras_cv/models/CLIP/clip_text_model.py rename to keras_cv/models/feature_extractor/clip/clip_text_model.py index f83108f110..cd6ba0add5 100644 --- a/keras_cv/models/CLIP/clip_text_model.py +++ b/keras_cv/models/feature_extractor/clip/clip_text_model.py @@ -1,6 +1,6 @@ from keras_cv.backend import keras from keras_cv.backend import ops -from keras_cv.models.feature_extractors.clip.clip_encoder import CLIPEncoder +from keras_cv.models.feature_extractor.clip.clip_encoder import CLIPEncoder class CLIPTextEncoder(keras.Model): diff --git a/keras_cv/models/CLIP/clip_tokenizer.py b/keras_cv/models/feature_extractor/clip/clip_tokenizer.py similarity index 100% rename from keras_cv/models/CLIP/clip_tokenizer.py rename to keras_cv/models/feature_extractor/clip/clip_tokenizer.py diff --git a/keras_cv/models/object_detection/yolo_v8/yolo_v8_backbone.py b/keras_cv/models/object_detection/yolo_v8/yolo_v8_backbone.py index a2bf4bdd3b..f4bd99fafa 100644 --- a/keras_cv/models/object_detection/yolo_v8/yolo_v8_backbone.py +++ b/keras_cv/models/object_detection/yolo_v8/yolo_v8_backbone.py @@ -178,9 +178,9 @@ def __init__( activation=activation, name=f"{stack_name}_spp_fast", ) - pyramid_level_inputs[f"P{stack_id + 2}"] = ( - utils.get_tensor_input_name(x) - ) + pyramid_level_inputs[ + f"P{stack_id + 2}" + ] = utils.get_tensor_input_name(x) super().__init__(inputs=inputs, outputs=x, **kwargs) self.pyramid_level_inputs = pyramid_level_inputs diff --git a/keras_cv/models/object_detection/yolo_v8/yolo_v8_detector.py b/keras_cv/models/object_detection/yolo_v8/yolo_v8_detector.py index 6c17c71a72..bfba44945c 100644 --- a/keras_cv/models/object_detection/yolo_v8/yolo_v8_detector.py +++ b/keras_cv/models/object_detection/yolo_v8/yolo_v8_detector.py @@ -663,9 +663,9 @@ def from_config(cls, config): if prediction_decoder is not None and isinstance( prediction_decoder, dict ): - config["prediction_decoder"] = ( - keras.saving.deserialize_keras_object(prediction_decoder) - ) + config[ + "prediction_decoder" + ] = keras.saving.deserialize_keras_object(prediction_decoder) return cls(**config) @classproperty diff --git a/keras_cv/models/stable_diffusion/noise_scheduler.py b/keras_cv/models/stable_diffusion/noise_scheduler.py index c5c100848c..bd1c0dc51e 100644 --- a/keras_cv/models/stable_diffusion/noise_scheduler.py +++ b/keras_cv/models/stable_diffusion/noise_scheduler.py @@ -54,7 +54,9 @@ def __init__( elif beta_schedule == "scaled_linear": # this schedule is very specific to the latent diffusion model. self.betas = ( - ops.linspace(beta_start**0.5, beta_end**0.5, train_timesteps) + ops.linspace( + beta_start**0.5, beta_end**0.5, train_timesteps + ) ** 2 ) else: From 3aa5c6c03a778ccb427d2dbafb9819f81d49d214 Mon Sep 17 00:00:00 2001 From: Divyashree Sreepathihalli Date: Fri, 2 Feb 2024 20:07:41 +0000 Subject: [PATCH 03/38] update encoder name --- keras_cv/models/feature_extractor/clip/clip_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/keras_cv/models/feature_extractor/clip/clip_model.py b/keras_cv/models/feature_extractor/clip/clip_model.py index 93c9676351..bf0740c4e9 100644 --- a/keras_cv/models/feature_extractor/clip/clip_model.py +++ b/keras_cv/models/feature_extractor/clip/clip_model.py @@ -128,7 +128,7 @@ def __init__( layers=vision_layers, heads=vision_heads, output_dim=embed_dim, - name="clip_encoder", + name="image_encoder", ) self.text_encoder = CLIPTextEncoder( transformer_width=transformer_width, From 1f648b35b299392d13c096f9b332cd6e5beafe66 Mon Sep 17 00:00:00 2001 From: Divyashree Sreepathihalli Date: Fri, 2 Feb 2024 20:19:24 +0000 Subject: [PATCH 04/38] update clip encoder name --- keras_cv/models/feature_extractor/clip/clip_text_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/keras_cv/models/feature_extractor/clip/clip_text_model.py b/keras_cv/models/feature_extractor/clip/clip_text_model.py index cd6ba0add5..3665e0b741 100644 --- a/keras_cv/models/feature_extractor/clip/clip_text_model.py +++ b/keras_cv/models/feature_extractor/clip/clip_text_model.py @@ -34,7 +34,7 @@ def __init__( layers=transformer_layers, heads=transformer_heads, attn_mask=self.build_attention_mask(), - name="residual_transformer_encoder", + name="clip_encoder", ) self.ln_final = keras.layers.LayerNormalization(name="ln_final") From 3c4743dd523c412c2c2f0dadca7d09beaae3f999 Mon Sep 17 00:00:00 2001 From: Divyashree Sreepathihalli Date: Fri, 2 Feb 2024 20:23:08 +0000 Subject: [PATCH 05/38] update clip encoder name in image encoder --- keras_cv/models/feature_extractor/clip/clip_image_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/keras_cv/models/feature_extractor/clip/clip_image_model.py b/keras_cv/models/feature_extractor/clip/clip_image_model.py index 265399f18b..b436e1a8d7 100644 --- a/keras_cv/models/feature_extractor/clip/clip_image_model.py +++ b/keras_cv/models/feature_extractor/clip/clip_image_model.py @@ -114,7 +114,7 @@ def __init__( width, layers, heads, - name="residual_transformer_encoder", + name="clip_encoder", ) self.post_norm = keras.layers.LayerNormalization( epsilon=1e-5, name="ln_2" From 54ec6e55c75c3ed02e149641a201a111bdd20895 Mon Sep 17 00:00:00 2001 From: Divyashree Sreepathihalli Date: Fri, 2 Feb 2024 23:13:57 +0000 Subject: [PATCH 06/38] add weights conversion script --- .../clip_weights_conversion.ipynb | 765 ++++++++++++++++++ requirements-common.txt | 1 + 2 files changed, 766 insertions(+) create mode 100644 keras_cv/tools/checkpoint_conversion/clip_weights_conversion.ipynb diff --git a/keras_cv/tools/checkpoint_conversion/clip_weights_conversion.ipynb b/keras_cv/tools/checkpoint_conversion/clip_weights_conversion.ipynb new file mode 100644 index 0000000000..4b5e89ef76 --- /dev/null +++ b/keras_cv/tools/checkpoint_conversion/clip_weights_conversion.ipynb @@ -0,0 +1,765 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "provenance": [] + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "language_info": { + "name": "python" + } + }, + "cells": [ + { + "cell_type": "markdown", + "source": [ + "# Setup" + ], + "metadata": { + "id": "0DhV6hzOMY0W" + } + }, + { + "cell_type": "code", + "source": [ + "!pip install -q git+https://github.com/divyashreepathihalli/keras-cv.git@CLIP_refactor\n", + "!pip install -q keras-nlp\n", + "!pip install -q tf-keras\n", + "!pip install -q tensorflow-text\n", + "!pip install keras==3.0.2" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "cRzYR-oFgxt1", + "outputId": "e4b01fcd-9f71-4ba7-b8a2-1796f7ef260d" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + " Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n", + " Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n", + " Preparing metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m950.8/950.8 kB\u001b[0m \u001b[31m5.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25h Building wheel for keras-cv (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m415.4/415.4 kB\u001b[0m \u001b[31m5.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m5.2/5.2 MB\u001b[0m \u001b[31m17.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hCollecting keras==3.0.2\n", + " Downloading keras-3.0.2-py3-none-any.whl (1.0 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.0/1.0 MB\u001b[0m \u001b[31m8.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hRequirement already satisfied: absl-py in /usr/local/lib/python3.10/dist-packages (from keras==3.0.2) (1.4.0)\n", + "Requirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (from keras==3.0.2) (1.23.5)\n", + "Requirement already satisfied: rich in /usr/local/lib/python3.10/dist-packages (from keras==3.0.2) (13.7.0)\n", + "Requirement already satisfied: namex in /usr/local/lib/python3.10/dist-packages (from keras==3.0.2) (0.0.7)\n", + "Requirement already satisfied: h5py in /usr/local/lib/python3.10/dist-packages (from keras==3.0.2) (3.9.0)\n", + "Requirement already satisfied: dm-tree in /usr/local/lib/python3.10/dist-packages (from keras==3.0.2) (0.1.8)\n", + "Requirement already satisfied: markdown-it-py>=2.2.0 in /usr/local/lib/python3.10/dist-packages (from rich->keras==3.0.2) (3.0.0)\n", + "Requirement already satisfied: pygments<3.0.0,>=2.13.0 in /usr/local/lib/python3.10/dist-packages (from rich->keras==3.0.2) (2.16.1)\n", + "Requirement already satisfied: mdurl~=0.1 in /usr/local/lib/python3.10/dist-packages (from markdown-it-py>=2.2.0->rich->keras==3.0.2) (0.1.2)\n", + "Installing collected packages: keras\n", + " Attempting uninstall: keras\n", + " Found existing installation: keras 2.15.0\n", + " Uninstalling keras-2.15.0:\n", + " Successfully uninstalled keras-2.15.0\n", + "\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n", + "tensorflow 2.15.0 requires keras<2.16,>=2.15.0, but you have keras 3.0.2 which is incompatible.\u001b[0m\u001b[31m\n", + "\u001b[0mSuccessfully installed keras-3.0.2\n" + ] + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "# Import" + ], + "metadata": { + "id": "mdGT8Em4Mc4b" + } + }, + { + "cell_type": "code", + "source": [ + "from keras_cv.models.feature_extractor.clip import CLIPProcessor\n", + "import keras\n", + "from keras_cv.models import CLIP" + ], + "metadata": { + "id": "GDvJmQuug4-x" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "!wget https://i.imgur.com/8H7XCH0.jpg -O cat.jpg\n", + "!wget http://images.cocodataset.org/val2017/000000039769.jpg -O test.jpg" + ], + "metadata": { + "id": "nuFgha2jTshi", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "b99d73eb-cc97-47d0-f46e-687c9e8b8237" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "--2024-02-02 22:19:20-- https://i.imgur.com/8H7XCH0.jpg\n", + "Resolving i.imgur.com (i.imgur.com)... 151.101.52.193\n", + "Connecting to i.imgur.com (i.imgur.com)|151.101.52.193|:443... connected.\n", + "HTTP request sent, awaiting response... 200 OK\n", + "Length: 44544 (44K) [image/jpeg]\n", + "Saving to: ‘cat.jpg’\n", + "\n", + "\rcat.jpg 0%[ ] 0 --.-KB/s \rcat.jpg 100%[===================>] 43.50K --.-KB/s in 0.01s \n", + "\n", + "2024-02-02 22:19:20 (3.58 MB/s) - ‘cat.jpg’ saved [44544/44544]\n", + "\n", + "--2024-02-02 22:19:20-- http://images.cocodataset.org/val2017/000000039769.jpg\n", + "Resolving images.cocodataset.org (images.cocodataset.org)... 52.216.78.4, 3.5.1.13, 52.217.139.73, ...\n", + "Connecting to images.cocodataset.org (images.cocodataset.org)|52.216.78.4|:80... connected.\n", + "HTTP request sent, awaiting response... 200 OK\n", + "Length: 173131 (169K) [image/jpeg]\n", + "Saving to: ‘test.jpg’\n", + "\n", + "test.jpg 100%[===================>] 169.07K --.-KB/s in 0.06s \n", + "\n", + "2024-02-02 22:19:20 (2.67 MB/s) - ‘test.jpg’ saved [173131/173131]\n", + "\n" + ] + } + ] + }, + { + "cell_type": "code", + "source": [ + "#@title Select which model weights you would like to convert\n", + "MODEL_CONFIGS = {\n", + " \"CLIP_B32\": {\n", + " \"embed_dim\": 512,\n", + " \"context_length\": 77,\n", + " \"vocab_size\": 49408,\n", + " \"transformer_width\": 512,\n", + " \"transformer_heads\": 8,\n", + " \"transformer_layers\": 12,\n", + " \"vision_layers\": 12,\n", + " \"vision_width\": 768,\n", + " \"image_resolution\": 224,\n", + " \"vision_patch_size\": 32,\n", + " },\n", + " \"CLIP_B16\": {\n", + " \"embed_dim\": 512,\n", + " \"context_length\": 77,\n", + " \"vocab_size\": 49408,\n", + " \"transformer_width\": 512,\n", + " \"transformer_heads\": 8,\n", + " \"transformer_layers\": 12,\n", + " \"vision_layers\": 12,\n", + " \"vision_width\": 768,\n", + " \"image_resolution\": 224,\n", + " \"vision_patch_size\": 16,\n", + " },\n", + " \"CLIP_L14\": {\n", + " \"embed_dim\": 768,\n", + " \"context_length\": 77,\n", + " \"vocab_size\": 49408,\n", + " \"transformer_width\": 768,\n", + " \"transformer_heads\": 12,\n", + " \"transformer_layers\": 12,\n", + " \"vision_layers\": 24,\n", + " \"vision_width\": 1024,\n", + " \"image_resolution\": 224,\n", + " \"vision_patch_size\": 14,\n", + " },\n", + " \"CLIP_L14_336\": {\n", + " \"embed_dim\": 768,\n", + " \"context_length\": 77,\n", + " \"vocab_size\": 49408,\n", + " \"transformer_width\": 768,\n", + " \"transformer_heads\": 12,\n", + " \"transformer_layers\": 12,\n", + " \"vision_layers\": 24,\n", + " \"vision_width\": 1024,\n", + " \"image_resolution\": 336,\n", + " \"vision_patch_size\": 14,\n", + " },\n", + "}\n", + "model_map_hf = {\n", + " \"CLIP_B16\": \"openai/clip-vit-base-patch32\",\n", + " \"CLIP_B32\": \"openai/clip-vit-base-patch16\",\n", + " \"CLIP_L14\": \"openai/clip-vit-large-patch14\",\n", + " \"CLIP_L14_336\": \"openai/clip-vit-large-patch14-336\",\n", + "}\n", + "config_name = 'CLIP_L14_336' # @param [\"CLIP_B16\", \"CLIP_B32\", \"CLIP_L14\", \"CLIP_L14_336\"]\n", + "config_name_hf = model_map_hf[config_name]" + ], + "metadata": { + "cellView": "form", + "id": "X3kkmK6h_gFH" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "# Keras 3 CLIP" + ], + "metadata": { + "id": "2l3Ll7dMMd-m" + } + }, + { + "cell_type": "code", + "source": [ + "embed_dim = MODEL_CONFIGS[config_name][\"embed_dim\"]\n", + "context_length = MODEL_CONFIGS[config_name][\"context_length\"]\n", + "vocab_size = MODEL_CONFIGS[config_name][\"vocab_size\"]\n", + "transformer_width = MODEL_CONFIGS[config_name][\"transformer_width\"]\n", + "transformer_heads = MODEL_CONFIGS[config_name][\"transformer_heads\"]\n", + "transformer_layers = MODEL_CONFIGS[config_name][\"transformer_layers\"]\n", + "vision_layers = MODEL_CONFIGS[config_name][\"vision_layers\"]\n", + "vision_width = MODEL_CONFIGS[config_name][\"vision_width\"]\n", + "vision_patch_size = MODEL_CONFIGS[config_name][\"vision_patch_size\"]\n", + "image_resolution = MODEL_CONFIGS[config_name][\"image_resolution\"]\n", + "model = CLIP(\n", + " embed_dim,\n", + " image_resolution,\n", + " vision_layers,\n", + " vision_width,\n", + " vision_patch_size,\n", + " context_length,\n", + " vocab_size,\n", + " transformer_width,\n", + " transformer_heads,\n", + " transformer_layers,\n", + ")" + ], + "metadata": { + "id": "urhuhwq0Dczo" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 193 + }, + "id": "uE6x7gfqa3Ee", + "outputId": "9a080569-7ab9-49ad-8589-87f335ef2f31" + }, + "outputs": [ + { + "output_type": "display_data", + "data": { + "text/plain": [ + "\u001b[1mModel: \"clip\"\u001b[0m\n" + ], + "text/html": [ + "
Model: \"clip\"\n",
+              "
\n" + ] + }, + "metadata": {} + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━┓\n", + "┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Param #\u001b[0m\u001b[1m \u001b[0m┃\n", + "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━┩\n", + "│ image_encoder (\u001b[38;5;33mCLIPImageEncoder\u001b[0m) │ ? │ \u001b[38;5;34m0\u001b[0m (unbuilt) │\n", + "├────────────────────────────────────┼───────────────────────────────┼─────────────┤\n", + "│ text_encoder (\u001b[38;5;33mCLIPTextEncoder\u001b[0m) │ ? │ \u001b[38;5;34m0\u001b[0m (unbuilt) │\n", + "└────────────────────────────────────┴───────────────────────────────┴─────────────┘\n" + ], + "text/html": [ + "
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━┓\n",
+              "┃ Layer (type)                        Output Shape                       Param # ┃\n",
+              "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━┩\n",
+              "│ image_encoder (CLIPImageEncoder)   │ ?                             │ 0 (unbuilt) │\n",
+              "├────────────────────────────────────┼───────────────────────────────┼─────────────┤\n",
+              "│ text_encoder (CLIPTextEncoder)     │ ?                             │ 0 (unbuilt) │\n",
+              "└────────────────────────────────────┴───────────────────────────────┴─────────────┘\n",
+              "
\n" + ] + }, + "metadata": {} + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "\u001b[1m Total params: \u001b[0m\u001b[38;5;34m39,425\u001b[0m (154.00 KB)\n" + ], + "text/html": [ + "
 Total params: 39,425 (154.00 KB)\n",
+              "
\n" + ] + }, + "metadata": {} + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m39,425\u001b[0m (154.00 KB)\n" + ], + "text/html": [ + "
 Trainable params: 39,425 (154.00 KB)\n",
+              "
\n" + ] + }, + "metadata": {} + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m0\u001b[0m (0.00 B)\n" + ], + "text/html": [ + "
 Non-trainable params: 0 (0.00 B)\n",
+              "
\n" + ] + }, + "metadata": {} + } + ], + "source": [ + "model.summary()" + ] + }, + { + "cell_type": "code", + "source": [ + "processor = CLIPProcessor(224, \"vocab.json\", \"merges.txt\")\n", + "image = processor.process_images([\"cat.jpg\"])\n", + "text_input = [\"photo of a cat on a tortoise\", \"tortoise on a dog\", \"a photo of a tortoise\"]\n", + "text = processor.process_texts(text_input)" + ], + "metadata": { + "id": "buXKlNfGTenW" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "image_logits, text_logits = model(image, text)\n", + "output = keras.layers.Softmax()(image_logits)\n", + "print(image_logits)\n", + "print(text_input[keras.ops.argmax(output)])" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "BHSpMv0PT5SX", + "outputId": "566c92c4-fbf3-4e2d-87f1-6112b2cff96f" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "tf.Tensor([[ 0.42190465 0.6262117 -0.2368357 ]], shape=(1, 3), dtype=float32)\n", + "tortoise on a dog\n" + ] + } + ] + }, + { + "cell_type": "code", + "source": [ + "model.summary()" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 193 + }, + "id": "GgNBvYCTtmA3", + "outputId": "35b9a26c-325e-4535-c33b-3f67ab112e19" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "display_data", + "data": { + "text/plain": [ + "\u001b[1mModel: \"clip\"\u001b[0m\n" + ], + "text/html": [ + "
Model: \"clip\"\n",
+              "
\n" + ] + }, + "metadata": {} + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━┓\n", + "┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Param #\u001b[0m\u001b[1m \u001b[0m┃\n", + "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━┩\n", + "│ image_encoder (\u001b[38;5;33mCLIPImageEncoder\u001b[0m) │ ? │ \u001b[38;5;34m87,849,216\u001b[0m │\n", + "├────────────────────────────────────┼───────────────────────────────┼─────────────┤\n", + "│ text_encoder (\u001b[38;5;33mCLIPTextEncoder\u001b[0m) │ ? │ \u001b[38;5;34m63,428,096\u001b[0m │\n", + "└────────────────────────────────────┴───────────────────────────────┴─────────────┘\n" + ], + "text/html": [ + "
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━┓\n",
+              "┃ Layer (type)                        Output Shape                       Param # ┃\n",
+              "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━┩\n",
+              "│ image_encoder (CLIPImageEncoder)   │ ?                             │  87,849,216 │\n",
+              "├────────────────────────────────────┼───────────────────────────────┼─────────────┤\n",
+              "│ text_encoder (CLIPTextEncoder)     │ ?                             │  63,428,096 │\n",
+              "└────────────────────────────────────┴───────────────────────────────┴─────────────┘\n",
+              "
\n" + ] + }, + "metadata": {} + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "\u001b[1m Total params: \u001b[0m\u001b[38;5;34m151,277,313\u001b[0m (577.08 MB)\n" + ], + "text/html": [ + "
 Total params: 151,277,313 (577.08 MB)\n",
+              "
\n" + ] + }, + "metadata": {} + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m151,277,313\u001b[0m (577.08 MB)\n" + ], + "text/html": [ + "
 Trainable params: 151,277,313 (577.08 MB)\n",
+              "
\n" + ] + }, + "metadata": {} + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m0\u001b[0m (0.00 B)\n" + ], + "text/html": [ + "
 Non-trainable params: 0 (0.00 B)\n",
+              "
\n" + ] + }, + "metadata": {} + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "# HF CLIP" + ], + "metadata": { + "id": "P8DWYq_hVFnz" + } + }, + { + "cell_type": "code", + "source": [ + "from PIL import Image\n", + "import requests\n", + "\n", + "from transformers import CLIPProcessor as CP\n", + "from transformers import CLIPModel as CM\n", + "\n", + "\n", + "\n" + ], + "metadata": { + "id": "3W2prd6C0pxe" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "model_hf = CM.from_pretrained(config_name_hf)\n", + "processor = CP.from_pretrained(config_name_hf)" + ], + "metadata": { + "id": "EntuvOq1MhwU", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "e154a367-2f94-4fa1-e97d-d2f32db7a2cf" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "name": "stderr", + "text": [ + "`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The value `text_config[\"id2label\"]` will be overriden.\n", + "`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The value `text_config[\"bos_token_id\"]` will be overriden.\n", + "`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The value `text_config[\"eos_token_id\"]` will be overriden.\n" + ] + } + ] + }, + { + "cell_type": "code", + "source": [ + "\n", + "url = \"https://i.imgur.com/8H7XCH0.jpg\"\n", + "image_hf = Image.open(requests.get(url, stream=True).raw)\n", + "text_inputs = [\"photo of a cat on a tortoise\", \"tortoise on a dog\", \"a photo of a tortoise\"]\n", + "inputs = processor(text=text_inputs, images=image_hf, return_tensors=\"pt\", padding=True)\n", + "\n", + "outputs = model_hf(**inputs)\n", + "logits_per_image = outputs.logits_per_image # this is the image-text similarity score\n", + "probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilitiesprobs\n", + "probs" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "Ep8DRTkv3AwS", + "outputId": "770756bc-8829-484f-b6e5-763fe81e24d0" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "tensor([[0.9957, 0.0023, 0.0020]], grad_fn=)" + ] + }, + "metadata": {}, + "execution_count": 14 + } + ] + }, + { + "cell_type": "code", + "source": [ + "#hugging face weights\n", + "hf_wts = model_hf.state_dict()" + ], + "metadata": { + "id": "wPa0cVnY3cBC" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "# Copy weights" + ], + "metadata": { + "id": "ArkCHlVZVKfM" + } + }, + { + "cell_type": "markdown", + "source": [ + "##vision encoder" + ], + "metadata": { + "id": "TUCpKltRG4Gd" + } + }, + { + "cell_type": "code", + "source": [ + "model.logit_scale.assign(hf_wts.pop(\"logit_scale\").numpy())\n", + "model.get_layer(\"image_encoder\").get_layer(\"clip_patching_and_embedding\").class_embedding.assign(hf_wts.pop('vision_model.embeddings.class_embedding').numpy())\n", + "model.get_layer(\"image_encoder\").get_layer(\"clip_patching_and_embedding\").positional_embedding.assign(hf_wts.pop('vision_model.embeddings.position_embedding.weight').numpy())\n", + "model.get_layer(\"image_encoder\").get_layer(\"clip_patching_and_embedding\").conv1.weights[0].assign(hf_wts.pop('vision_model.embeddings.patch_embedding.weight').permute(3, 2, 1, 0).numpy())\n", + "model.get_layer(\"image_encoder\").get_layer(\"ln_1\").weights[0].assign(hf_wts.pop('vision_model.pre_layrnorm.weight').numpy())\n", + "model.get_layer(\"image_encoder\").get_layer(\"ln_1\").weights[1].assign(hf_wts.pop('vision_model.pre_layrnorm.bias').numpy())\n", + "model.get_layer(\"image_encoder\").get_layer(\"ln_2\").weights[0].assign(hf_wts.pop('vision_model.post_layernorm.weight').numpy())\n", + "model.get_layer(\"image_encoder\").get_layer(\"ln_2\").weights[1].assign(hf_wts.pop('vision_model.post_layernorm.bias').numpy())\n", + "model.get_layer(\"image_encoder\").get_layer(\"vision_projector\").weights[0].assign(hf_wts.pop('visual_projection.weight').transpose(1,0).numpy())\n" + ], + "metadata": { + "id": "tn_U02N7U2VN" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "for i in range(0,MODEL_CONFIGS[config_name][\"vision_layers\"]):\n", + " if i == 0:\n", + " residual_attention = f\"residual_attention\"\n", + " else:\n", + " residual_attention = f\"residual_attention_{i}\"\n", + " model.get_layer(\"image_encoder\").get_layer(\"clip_encoder\").resblocks.get_layer(residual_attention).attn.q_proj.weights[0].assign(hf_wts.pop(f'vision_model.encoder.layers.{i}.self_attn.q_proj.weight'))\n", + " model.get_layer(\"image_encoder\").get_layer(\"clip_encoder\").resblocks.get_layer(residual_attention).attn.q_proj.weights[1].assign(hf_wts.pop(f'vision_model.encoder.layers.{i}.self_attn.q_proj.bias'))\n", + " model.get_layer(\"image_encoder\").get_layer(\"clip_encoder\").resblocks.get_layer(residual_attention).attn.k_proj.weights[0].assign(hf_wts.pop(f'vision_model.encoder.layers.{i}.self_attn.k_proj.weight'))\n", + " model.get_layer(\"image_encoder\").get_layer(\"clip_encoder\").resblocks.get_layer(residual_attention).attn.k_proj.weights[1].assign(hf_wts.pop(f'vision_model.encoder.layers.{i}.self_attn.k_proj.bias'))\n", + " model.get_layer(\"image_encoder\").get_layer(\"clip_encoder\").resblocks.get_layer(residual_attention).attn.v_proj.weights[0].assign(hf_wts.pop(f'vision_model.encoder.layers.{i}.self_attn.v_proj.weight'))\n", + " model.get_layer(\"image_encoder\").get_layer(\"clip_encoder\").resblocks.get_layer(residual_attention).attn.v_proj.weights[1].assign(hf_wts.pop(f'vision_model.encoder.layers.{i}.self_attn.v_proj.bias'))\n", + " model.get_layer(\"image_encoder\").get_layer(\"clip_encoder\").resblocks.get_layer(residual_attention).attn.out_proj.weights[1].assign(hf_wts.pop(f'vision_model.encoder.layers.{i}.self_attn.out_proj.bias').numpy())\n", + " model.get_layer(\"image_encoder\").get_layer(\"clip_encoder\").resblocks.get_layer(residual_attention).attn.out_proj.weights[0].assign(hf_wts.pop(f'vision_model.encoder.layers.{i}.self_attn.out_proj.weight').numpy())\n", + " model.get_layer(\"image_encoder\").get_layer(\"clip_encoder\").resblocks.get_layer(residual_attention).ln_1.weights[0].assign(hf_wts.pop(f'vision_model.encoder.layers.{i}.layer_norm1.weight').numpy())\n", + " model.get_layer(\"image_encoder\").get_layer(\"clip_encoder\").resblocks.get_layer(residual_attention).ln_1.weights[1].assign(hf_wts.pop(f'vision_model.encoder.layers.{i}.layer_norm1.bias').numpy())\n", + " model.get_layer(\"image_encoder\").get_layer(\"clip_encoder\").resblocks.get_layer(residual_attention).ln_2.weights[0].assign(hf_wts.pop(f'vision_model.encoder.layers.{i}.layer_norm2.weight').numpy())\n", + " model.get_layer(\"image_encoder\").get_layer(\"clip_encoder\").resblocks.get_layer(residual_attention).ln_2.weights[1].assign(hf_wts.pop(f'vision_model.encoder.layers.{i}.layer_norm2.bias').numpy())\n", + " model.get_layer(\"image_encoder\").get_layer(\"clip_encoder\").resblocks.get_layer(residual_attention).mlp.get_layer(\"c_fc\").weights[0].assign(hf_wts.pop(f'vision_model.encoder.layers.{i}.mlp.fc1.weight').transpose(1,0).numpy())\n", + " model.get_layer(\"image_encoder\").get_layer(\"clip_encoder\").resblocks.get_layer(residual_attention).mlp.get_layer(\"c_fc\").weights[1].assign(hf_wts.pop(f'vision_model.encoder.layers.{i}.mlp.fc1.bias').numpy())\n", + " model.get_layer(\"image_encoder\").get_layer(\"clip_encoder\").resblocks.get_layer(residual_attention).mlp.get_layer(\"c_proj\").weights[0].assign(hf_wts.pop(f'vision_model.encoder.layers.{i}.mlp.fc2.weight').transpose(1,0).numpy())\n", + " model.get_layer(\"image_encoder\").get_layer(\"clip_encoder\").resblocks.get_layer(residual_attention).mlp.get_layer(\"c_proj\").weights[1].assign(hf_wts.pop(f'vision_model.encoder.layers.{i}.mlp.fc2.bias').numpy())\n", + "\n" + ], + "metadata": { + "id": "qptfuWobZcbT" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "## Text encoder" + ], + "metadata": { + "id": "1RN2aVrYG8T3" + } + }, + { + "cell_type": "code", + "source": [ + "num_transformer_layers = MODEL_CONFIGS[config_name][\"vision_layers\"]" + ], + "metadata": { + "id": "5FtDROnynb0N" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "model.get_layer(\"text_encoder\").get_layer(\"text_projector\").weights[0].assign(hf_wts.pop(\"text_projection.weight\").numpy())\n", + "model.get_layer(\"text_encoder\").get_layer(\"token_embedding\").weights[0].assign(hf_wts.pop('text_model.embeddings.token_embedding.weight').numpy())\n", + "model.get_layer(\"text_encoder\").positional_embedding.assign(hf_wts.pop('text_model.embeddings.position_embedding.weight').numpy())\n", + "model.get_layer(\"text_encoder\").get_layer(\"ln_final\").weights[0].assign(hf_wts.pop('text_model.final_layer_norm.weight'))\n", + "model.get_layer(\"text_encoder\").get_layer(\"ln_final\").weights[1].assign(hf_wts.pop('text_model.final_layer_norm.bias'))\n" + ], + "metadata": { + "id": "_1AD7TcbdWEC" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "for i in range(MODEL_CONFIGS[config_name][\"transformer_layers\"]):\n", + " model.get_layer(\"text_encoder\").get_layer(\"clip_encoder\").resblocks.get_layer(f\"residual_attention_{num_transformer_layers+i}\").attn.k_proj.weights[0].assign(hf_wts.pop(f'text_model.encoder.layers.{i}.self_attn.k_proj.weight'))\n", + " model.get_layer(\"text_encoder\").get_layer(\"clip_encoder\").resblocks.get_layer(f\"residual_attention_{num_transformer_layers+i}\").attn.k_proj.weights[1].assign(hf_wts.pop(f'text_model.encoder.layers.{i}.self_attn.k_proj.bias'))\n", + " model.get_layer(\"text_encoder\").get_layer(\"clip_encoder\").resblocks.get_layer(f\"residual_attention_{num_transformer_layers+i}\").attn.q_proj.weights[0].assign(hf_wts.pop(f'text_model.encoder.layers.{i}.self_attn.q_proj.weight'))\n", + " model.get_layer(\"text_encoder\").get_layer(\"clip_encoder\").resblocks.get_layer(f\"residual_attention_{num_transformer_layers+i}\").attn.q_proj.weights[1].assign(hf_wts.pop(f'text_model.encoder.layers.{i}.self_attn.q_proj.bias'))\n", + " model.get_layer(\"text_encoder\").get_layer(\"clip_encoder\").resblocks.get_layer(f\"residual_attention_{num_transformer_layers+i}\").attn.v_proj.weights[0].assign(hf_wts.pop(f'text_model.encoder.layers.{i}.self_attn.v_proj.weight'))\n", + " model.get_layer(\"text_encoder\").get_layer(\"clip_encoder\").resblocks.get_layer(f\"residual_attention_{num_transformer_layers+i}\").attn.v_proj.weights[1].assign(hf_wts.pop(f'text_model.encoder.layers.{i}.self_attn.v_proj.bias'))\n", + " model.get_layer(\"text_encoder\").get_layer(\"clip_encoder\").resblocks.get_layer(f\"residual_attention_{num_transformer_layers+i}\").attn.out_proj.weights[0].assign(hf_wts.pop(f'text_model.encoder.layers.{i}.self_attn.out_proj.weight'))\n", + " model.get_layer(\"text_encoder\").get_layer(\"clip_encoder\").resblocks.get_layer(f\"residual_attention_{num_transformer_layers+i}\").attn.out_proj.weights[1].assign(hf_wts.pop(f'text_model.encoder.layers.{i}.self_attn.out_proj.bias'))\n", + " model.get_layer(\"text_encoder\").get_layer(\"clip_encoder\").resblocks.get_layer(f\"residual_attention_{num_transformer_layers+i}\").ln_1.weights[0].assign(hf_wts.pop(f'text_model.encoder.layers.{i}.layer_norm1.weight').numpy())\n", + " model.get_layer(\"text_encoder\").get_layer(\"clip_encoder\").resblocks.get_layer(f\"residual_attention_{num_transformer_layers+i}\").ln_1.weights[1].assign(hf_wts.pop(f'text_model.encoder.layers.{i}.layer_norm1.bias').numpy())\n", + " model.get_layer(\"text_encoder\").get_layer(\"clip_encoder\").resblocks.get_layer(f\"residual_attention_{num_transformer_layers+i}\").ln_2.weights[0].assign(hf_wts.pop(f'text_model.encoder.layers.{i}.layer_norm2.weight').numpy())\n", + " model.get_layer(\"text_encoder\").get_layer(\"clip_encoder\").resblocks.get_layer(f\"residual_attention_{num_transformer_layers+i}\").ln_2.weights[1].assign(hf_wts.pop(f'text_model.encoder.layers.{i}.layer_norm2.bias').numpy())\n", + " model.get_layer(\"text_encoder\").get_layer(\"clip_encoder\").resblocks.get_layer(f\"residual_attention_{num_transformer_layers+i}\").mlp.get_layer(\"c_fc\").weights[0].assign(hf_wts.pop(f'text_model.encoder.layers.{i}.mlp.fc1.weight').transpose(1,0).numpy())\n", + " model.get_layer(\"text_encoder\").get_layer(\"clip_encoder\").resblocks.get_layer(f\"residual_attention_{num_transformer_layers+i}\").mlp.get_layer(\"c_fc\").weights[1].assign(hf_wts.pop(f'text_model.encoder.layers.{i}.mlp.fc1.bias').numpy())\n", + " model.get_layer(\"text_encoder\").get_layer(\"clip_encoder\").resblocks.get_layer(f\"residual_attention_{num_transformer_layers+i}\").mlp.get_layer(\"c_proj\").weights[0].assign(hf_wts.pop(f'text_model.encoder.layers.{i}.mlp.fc2.weight').transpose(1,0).numpy())\n", + " model.get_layer(\"text_encoder\").get_layer(\"clip_encoder\").resblocks.get_layer(f\"residual_attention_{num_transformer_layers+i}\").mlp.get_layer(\"c_proj\").weights[1].assign(hf_wts.pop(f'text_model.encoder.layers.{i}.mlp.fc2.bias').numpy())\n" + ], + "metadata": { + "id": "s6leOiFO6V2U" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "# verify that we copied all weights\n", + "hf_wts.keys()" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "Bgen7hxCCeZ7", + "outputId": "c777d6f1-4aa7-4f3e-8fd7-759364364c44" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "odict_keys([])" + ] + }, + "metadata": {}, + "execution_count": 22 + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "# save weights" + ], + "metadata": { + "id": "wlfDdO-mid62" + } + }, + { + "cell_type": "code", + "source": [ + "model.save_weights(\"clip-vit-base-patch32.weights.h5\")" + ], + "metadata": { + "id": "QscCUUZFiqBV" + }, + "execution_count": null, + "outputs": [] + } + ] +} \ No newline at end of file diff --git a/requirements-common.txt b/requirements-common.txt index fc21cc5f96..038a886a9b 100644 --- a/requirements-common.txt +++ b/requirements-common.txt @@ -13,4 +13,5 @@ isort black pytest build +keras-nlp namex \ No newline at end of file From 286d0c22254d1ae49767448f2560a86857bbfe84 Mon Sep 17 00:00:00 2001 From: Divyashree Sreepathihalli Date: Fri, 2 Feb 2024 23:24:34 +0000 Subject: [PATCH 07/38] update setup to install keras-nlp --- setup.py | 1 + 1 file changed, 1 insertion(+) diff --git a/setup.py b/setup.py index 19dc42248c..e00c348978 100644 --- a/setup.py +++ b/setup.py @@ -75,6 +75,7 @@ def is_pure(self): "regex", "tensorflow-datasets", "keras-core", + "keras-nlp", "kagglehub", ], extras_require={ From 209e5dafa279ae7469664e1c7fec1f4e7e168f8d Mon Sep 17 00:00:00 2001 From: Divyashree Sreepathihalli Date: Fri, 2 Feb 2024 23:28:22 +0000 Subject: [PATCH 08/38] new black formatting --- benchmarks/vectorized_randomly_zoomed_crop.py | 8 +- .../base_image_augmentation_layer.py | 12 +- .../preprocessing/random_crop_and_resize.py | 8 +- .../layers/regularization/squeeze_excite.py | 8 +- .../object_detection/box_coco_metrics.py | 6 +- .../backbones/densenet/densenet_backbone.py | 6 +- .../backbones/resnet_v1/resnet_v1_backbone.py | 6 +- .../backbones/resnet_v2/resnet_v2_backbone.py | 6 +- .../yolo_v8/yolo_v8_backbone.py | 6 +- .../yolo_v8/yolo_v8_detector.py | 6 +- .../stable_diffusion/noise_scheduler.py | 4 +- .../clip_weights_conversion.ipynb | 1689 ++++++++++------- 12 files changed, 1015 insertions(+), 750 deletions(-) diff --git a/benchmarks/vectorized_randomly_zoomed_crop.py b/benchmarks/vectorized_randomly_zoomed_crop.py index 4e807fd1ab..3a207ed2e3 100644 --- a/benchmarks/vectorized_randomly_zoomed_crop.py +++ b/benchmarks/vectorized_randomly_zoomed_crop.py @@ -249,10 +249,10 @@ def from_config(cls, config): config["zoom_factor"] ) if isinstance(config["aspect_ratio_factor"], dict): - config[ - "aspect_ratio_factor" - ] = keras.utils.deserialize_keras_object( - config["aspect_ratio_factor"] + config["aspect_ratio_factor"] = ( + keras.utils.deserialize_keras_object( + config["aspect_ratio_factor"] + ) ) return cls(**config) diff --git a/keras_cv/layers/preprocessing/base_image_augmentation_layer.py b/keras_cv/layers/preprocessing/base_image_augmentation_layer.py index ef2e9cefe7..167da7ad0b 100644 --- a/keras_cv/layers/preprocessing/base_image_augmentation_layer.py +++ b/keras_cv/layers/preprocessing/base_image_augmentation_layer.py @@ -236,15 +236,15 @@ def _compute_output_signature(self, inputs): bounding_boxes = inputs.get(BOUNDING_BOXES, None) if bounding_boxes is not None: - fn_output_signature[ - BOUNDING_BOXES - ] = self._compute_bounding_box_signature(bounding_boxes) + fn_output_signature[BOUNDING_BOXES] = ( + self._compute_bounding_box_signature(bounding_boxes) + ) segmentation_masks = inputs.get(SEGMENTATION_MASKS, None) if segmentation_masks is not None: - fn_output_signature[ - SEGMENTATION_MASKS - ] = self.compute_image_signature(segmentation_masks) + fn_output_signature[SEGMENTATION_MASKS] = ( + self.compute_image_signature(segmentation_masks) + ) keypoints = inputs.get(KEYPOINTS, None) if keypoints is not None: diff --git a/keras_cv/layers/preprocessing/random_crop_and_resize.py b/keras_cv/layers/preprocessing/random_crop_and_resize.py index 593515ad09..cd947d5835 100644 --- a/keras_cv/layers/preprocessing/random_crop_and_resize.py +++ b/keras_cv/layers/preprocessing/random_crop_and_resize.py @@ -272,10 +272,10 @@ def from_config(cls, config): config["crop_area_factor"] ) if isinstance(config["aspect_ratio_factor"], dict): - config[ - "aspect_ratio_factor" - ] = keras.utils.deserialize_keras_object( - config["aspect_ratio_factor"] + config["aspect_ratio_factor"] = ( + keras.utils.deserialize_keras_object( + config["aspect_ratio_factor"] + ) ) return cls(**config) diff --git a/keras_cv/layers/regularization/squeeze_excite.py b/keras_cv/layers/regularization/squeeze_excite.py index cb03cc6942..8cbcc5bd94 100644 --- a/keras_cv/layers/regularization/squeeze_excite.py +++ b/keras_cv/layers/regularization/squeeze_excite.py @@ -118,10 +118,10 @@ def get_config(self): @classmethod def from_config(cls, config): if isinstance(config["squeeze_activation"], dict): - config[ - "squeeze_activation" - ] = keras.saving.deserialize_keras_object( - config["squeeze_activation"] + config["squeeze_activation"] = ( + keras.saving.deserialize_keras_object( + config["squeeze_activation"] + ) ) if isinstance(config["excite_activation"], dict): config["excite_activation"] = keras.saving.deserialize_keras_object( diff --git a/keras_cv/metrics/object_detection/box_coco_metrics.py b/keras_cv/metrics/object_detection/box_coco_metrics.py index a59af8c767..47d86ba1c2 100644 --- a/keras_cv/metrics/object_detection/box_coco_metrics.py +++ b/keras_cv/metrics/object_detection/box_coco_metrics.py @@ -212,9 +212,9 @@ def result_fn(self, force=False): ) result = {} for i, key in enumerate(METRIC_NAMES): - result[ - self.name_prefix() + METRIC_MAPPING[key] - ] = py_func_result[i] + result[self.name_prefix() + METRIC_MAPPING[key]] = ( + py_func_result[i] + ) return result obj.result = types.MethodType(result_fn, obj) diff --git a/keras_cv/models/backbones/densenet/densenet_backbone.py b/keras_cv/models/backbones/densenet/densenet_backbone.py index 28109b64fa..251f3601ec 100644 --- a/keras_cv/models/backbones/densenet/densenet_backbone.py +++ b/keras_cv/models/backbones/densenet/densenet_backbone.py @@ -119,9 +119,9 @@ def __init__( name=f"conv{len(stackwise_num_repeats) + 1}", ) - pyramid_level_inputs[ - f"P{len(stackwise_num_repeats) + 1}" - ] = utils.get_tensor_input_name(x) + pyramid_level_inputs[f"P{len(stackwise_num_repeats) + 1}"] = ( + utils.get_tensor_input_name(x) + ) x = keras.layers.BatchNormalization( axis=BN_AXIS, epsilon=BN_EPSILON, name="bn" )(x) diff --git a/keras_cv/models/backbones/resnet_v1/resnet_v1_backbone.py b/keras_cv/models/backbones/resnet_v1/resnet_v1_backbone.py index 61046234d3..07c896613c 100644 --- a/keras_cv/models/backbones/resnet_v1/resnet_v1_backbone.py +++ b/keras_cv/models/backbones/resnet_v1/resnet_v1_backbone.py @@ -130,9 +130,9 @@ def __init__( first_shortcut=(block_type == "block" or stack_index > 0), name=f"v2_stack_{stack_index}", ) - pyramid_level_inputs[ - f"P{stack_index + 2}" - ] = utils.get_tensor_input_name(x) + pyramid_level_inputs[f"P{stack_index + 2}"] = ( + utils.get_tensor_input_name(x) + ) # Create model. super().__init__(inputs=inputs, outputs=x, **kwargs) diff --git a/keras_cv/models/backbones/resnet_v2/resnet_v2_backbone.py b/keras_cv/models/backbones/resnet_v2/resnet_v2_backbone.py index a31841f7fc..6a0cc74740 100644 --- a/keras_cv/models/backbones/resnet_v2/resnet_v2_backbone.py +++ b/keras_cv/models/backbones/resnet_v2/resnet_v2_backbone.py @@ -136,9 +136,9 @@ def __init__( first_shortcut=(block_type == "block" or stack_index > 0), name=f"v2_stack_{stack_index}", ) - pyramid_level_inputs[ - f"P{stack_index + 2}" - ] = utils.get_tensor_input_name(x) + pyramid_level_inputs[f"P{stack_index + 2}"] = ( + utils.get_tensor_input_name(x) + ) x = keras.layers.BatchNormalization( axis=BN_AXIS, epsilon=BN_EPSILON, name="post_bn" diff --git a/keras_cv/models/object_detection/yolo_v8/yolo_v8_backbone.py b/keras_cv/models/object_detection/yolo_v8/yolo_v8_backbone.py index f4bd99fafa..a2bf4bdd3b 100644 --- a/keras_cv/models/object_detection/yolo_v8/yolo_v8_backbone.py +++ b/keras_cv/models/object_detection/yolo_v8/yolo_v8_backbone.py @@ -178,9 +178,9 @@ def __init__( activation=activation, name=f"{stack_name}_spp_fast", ) - pyramid_level_inputs[ - f"P{stack_id + 2}" - ] = utils.get_tensor_input_name(x) + pyramid_level_inputs[f"P{stack_id + 2}"] = ( + utils.get_tensor_input_name(x) + ) super().__init__(inputs=inputs, outputs=x, **kwargs) self.pyramid_level_inputs = pyramid_level_inputs diff --git a/keras_cv/models/object_detection/yolo_v8/yolo_v8_detector.py b/keras_cv/models/object_detection/yolo_v8/yolo_v8_detector.py index bfba44945c..6c17c71a72 100644 --- a/keras_cv/models/object_detection/yolo_v8/yolo_v8_detector.py +++ b/keras_cv/models/object_detection/yolo_v8/yolo_v8_detector.py @@ -663,9 +663,9 @@ def from_config(cls, config): if prediction_decoder is not None and isinstance( prediction_decoder, dict ): - config[ - "prediction_decoder" - ] = keras.saving.deserialize_keras_object(prediction_decoder) + config["prediction_decoder"] = ( + keras.saving.deserialize_keras_object(prediction_decoder) + ) return cls(**config) @classproperty diff --git a/keras_cv/models/stable_diffusion/noise_scheduler.py b/keras_cv/models/stable_diffusion/noise_scheduler.py index bd1c0dc51e..c5c100848c 100644 --- a/keras_cv/models/stable_diffusion/noise_scheduler.py +++ b/keras_cv/models/stable_diffusion/noise_scheduler.py @@ -54,9 +54,7 @@ def __init__( elif beta_schedule == "scaled_linear": # this schedule is very specific to the latent diffusion model. self.betas = ( - ops.linspace( - beta_start**0.5, beta_end**0.5, train_timesteps - ) + ops.linspace(beta_start**0.5, beta_end**0.5, train_timesteps) ** 2 ) else: diff --git a/keras_cv/tools/checkpoint_conversion/clip_weights_conversion.ipynb b/keras_cv/tools/checkpoint_conversion/clip_weights_conversion.ipynb index 4b5e89ef76..9e4a771b5f 100644 --- a/keras_cv/tools/checkpoint_conversion/clip_weights_conversion.ipynb +++ b/keras_cv/tools/checkpoint_conversion/clip_weights_conversion.ipynb @@ -1,765 +1,1032 @@ { - "nbformat": 4, - "nbformat_minor": 0, - "metadata": { + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "provenance": [] + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "language_info": { + "name": "python" + } + }, + "cells": [ + { + "cell_type": "markdown", + "source": [ + "# Setup" + ], + "metadata": { + "id": "0DhV6hzOMY0W" + } + }, + { + "cell_type": "code", + "source": [ + "!pip install -q git+https://github.com/divyashreepathihalli/keras-cv.git@CLIP_refactor\n", + "!pip install -q keras-nlp\n", + "!pip install -q tf-keras\n", + "!pip install -q tensorflow-text\n", + "!pip install keras==3.0.2" + ], + "metadata": { "colab": { - "provenance": [] + "base_uri": "https://localhost:8080/" }, - "kernelspec": { - "name": "python3", - "display_name": "Python 3" - }, - "language_info": { - "name": "python" + "id": "cRzYR-oFgxt1", + "outputId": "e4b01fcd-9f71-4ba7-b8a2-1796f7ef260d" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + " Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n", + " Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n", + " Preparing metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m950.8/950.8 kB\u001b[0m \u001b[31m5.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25h Building wheel for keras-cv (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m415.4/415.4 kB\u001b[0m \u001b[31m5.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m5.2/5.2 MB\u001b[0m \u001b[31m17.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hCollecting keras==3.0.2\n", + " Downloading keras-3.0.2-py3-none-any.whl (1.0 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.0/1.0 MB\u001b[0m \u001b[31m8.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hRequirement already satisfied: absl-py in /usr/local/lib/python3.10/dist-packages (from keras==3.0.2) (1.4.0)\n", + "Requirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (from keras==3.0.2) (1.23.5)\n", + "Requirement already satisfied: rich in /usr/local/lib/python3.10/dist-packages (from keras==3.0.2) (13.7.0)\n", + "Requirement already satisfied: namex in /usr/local/lib/python3.10/dist-packages (from keras==3.0.2) (0.0.7)\n", + "Requirement already satisfied: h5py in /usr/local/lib/python3.10/dist-packages (from keras==3.0.2) (3.9.0)\n", + "Requirement already satisfied: dm-tree in /usr/local/lib/python3.10/dist-packages (from keras==3.0.2) (0.1.8)\n", + "Requirement already satisfied: markdown-it-py>=2.2.0 in /usr/local/lib/python3.10/dist-packages (from rich->keras==3.0.2) (3.0.0)\n", + "Requirement already satisfied: pygments<3.0.0,>=2.13.0 in /usr/local/lib/python3.10/dist-packages (from rich->keras==3.0.2) (2.16.1)\n", + "Requirement already satisfied: mdurl~=0.1 in /usr/local/lib/python3.10/dist-packages (from markdown-it-py>=2.2.0->rich->keras==3.0.2) (0.1.2)\n", + "Installing collected packages: keras\n", + " Attempting uninstall: keras\n", + " Found existing installation: keras 2.15.0\n", + " Uninstalling keras-2.15.0:\n", + " Successfully uninstalled keras-2.15.0\n", + "\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n", + "tensorflow 2.15.0 requires keras<2.16,>=2.15.0, but you have keras 3.0.2 which is incompatible.\u001b[0m\u001b[31m\n", + "\u001b[0mSuccessfully installed keras-3.0.2\n" + ] } + ] }, - "cells": [ - { - "cell_type": "markdown", - "source": [ - "# Setup" - ], - "metadata": { - "id": "0DhV6hzOMY0W" - } - }, - { - "cell_type": "code", - "source": [ - "!pip install -q git+https://github.com/divyashreepathihalli/keras-cv.git@CLIP_refactor\n", - "!pip install -q keras-nlp\n", - "!pip install -q tf-keras\n", - "!pip install -q tensorflow-text\n", - "!pip install keras==3.0.2" - ], - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "cRzYR-oFgxt1", - "outputId": "e4b01fcd-9f71-4ba7-b8a2-1796f7ef260d" - }, - "execution_count": null, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - " Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n", - " Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n", - " Preparing metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m950.8/950.8 kB\u001b[0m \u001b[31m5.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25h Building wheel for keras-cv (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m415.4/415.4 kB\u001b[0m \u001b[31m5.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m5.2/5.2 MB\u001b[0m \u001b[31m17.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hCollecting keras==3.0.2\n", - " Downloading keras-3.0.2-py3-none-any.whl (1.0 MB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.0/1.0 MB\u001b[0m \u001b[31m8.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hRequirement already satisfied: absl-py in /usr/local/lib/python3.10/dist-packages (from keras==3.0.2) (1.4.0)\n", - "Requirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (from keras==3.0.2) (1.23.5)\n", - "Requirement already satisfied: rich in /usr/local/lib/python3.10/dist-packages (from keras==3.0.2) (13.7.0)\n", - "Requirement already satisfied: namex in /usr/local/lib/python3.10/dist-packages (from keras==3.0.2) (0.0.7)\n", - "Requirement already satisfied: h5py in /usr/local/lib/python3.10/dist-packages (from keras==3.0.2) (3.9.0)\n", - "Requirement already satisfied: dm-tree in /usr/local/lib/python3.10/dist-packages (from keras==3.0.2) (0.1.8)\n", - "Requirement already satisfied: markdown-it-py>=2.2.0 in /usr/local/lib/python3.10/dist-packages (from rich->keras==3.0.2) (3.0.0)\n", - "Requirement already satisfied: pygments<3.0.0,>=2.13.0 in /usr/local/lib/python3.10/dist-packages (from rich->keras==3.0.2) (2.16.1)\n", - "Requirement already satisfied: mdurl~=0.1 in /usr/local/lib/python3.10/dist-packages (from markdown-it-py>=2.2.0->rich->keras==3.0.2) (0.1.2)\n", - "Installing collected packages: keras\n", - " Attempting uninstall: keras\n", - " Found existing installation: keras 2.15.0\n", - " Uninstalling keras-2.15.0:\n", - " Successfully uninstalled keras-2.15.0\n", - "\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n", - "tensorflow 2.15.0 requires keras<2.16,>=2.15.0, but you have keras 3.0.2 which is incompatible.\u001b[0m\u001b[31m\n", - "\u001b[0mSuccessfully installed keras-3.0.2\n" - ] - } - ] - }, - { - "cell_type": "markdown", - "source": [ - "# Import" - ], - "metadata": { - "id": "mdGT8Em4Mc4b" - } + { + "cell_type": "markdown", + "source": [ + "# Import" + ], + "metadata": { + "id": "mdGT8Em4Mc4b" + } + }, + { + "cell_type": "code", + "source": [ + "from keras_cv.models.feature_extractor.clip import CLIPProcessor\n", + "import keras\n", + "from keras_cv.models import CLIP" + ], + "metadata": { + "id": "GDvJmQuug4-x" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "!wget https://i.imgur.com/8H7XCH0.jpg -O cat.jpg\n", + "!wget http://images.cocodataset.org/val2017/000000039769.jpg -O test.jpg" + ], + "metadata": { + "id": "nuFgha2jTshi", + "colab": { + "base_uri": "https://localhost:8080/" }, + "outputId": "b99d73eb-cc97-47d0-f46e-687c9e8b8237" + }, + "execution_count": null, + "outputs": [ { - "cell_type": "code", - "source": [ - "from keras_cv.models.feature_extractor.clip import CLIPProcessor\n", - "import keras\n", - "from keras_cv.models import CLIP" - ], - "metadata": { - "id": "GDvJmQuug4-x" - }, - "execution_count": null, - "outputs": [] + "output_type": "stream", + "name": "stdout", + "text": [ + "--2024-02-02 22:19:20-- https://i.imgur.com/8H7XCH0.jpg\n", + "Resolving i.imgur.com (i.imgur.com)... 151.101.52.193\n", + "Connecting to i.imgur.com (i.imgur.com)|151.101.52.193|:443... connected.\n", + "HTTP request sent, awaiting response... 200 OK\n", + "Length: 44544 (44K) [image/jpeg]\n", + "Saving to: ‘cat.jpg’\n", + "\n", + "\rcat.jpg 0%[ ] 0 --.-KB/s \rcat.jpg 100%[===================>] 43.50K --.-KB/s in 0.01s \n", + "\n", + "2024-02-02 22:19:20 (3.58 MB/s) - ‘cat.jpg’ saved [44544/44544]\n", + "\n", + "--2024-02-02 22:19:20-- http://images.cocodataset.org/val2017/000000039769.jpg\n", + "Resolving images.cocodataset.org (images.cocodataset.org)... 52.216.78.4, 3.5.1.13, 52.217.139.73, ...\n", + "Connecting to images.cocodataset.org (images.cocodataset.org)|52.216.78.4|:80... connected.\n", + "HTTP request sent, awaiting response... 200 OK\n", + "Length: 173131 (169K) [image/jpeg]\n", + "Saving to: ‘test.jpg’\n", + "\n", + "test.jpg 100%[===================>] 169.07K --.-KB/s in 0.06s \n", + "\n", + "2024-02-02 22:19:20 (2.67 MB/s) - ‘test.jpg’ saved [173131/173131]\n", + "\n" + ] + } + ] + }, + { + "cell_type": "code", + "source": [ + "# @title Select which model weights you would like to convert\n", + "MODEL_CONFIGS = {\n", + " \"CLIP_B32\": {\n", + " \"embed_dim\": 512,\n", + " \"context_length\": 77,\n", + " \"vocab_size\": 49408,\n", + " \"transformer_width\": 512,\n", + " \"transformer_heads\": 8,\n", + " \"transformer_layers\": 12,\n", + " \"vision_layers\": 12,\n", + " \"vision_width\": 768,\n", + " \"image_resolution\": 224,\n", + " \"vision_patch_size\": 32,\n", + " },\n", + " \"CLIP_B16\": {\n", + " \"embed_dim\": 512,\n", + " \"context_length\": 77,\n", + " \"vocab_size\": 49408,\n", + " \"transformer_width\": 512,\n", + " \"transformer_heads\": 8,\n", + " \"transformer_layers\": 12,\n", + " \"vision_layers\": 12,\n", + " \"vision_width\": 768,\n", + " \"image_resolution\": 224,\n", + " \"vision_patch_size\": 16,\n", + " },\n", + " \"CLIP_L14\": {\n", + " \"embed_dim\": 768,\n", + " \"context_length\": 77,\n", + " \"vocab_size\": 49408,\n", + " \"transformer_width\": 768,\n", + " \"transformer_heads\": 12,\n", + " \"transformer_layers\": 12,\n", + " \"vision_layers\": 24,\n", + " \"vision_width\": 1024,\n", + " \"image_resolution\": 224,\n", + " \"vision_patch_size\": 14,\n", + " },\n", + " \"CLIP_L14_336\": {\n", + " \"embed_dim\": 768,\n", + " \"context_length\": 77,\n", + " \"vocab_size\": 49408,\n", + " \"transformer_width\": 768,\n", + " \"transformer_heads\": 12,\n", + " \"transformer_layers\": 12,\n", + " \"vision_layers\": 24,\n", + " \"vision_width\": 1024,\n", + " \"image_resolution\": 336,\n", + " \"vision_patch_size\": 14,\n", + " },\n", + "}\n", + "model_map_hf = {\n", + " \"CLIP_B16\": \"openai/clip-vit-base-patch32\",\n", + " \"CLIP_B32\": \"openai/clip-vit-base-patch16\",\n", + " \"CLIP_L14\": \"openai/clip-vit-large-patch14\",\n", + " \"CLIP_L14_336\": \"openai/clip-vit-large-patch14-336\",\n", + "}\n", + "config_name = \"CLIP_L14_336\" # @param [\"CLIP_B16\", \"CLIP_B32\", \"CLIP_L14\", \"CLIP_L14_336\"]\n", + "config_name_hf = model_map_hf[config_name]" + ], + "metadata": { + "cellView": "form", + "id": "X3kkmK6h_gFH" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "# Keras 3 CLIP" + ], + "metadata": { + "id": "2l3Ll7dMMd-m" + } + }, + { + "cell_type": "code", + "source": [ + "embed_dim = MODEL_CONFIGS[config_name][\"embed_dim\"]\n", + "context_length = MODEL_CONFIGS[config_name][\"context_length\"]\n", + "vocab_size = MODEL_CONFIGS[config_name][\"vocab_size\"]\n", + "transformer_width = MODEL_CONFIGS[config_name][\"transformer_width\"]\n", + "transformer_heads = MODEL_CONFIGS[config_name][\"transformer_heads\"]\n", + "transformer_layers = MODEL_CONFIGS[config_name][\"transformer_layers\"]\n", + "vision_layers = MODEL_CONFIGS[config_name][\"vision_layers\"]\n", + "vision_width = MODEL_CONFIGS[config_name][\"vision_width\"]\n", + "vision_patch_size = MODEL_CONFIGS[config_name][\"vision_patch_size\"]\n", + "image_resolution = MODEL_CONFIGS[config_name][\"image_resolution\"]\n", + "model = CLIP(\n", + " embed_dim,\n", + " image_resolution,\n", + " vision_layers,\n", + " vision_width,\n", + " vision_patch_size,\n", + " context_length,\n", + " vocab_size,\n", + " transformer_width,\n", + " transformer_heads,\n", + " transformer_layers,\n", + ")" + ], + "metadata": { + "id": "urhuhwq0Dczo" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 193 }, + "id": "uE6x7gfqa3Ee", + "outputId": "9a080569-7ab9-49ad-8589-87f335ef2f31" + }, + "outputs": [ { - "cell_type": "code", - "source": [ - "!wget https://i.imgur.com/8H7XCH0.jpg -O cat.jpg\n", - "!wget http://images.cocodataset.org/val2017/000000039769.jpg -O test.jpg" + "output_type": "display_data", + "data": { + "text/plain": [ + "\u001b[1mModel: \"clip\"\u001b[0m\n" ], - "metadata": { - "id": "nuFgha2jTshi", - "colab": { - "base_uri": "https://localhost:8080/" - }, - "outputId": "b99d73eb-cc97-47d0-f46e-687c9e8b8237" - }, - "execution_count": null, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "--2024-02-02 22:19:20-- https://i.imgur.com/8H7XCH0.jpg\n", - "Resolving i.imgur.com (i.imgur.com)... 151.101.52.193\n", - "Connecting to i.imgur.com (i.imgur.com)|151.101.52.193|:443... connected.\n", - "HTTP request sent, awaiting response... 200 OK\n", - "Length: 44544 (44K) [image/jpeg]\n", - "Saving to: ‘cat.jpg’\n", - "\n", - "\rcat.jpg 0%[ ] 0 --.-KB/s \rcat.jpg 100%[===================>] 43.50K --.-KB/s in 0.01s \n", - "\n", - "2024-02-02 22:19:20 (3.58 MB/s) - ‘cat.jpg’ saved [44544/44544]\n", - "\n", - "--2024-02-02 22:19:20-- http://images.cocodataset.org/val2017/000000039769.jpg\n", - "Resolving images.cocodataset.org (images.cocodataset.org)... 52.216.78.4, 3.5.1.13, 52.217.139.73, ...\n", - "Connecting to images.cocodataset.org (images.cocodataset.org)|52.216.78.4|:80... connected.\n", - "HTTP request sent, awaiting response... 200 OK\n", - "Length: 173131 (169K) [image/jpeg]\n", - "Saving to: ‘test.jpg’\n", - "\n", - "test.jpg 100%[===================>] 169.07K --.-KB/s in 0.06s \n", - "\n", - "2024-02-02 22:19:20 (2.67 MB/s) - ‘test.jpg’ saved [173131/173131]\n", - "\n" - ] - } + "text/html": [ + "
Model: \"clip\"\n",
+       "
\n" ] + }, + "metadata": {} }, { - "cell_type": "code", - "source": [ - "#@title Select which model weights you would like to convert\n", - "MODEL_CONFIGS = {\n", - " \"CLIP_B32\": {\n", - " \"embed_dim\": 512,\n", - " \"context_length\": 77,\n", - " \"vocab_size\": 49408,\n", - " \"transformer_width\": 512,\n", - " \"transformer_heads\": 8,\n", - " \"transformer_layers\": 12,\n", - " \"vision_layers\": 12,\n", - " \"vision_width\": 768,\n", - " \"image_resolution\": 224,\n", - " \"vision_patch_size\": 32,\n", - " },\n", - " \"CLIP_B16\": {\n", - " \"embed_dim\": 512,\n", - " \"context_length\": 77,\n", - " \"vocab_size\": 49408,\n", - " \"transformer_width\": 512,\n", - " \"transformer_heads\": 8,\n", - " \"transformer_layers\": 12,\n", - " \"vision_layers\": 12,\n", - " \"vision_width\": 768,\n", - " \"image_resolution\": 224,\n", - " \"vision_patch_size\": 16,\n", - " },\n", - " \"CLIP_L14\": {\n", - " \"embed_dim\": 768,\n", - " \"context_length\": 77,\n", - " \"vocab_size\": 49408,\n", - " \"transformer_width\": 768,\n", - " \"transformer_heads\": 12,\n", - " \"transformer_layers\": 12,\n", - " \"vision_layers\": 24,\n", - " \"vision_width\": 1024,\n", - " \"image_resolution\": 224,\n", - " \"vision_patch_size\": 14,\n", - " },\n", - " \"CLIP_L14_336\": {\n", - " \"embed_dim\": 768,\n", - " \"context_length\": 77,\n", - " \"vocab_size\": 49408,\n", - " \"transformer_width\": 768,\n", - " \"transformer_heads\": 12,\n", - " \"transformer_layers\": 12,\n", - " \"vision_layers\": 24,\n", - " \"vision_width\": 1024,\n", - " \"image_resolution\": 336,\n", - " \"vision_patch_size\": 14,\n", - " },\n", - "}\n", - "model_map_hf = {\n", - " \"CLIP_B16\": \"openai/clip-vit-base-patch32\",\n", - " \"CLIP_B32\": \"openai/clip-vit-base-patch16\",\n", - " \"CLIP_L14\": \"openai/clip-vit-large-patch14\",\n", - " \"CLIP_L14_336\": \"openai/clip-vit-large-patch14-336\",\n", - "}\n", - "config_name = 'CLIP_L14_336' # @param [\"CLIP_B16\", \"CLIP_B32\", \"CLIP_L14\", \"CLIP_L14_336\"]\n", - "config_name_hf = model_map_hf[config_name]" - ], - "metadata": { - "cellView": "form", - "id": "X3kkmK6h_gFH" - }, - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "source": [ - "# Keras 3 CLIP" - ], - "metadata": { - "id": "2l3Ll7dMMd-m" - } - }, - { - "cell_type": "code", - "source": [ - "embed_dim = MODEL_CONFIGS[config_name][\"embed_dim\"]\n", - "context_length = MODEL_CONFIGS[config_name][\"context_length\"]\n", - "vocab_size = MODEL_CONFIGS[config_name][\"vocab_size\"]\n", - "transformer_width = MODEL_CONFIGS[config_name][\"transformer_width\"]\n", - "transformer_heads = MODEL_CONFIGS[config_name][\"transformer_heads\"]\n", - "transformer_layers = MODEL_CONFIGS[config_name][\"transformer_layers\"]\n", - "vision_layers = MODEL_CONFIGS[config_name][\"vision_layers\"]\n", - "vision_width = MODEL_CONFIGS[config_name][\"vision_width\"]\n", - "vision_patch_size = MODEL_CONFIGS[config_name][\"vision_patch_size\"]\n", - "image_resolution = MODEL_CONFIGS[config_name][\"image_resolution\"]\n", - "model = CLIP(\n", - " embed_dim,\n", - " image_resolution,\n", - " vision_layers,\n", - " vision_width,\n", - " vision_patch_size,\n", - " context_length,\n", - " vocab_size,\n", - " transformer_width,\n", - " transformer_heads,\n", - " transformer_layers,\n", - ")" - ], - "metadata": { - "id": "urhuhwq0Dczo" - }, - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/", - "height": 193 - }, - "id": "uE6x7gfqa3Ee", - "outputId": "9a080569-7ab9-49ad-8589-87f335ef2f31" - }, - "outputs": [ - { - "output_type": "display_data", - "data": { - "text/plain": [ - "\u001b[1mModel: \"clip\"\u001b[0m\n" - ], - "text/html": [ - "
Model: \"clip\"\n",
-              "
\n" - ] - }, - "metadata": {} - }, - { - "output_type": "display_data", - "data": { - "text/plain": [ - "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━┓\n", - "┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Param #\u001b[0m\u001b[1m \u001b[0m┃\n", - "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━┩\n", - "│ image_encoder (\u001b[38;5;33mCLIPImageEncoder\u001b[0m) │ ? │ \u001b[38;5;34m0\u001b[0m (unbuilt) │\n", - "├────────────────────────────────────┼───────────────────────────────┼─────────────┤\n", - "│ text_encoder (\u001b[38;5;33mCLIPTextEncoder\u001b[0m) │ ? │ \u001b[38;5;34m0\u001b[0m (unbuilt) │\n", - "└────────────────────────────────────┴───────────────────────────────┴─────────────┘\n" - ], - "text/html": [ - "
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━┓\n",
-              "┃ Layer (type)                        Output Shape                       Param # ┃\n",
-              "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━┩\n",
-              "│ image_encoder (CLIPImageEncoder)   │ ?                             │ 0 (unbuilt) │\n",
-              "├────────────────────────────────────┼───────────────────────────────┼─────────────┤\n",
-              "│ text_encoder (CLIPTextEncoder)     │ ?                             │ 0 (unbuilt) │\n",
-              "└────────────────────────────────────┴───────────────────────────────┴─────────────┘\n",
-              "
\n" - ] - }, - "metadata": {} - }, - { - "output_type": "display_data", - "data": { - "text/plain": [ - "\u001b[1m Total params: \u001b[0m\u001b[38;5;34m39,425\u001b[0m (154.00 KB)\n" - ], - "text/html": [ - "
 Total params: 39,425 (154.00 KB)\n",
-              "
\n" - ] - }, - "metadata": {} - }, - { - "output_type": "display_data", - "data": { - "text/plain": [ - "\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m39,425\u001b[0m (154.00 KB)\n" - ], - "text/html": [ - "
 Trainable params: 39,425 (154.00 KB)\n",
-              "
\n" - ] - }, - "metadata": {} - }, - { - "output_type": "display_data", - "data": { - "text/plain": [ - "\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m0\u001b[0m (0.00 B)\n" - ], - "text/html": [ - "
 Non-trainable params: 0 (0.00 B)\n",
-              "
\n" - ] - }, - "metadata": {} - } + "output_type": "display_data", + "data": { + "text/plain": [ + "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━┓\n", + "┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Param #\u001b[0m\u001b[1m \u001b[0m┃\n", + "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━┩\n", + "│ image_encoder (\u001b[38;5;33mCLIPImageEncoder\u001b[0m) │ ? │ \u001b[38;5;34m0\u001b[0m (unbuilt) │\n", + "├────────────────────────────────────┼───────────────────────────────┼─────────────┤\n", + "│ text_encoder (\u001b[38;5;33mCLIPTextEncoder\u001b[0m) │ ? │ \u001b[38;5;34m0\u001b[0m (unbuilt) │\n", + "└────────────────────────────────────┴───────────────────────────────┴─────────────┘\n" ], - "source": [ - "model.summary()" + "text/html": [ + "
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━┓\n",
+       "┃ Layer (type)                        Output Shape                       Param # ┃\n",
+       "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━┩\n",
+       "│ image_encoder (CLIPImageEncoder)   │ ?                             │ 0 (unbuilt) │\n",
+       "├────────────────────────────────────┼───────────────────────────────┼─────────────┤\n",
+       "│ text_encoder (CLIPTextEncoder)     │ ?                             │ 0 (unbuilt) │\n",
+       "└────────────────────────────────────┴───────────────────────────────┴─────────────┘\n",
+       "
\n" ] + }, + "metadata": {} }, { - "cell_type": "code", - "source": [ - "processor = CLIPProcessor(224, \"vocab.json\", \"merges.txt\")\n", - "image = processor.process_images([\"cat.jpg\"])\n", - "text_input = [\"photo of a cat on a tortoise\", \"tortoise on a dog\", \"a photo of a tortoise\"]\n", - "text = processor.process_texts(text_input)" + "output_type": "display_data", + "data": { + "text/plain": [ + "\u001b[1m Total params: \u001b[0m\u001b[38;5;34m39,425\u001b[0m (154.00 KB)\n" ], - "metadata": { - "id": "buXKlNfGTenW" - }, - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "code", - "source": [ - "image_logits, text_logits = model(image, text)\n", - "output = keras.layers.Softmax()(image_logits)\n", - "print(image_logits)\n", - "print(text_input[keras.ops.argmax(output)])" - ], - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "BHSpMv0PT5SX", - "outputId": "566c92c4-fbf3-4e2d-87f1-6112b2cff96f" - }, - "execution_count": null, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "tf.Tensor([[ 0.42190465 0.6262117 -0.2368357 ]], shape=(1, 3), dtype=float32)\n", - "tortoise on a dog\n" - ] - } + "text/html": [ + "
 Total params: 39,425 (154.00 KB)\n",
+       "
\n" ] + }, + "metadata": {} }, { - "cell_type": "code", - "source": [ - "model.summary()" + "output_type": "display_data", + "data": { + "text/plain": [ + "\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m39,425\u001b[0m (154.00 KB)\n" ], - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/", - "height": 193 - }, - "id": "GgNBvYCTtmA3", - "outputId": "35b9a26c-325e-4535-c33b-3f67ab112e19" - }, - "execution_count": null, - "outputs": [ - { - "output_type": "display_data", - "data": { - "text/plain": [ - "\u001b[1mModel: \"clip\"\u001b[0m\n" - ], - "text/html": [ - "
Model: \"clip\"\n",
-              "
\n" - ] - }, - "metadata": {} - }, - { - "output_type": "display_data", - "data": { - "text/plain": [ - "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━┓\n", - "┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Param #\u001b[0m\u001b[1m \u001b[0m┃\n", - "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━┩\n", - "│ image_encoder (\u001b[38;5;33mCLIPImageEncoder\u001b[0m) │ ? │ \u001b[38;5;34m87,849,216\u001b[0m │\n", - "├────────────────────────────────────┼───────────────────────────────┼─────────────┤\n", - "│ text_encoder (\u001b[38;5;33mCLIPTextEncoder\u001b[0m) │ ? │ \u001b[38;5;34m63,428,096\u001b[0m │\n", - "└────────────────────────────────────┴───────────────────────────────┴─────────────┘\n" - ], - "text/html": [ - "
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━┓\n",
-              "┃ Layer (type)                        Output Shape                       Param # ┃\n",
-              "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━┩\n",
-              "│ image_encoder (CLIPImageEncoder)   │ ?                             │  87,849,216 │\n",
-              "├────────────────────────────────────┼───────────────────────────────┼─────────────┤\n",
-              "│ text_encoder (CLIPTextEncoder)     │ ?                             │  63,428,096 │\n",
-              "└────────────────────────────────────┴───────────────────────────────┴─────────────┘\n",
-              "
\n" - ] - }, - "metadata": {} - }, - { - "output_type": "display_data", - "data": { - "text/plain": [ - "\u001b[1m Total params: \u001b[0m\u001b[38;5;34m151,277,313\u001b[0m (577.08 MB)\n" - ], - "text/html": [ - "
 Total params: 151,277,313 (577.08 MB)\n",
-              "
\n" - ] - }, - "metadata": {} - }, - { - "output_type": "display_data", - "data": { - "text/plain": [ - "\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m151,277,313\u001b[0m (577.08 MB)\n" - ], - "text/html": [ - "
 Trainable params: 151,277,313 (577.08 MB)\n",
-              "
\n" - ] - }, - "metadata": {} - }, - { - "output_type": "display_data", - "data": { - "text/plain": [ - "\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m0\u001b[0m (0.00 B)\n" - ], - "text/html": [ - "
 Non-trainable params: 0 (0.00 B)\n",
-              "
\n" - ] - }, - "metadata": {} - } + "text/html": [ + "
 Trainable params: 39,425 (154.00 KB)\n",
+       "
\n" ] + }, + "metadata": {} }, { - "cell_type": "markdown", - "source": [ - "# HF CLIP" + "output_type": "display_data", + "data": { + "text/plain": [ + "\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m0\u001b[0m (0.00 B)\n" ], - "metadata": { - "id": "P8DWYq_hVFnz" - } + "text/html": [ + "
 Non-trainable params: 0 (0.00 B)\n",
+       "
\n" + ] + }, + "metadata": {} + } + ], + "source": [ + "model.summary()" + ] + }, + { + "cell_type": "code", + "source": [ + "processor = CLIPProcessor(224, \"vocab.json\", \"merges.txt\")\n", + "image = processor.process_images([\"cat.jpg\"])\n", + "text_input = [\n", + " \"photo of a cat on a tortoise\",\n", + " \"tortoise on a dog\",\n", + " \"a photo of a tortoise\",\n", + "]\n", + "text = processor.process_texts(text_input)" + ], + "metadata": { + "id": "buXKlNfGTenW" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "image_logits, text_logits = model(image, text)\n", + "output = keras.layers.Softmax()(image_logits)\n", + "print(image_logits)\n", + "print(text_input[keras.ops.argmax(output)])" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" }, + "id": "BHSpMv0PT5SX", + "outputId": "566c92c4-fbf3-4e2d-87f1-6112b2cff96f" + }, + "execution_count": null, + "outputs": [ { - "cell_type": "code", - "source": [ - "from PIL import Image\n", - "import requests\n", - "\n", - "from transformers import CLIPProcessor as CP\n", - "from transformers import CLIPModel as CM\n", - "\n", - "\n", - "\n" - ], - "metadata": { - "id": "3W2prd6C0pxe" - }, - "execution_count": null, - "outputs": [] + "output_type": "stream", + "name": "stdout", + "text": [ + "tf.Tensor([[ 0.42190465 0.6262117 -0.2368357 ]], shape=(1, 3), dtype=float32)\n", + "tortoise on a dog\n" + ] + } + ] + }, + { + "cell_type": "code", + "source": [ + "model.summary()" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 193 }, + "id": "GgNBvYCTtmA3", + "outputId": "35b9a26c-325e-4535-c33b-3f67ab112e19" + }, + "execution_count": null, + "outputs": [ { - "cell_type": "code", - "source": [ - "model_hf = CM.from_pretrained(config_name_hf)\n", - "processor = CP.from_pretrained(config_name_hf)" + "output_type": "display_data", + "data": { + "text/plain": [ + "\u001b[1mModel: \"clip\"\u001b[0m\n" ], - "metadata": { - "id": "EntuvOq1MhwU", - "colab": { - "base_uri": "https://localhost:8080/" - }, - "outputId": "e154a367-2f94-4fa1-e97d-d2f32db7a2cf" - }, - "execution_count": null, - "outputs": [ - { - "output_type": "stream", - "name": "stderr", - "text": [ - "`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The value `text_config[\"id2label\"]` will be overriden.\n", - "`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The value `text_config[\"bos_token_id\"]` will be overriden.\n", - "`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The value `text_config[\"eos_token_id\"]` will be overriden.\n" - ] - } + "text/html": [ + "
Model: \"clip\"\n",
+       "
\n" ] + }, + "metadata": {} }, { - "cell_type": "code", - "source": [ - "\n", - "url = \"https://i.imgur.com/8H7XCH0.jpg\"\n", - "image_hf = Image.open(requests.get(url, stream=True).raw)\n", - "text_inputs = [\"photo of a cat on a tortoise\", \"tortoise on a dog\", \"a photo of a tortoise\"]\n", - "inputs = processor(text=text_inputs, images=image_hf, return_tensors=\"pt\", padding=True)\n", - "\n", - "outputs = model_hf(**inputs)\n", - "logits_per_image = outputs.logits_per_image # this is the image-text similarity score\n", - "probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilitiesprobs\n", - "probs" + "output_type": "display_data", + "data": { + "text/plain": [ + "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━┓\n", + "┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Param #\u001b[0m\u001b[1m \u001b[0m┃\n", + "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━┩\n", + "│ image_encoder (\u001b[38;5;33mCLIPImageEncoder\u001b[0m) │ ? │ \u001b[38;5;34m87,849,216\u001b[0m │\n", + "├────────────────────────────────────┼───────────────────────────────┼─────────────┤\n", + "│ text_encoder (\u001b[38;5;33mCLIPTextEncoder\u001b[0m) │ ? │ \u001b[38;5;34m63,428,096\u001b[0m │\n", + "└────────────────────────────────────┴───────────────────────────────┴─────────────┘\n" ], - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "Ep8DRTkv3AwS", - "outputId": "770756bc-8829-484f-b6e5-763fe81e24d0" - }, - "execution_count": null, - "outputs": [ - { - "output_type": "execute_result", - "data": { - "text/plain": [ - "tensor([[0.9957, 0.0023, 0.0020]], grad_fn=)" - ] - }, - "metadata": {}, - "execution_count": 14 - } + "text/html": [ + "
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━┓\n",
+       "┃ Layer (type)                        Output Shape                       Param # ┃\n",
+       "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━┩\n",
+       "│ image_encoder (CLIPImageEncoder)   │ ?                             │  87,849,216 │\n",
+       "├────────────────────────────────────┼───────────────────────────────┼─────────────┤\n",
+       "│ text_encoder (CLIPTextEncoder)     │ ?                             │  63,428,096 │\n",
+       "└────────────────────────────────────┴───────────────────────────────┴─────────────┘\n",
+       "
\n" ] + }, + "metadata": {} }, { - "cell_type": "code", - "source": [ - "#hugging face weights\n", - "hf_wts = model_hf.state_dict()" - ], - "metadata": { - "id": "wPa0cVnY3cBC" - }, - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "source": [ - "# Copy weights" - ], - "metadata": { - "id": "ArkCHlVZVKfM" - } - }, - { - "cell_type": "markdown", - "source": [ - "##vision encoder" - ], - "metadata": { - "id": "TUCpKltRG4Gd" - } - }, - { - "cell_type": "code", - "source": [ - "model.logit_scale.assign(hf_wts.pop(\"logit_scale\").numpy())\n", - "model.get_layer(\"image_encoder\").get_layer(\"clip_patching_and_embedding\").class_embedding.assign(hf_wts.pop('vision_model.embeddings.class_embedding').numpy())\n", - "model.get_layer(\"image_encoder\").get_layer(\"clip_patching_and_embedding\").positional_embedding.assign(hf_wts.pop('vision_model.embeddings.position_embedding.weight').numpy())\n", - "model.get_layer(\"image_encoder\").get_layer(\"clip_patching_and_embedding\").conv1.weights[0].assign(hf_wts.pop('vision_model.embeddings.patch_embedding.weight').permute(3, 2, 1, 0).numpy())\n", - "model.get_layer(\"image_encoder\").get_layer(\"ln_1\").weights[0].assign(hf_wts.pop('vision_model.pre_layrnorm.weight').numpy())\n", - "model.get_layer(\"image_encoder\").get_layer(\"ln_1\").weights[1].assign(hf_wts.pop('vision_model.pre_layrnorm.bias').numpy())\n", - "model.get_layer(\"image_encoder\").get_layer(\"ln_2\").weights[0].assign(hf_wts.pop('vision_model.post_layernorm.weight').numpy())\n", - "model.get_layer(\"image_encoder\").get_layer(\"ln_2\").weights[1].assign(hf_wts.pop('vision_model.post_layernorm.bias').numpy())\n", - "model.get_layer(\"image_encoder\").get_layer(\"vision_projector\").weights[0].assign(hf_wts.pop('visual_projection.weight').transpose(1,0).numpy())\n" - ], - "metadata": { - "id": "tn_U02N7U2VN" - }, - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "code", - "source": [ - "for i in range(0,MODEL_CONFIGS[config_name][\"vision_layers\"]):\n", - " if i == 0:\n", - " residual_attention = f\"residual_attention\"\n", - " else:\n", - " residual_attention = f\"residual_attention_{i}\"\n", - " model.get_layer(\"image_encoder\").get_layer(\"clip_encoder\").resblocks.get_layer(residual_attention).attn.q_proj.weights[0].assign(hf_wts.pop(f'vision_model.encoder.layers.{i}.self_attn.q_proj.weight'))\n", - " model.get_layer(\"image_encoder\").get_layer(\"clip_encoder\").resblocks.get_layer(residual_attention).attn.q_proj.weights[1].assign(hf_wts.pop(f'vision_model.encoder.layers.{i}.self_attn.q_proj.bias'))\n", - " model.get_layer(\"image_encoder\").get_layer(\"clip_encoder\").resblocks.get_layer(residual_attention).attn.k_proj.weights[0].assign(hf_wts.pop(f'vision_model.encoder.layers.{i}.self_attn.k_proj.weight'))\n", - " model.get_layer(\"image_encoder\").get_layer(\"clip_encoder\").resblocks.get_layer(residual_attention).attn.k_proj.weights[1].assign(hf_wts.pop(f'vision_model.encoder.layers.{i}.self_attn.k_proj.bias'))\n", - " model.get_layer(\"image_encoder\").get_layer(\"clip_encoder\").resblocks.get_layer(residual_attention).attn.v_proj.weights[0].assign(hf_wts.pop(f'vision_model.encoder.layers.{i}.self_attn.v_proj.weight'))\n", - " model.get_layer(\"image_encoder\").get_layer(\"clip_encoder\").resblocks.get_layer(residual_attention).attn.v_proj.weights[1].assign(hf_wts.pop(f'vision_model.encoder.layers.{i}.self_attn.v_proj.bias'))\n", - " model.get_layer(\"image_encoder\").get_layer(\"clip_encoder\").resblocks.get_layer(residual_attention).attn.out_proj.weights[1].assign(hf_wts.pop(f'vision_model.encoder.layers.{i}.self_attn.out_proj.bias').numpy())\n", - " model.get_layer(\"image_encoder\").get_layer(\"clip_encoder\").resblocks.get_layer(residual_attention).attn.out_proj.weights[0].assign(hf_wts.pop(f'vision_model.encoder.layers.{i}.self_attn.out_proj.weight').numpy())\n", - " model.get_layer(\"image_encoder\").get_layer(\"clip_encoder\").resblocks.get_layer(residual_attention).ln_1.weights[0].assign(hf_wts.pop(f'vision_model.encoder.layers.{i}.layer_norm1.weight').numpy())\n", - " model.get_layer(\"image_encoder\").get_layer(\"clip_encoder\").resblocks.get_layer(residual_attention).ln_1.weights[1].assign(hf_wts.pop(f'vision_model.encoder.layers.{i}.layer_norm1.bias').numpy())\n", - " model.get_layer(\"image_encoder\").get_layer(\"clip_encoder\").resblocks.get_layer(residual_attention).ln_2.weights[0].assign(hf_wts.pop(f'vision_model.encoder.layers.{i}.layer_norm2.weight').numpy())\n", - " model.get_layer(\"image_encoder\").get_layer(\"clip_encoder\").resblocks.get_layer(residual_attention).ln_2.weights[1].assign(hf_wts.pop(f'vision_model.encoder.layers.{i}.layer_norm2.bias').numpy())\n", - " model.get_layer(\"image_encoder\").get_layer(\"clip_encoder\").resblocks.get_layer(residual_attention).mlp.get_layer(\"c_fc\").weights[0].assign(hf_wts.pop(f'vision_model.encoder.layers.{i}.mlp.fc1.weight').transpose(1,0).numpy())\n", - " model.get_layer(\"image_encoder\").get_layer(\"clip_encoder\").resblocks.get_layer(residual_attention).mlp.get_layer(\"c_fc\").weights[1].assign(hf_wts.pop(f'vision_model.encoder.layers.{i}.mlp.fc1.bias').numpy())\n", - " model.get_layer(\"image_encoder\").get_layer(\"clip_encoder\").resblocks.get_layer(residual_attention).mlp.get_layer(\"c_proj\").weights[0].assign(hf_wts.pop(f'vision_model.encoder.layers.{i}.mlp.fc2.weight').transpose(1,0).numpy())\n", - " model.get_layer(\"image_encoder\").get_layer(\"clip_encoder\").resblocks.get_layer(residual_attention).mlp.get_layer(\"c_proj\").weights[1].assign(hf_wts.pop(f'vision_model.encoder.layers.{i}.mlp.fc2.bias').numpy())\n", - "\n" + "output_type": "display_data", + "data": { + "text/plain": [ + "\u001b[1m Total params: \u001b[0m\u001b[38;5;34m151,277,313\u001b[0m (577.08 MB)\n" ], - "metadata": { - "id": "qptfuWobZcbT" - }, - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "source": [ - "## Text encoder" - ], - "metadata": { - "id": "1RN2aVrYG8T3" - } + "text/html": [ + "
 Total params: 151,277,313 (577.08 MB)\n",
+       "
\n" + ] + }, + "metadata": {} }, { - "cell_type": "code", - "source": [ - "num_transformer_layers = MODEL_CONFIGS[config_name][\"vision_layers\"]" + "output_type": "display_data", + "data": { + "text/plain": [ + "\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m151,277,313\u001b[0m (577.08 MB)\n" ], - "metadata": { - "id": "5FtDROnynb0N" - }, - "execution_count": null, - "outputs": [] + "text/html": [ + "
 Trainable params: 151,277,313 (577.08 MB)\n",
+       "
\n" + ] + }, + "metadata": {} }, { - "cell_type": "code", - "source": [ - "model.get_layer(\"text_encoder\").get_layer(\"text_projector\").weights[0].assign(hf_wts.pop(\"text_projection.weight\").numpy())\n", - "model.get_layer(\"text_encoder\").get_layer(\"token_embedding\").weights[0].assign(hf_wts.pop('text_model.embeddings.token_embedding.weight').numpy())\n", - "model.get_layer(\"text_encoder\").positional_embedding.assign(hf_wts.pop('text_model.embeddings.position_embedding.weight').numpy())\n", - "model.get_layer(\"text_encoder\").get_layer(\"ln_final\").weights[0].assign(hf_wts.pop('text_model.final_layer_norm.weight'))\n", - "model.get_layer(\"text_encoder\").get_layer(\"ln_final\").weights[1].assign(hf_wts.pop('text_model.final_layer_norm.bias'))\n" + "output_type": "display_data", + "data": { + "text/plain": [ + "\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m0\u001b[0m (0.00 B)\n" ], - "metadata": { - "id": "_1AD7TcbdWEC" - }, - "execution_count": null, - "outputs": [] + "text/html": [ + "
 Non-trainable params: 0 (0.00 B)\n",
+       "
\n" + ] + }, + "metadata": {} + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "# HF CLIP" + ], + "metadata": { + "id": "P8DWYq_hVFnz" + } + }, + { + "cell_type": "code", + "source": [ + "from PIL import Image\n", + "import requests\n", + "\n", + "from transformers import CLIPProcessor as CP\n", + "from transformers import CLIPModel as CM" + ], + "metadata": { + "id": "3W2prd6C0pxe" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "model_hf = CM.from_pretrained(config_name_hf)\n", + "processor = CP.from_pretrained(config_name_hf)" + ], + "metadata": { + "id": "EntuvOq1MhwU", + "colab": { + "base_uri": "https://localhost:8080/" }, + "outputId": "e154a367-2f94-4fa1-e97d-d2f32db7a2cf" + }, + "execution_count": null, + "outputs": [ { - "cell_type": "code", - "source": [ - "for i in range(MODEL_CONFIGS[config_name][\"transformer_layers\"]):\n", - " model.get_layer(\"text_encoder\").get_layer(\"clip_encoder\").resblocks.get_layer(f\"residual_attention_{num_transformer_layers+i}\").attn.k_proj.weights[0].assign(hf_wts.pop(f'text_model.encoder.layers.{i}.self_attn.k_proj.weight'))\n", - " model.get_layer(\"text_encoder\").get_layer(\"clip_encoder\").resblocks.get_layer(f\"residual_attention_{num_transformer_layers+i}\").attn.k_proj.weights[1].assign(hf_wts.pop(f'text_model.encoder.layers.{i}.self_attn.k_proj.bias'))\n", - " model.get_layer(\"text_encoder\").get_layer(\"clip_encoder\").resblocks.get_layer(f\"residual_attention_{num_transformer_layers+i}\").attn.q_proj.weights[0].assign(hf_wts.pop(f'text_model.encoder.layers.{i}.self_attn.q_proj.weight'))\n", - " model.get_layer(\"text_encoder\").get_layer(\"clip_encoder\").resblocks.get_layer(f\"residual_attention_{num_transformer_layers+i}\").attn.q_proj.weights[1].assign(hf_wts.pop(f'text_model.encoder.layers.{i}.self_attn.q_proj.bias'))\n", - " model.get_layer(\"text_encoder\").get_layer(\"clip_encoder\").resblocks.get_layer(f\"residual_attention_{num_transformer_layers+i}\").attn.v_proj.weights[0].assign(hf_wts.pop(f'text_model.encoder.layers.{i}.self_attn.v_proj.weight'))\n", - " model.get_layer(\"text_encoder\").get_layer(\"clip_encoder\").resblocks.get_layer(f\"residual_attention_{num_transformer_layers+i}\").attn.v_proj.weights[1].assign(hf_wts.pop(f'text_model.encoder.layers.{i}.self_attn.v_proj.bias'))\n", - " model.get_layer(\"text_encoder\").get_layer(\"clip_encoder\").resblocks.get_layer(f\"residual_attention_{num_transformer_layers+i}\").attn.out_proj.weights[0].assign(hf_wts.pop(f'text_model.encoder.layers.{i}.self_attn.out_proj.weight'))\n", - " model.get_layer(\"text_encoder\").get_layer(\"clip_encoder\").resblocks.get_layer(f\"residual_attention_{num_transformer_layers+i}\").attn.out_proj.weights[1].assign(hf_wts.pop(f'text_model.encoder.layers.{i}.self_attn.out_proj.bias'))\n", - " model.get_layer(\"text_encoder\").get_layer(\"clip_encoder\").resblocks.get_layer(f\"residual_attention_{num_transformer_layers+i}\").ln_1.weights[0].assign(hf_wts.pop(f'text_model.encoder.layers.{i}.layer_norm1.weight').numpy())\n", - " model.get_layer(\"text_encoder\").get_layer(\"clip_encoder\").resblocks.get_layer(f\"residual_attention_{num_transformer_layers+i}\").ln_1.weights[1].assign(hf_wts.pop(f'text_model.encoder.layers.{i}.layer_norm1.bias').numpy())\n", - " model.get_layer(\"text_encoder\").get_layer(\"clip_encoder\").resblocks.get_layer(f\"residual_attention_{num_transformer_layers+i}\").ln_2.weights[0].assign(hf_wts.pop(f'text_model.encoder.layers.{i}.layer_norm2.weight').numpy())\n", - " model.get_layer(\"text_encoder\").get_layer(\"clip_encoder\").resblocks.get_layer(f\"residual_attention_{num_transformer_layers+i}\").ln_2.weights[1].assign(hf_wts.pop(f'text_model.encoder.layers.{i}.layer_norm2.bias').numpy())\n", - " model.get_layer(\"text_encoder\").get_layer(\"clip_encoder\").resblocks.get_layer(f\"residual_attention_{num_transformer_layers+i}\").mlp.get_layer(\"c_fc\").weights[0].assign(hf_wts.pop(f'text_model.encoder.layers.{i}.mlp.fc1.weight').transpose(1,0).numpy())\n", - " model.get_layer(\"text_encoder\").get_layer(\"clip_encoder\").resblocks.get_layer(f\"residual_attention_{num_transformer_layers+i}\").mlp.get_layer(\"c_fc\").weights[1].assign(hf_wts.pop(f'text_model.encoder.layers.{i}.mlp.fc1.bias').numpy())\n", - " model.get_layer(\"text_encoder\").get_layer(\"clip_encoder\").resblocks.get_layer(f\"residual_attention_{num_transformer_layers+i}\").mlp.get_layer(\"c_proj\").weights[0].assign(hf_wts.pop(f'text_model.encoder.layers.{i}.mlp.fc2.weight').transpose(1,0).numpy())\n", - " model.get_layer(\"text_encoder\").get_layer(\"clip_encoder\").resblocks.get_layer(f\"residual_attention_{num_transformer_layers+i}\").mlp.get_layer(\"c_proj\").weights[1].assign(hf_wts.pop(f'text_model.encoder.layers.{i}.mlp.fc2.bias').numpy())\n" - ], - "metadata": { - "id": "s6leOiFO6V2U" - }, - "execution_count": null, - "outputs": [] + "output_type": "stream", + "name": "stderr", + "text": [ + "`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The value `text_config[\"id2label\"]` will be overriden.\n", + "`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The value `text_config[\"bos_token_id\"]` will be overriden.\n", + "`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The value `text_config[\"eos_token_id\"]` will be overriden.\n" + ] + } + ] + }, + { + "cell_type": "code", + "source": [ + "url = \"https://i.imgur.com/8H7XCH0.jpg\"\n", + "image_hf = Image.open(requests.get(url, stream=True).raw)\n", + "text_inputs = [\n", + " \"photo of a cat on a tortoise\",\n", + " \"tortoise on a dog\",\n", + " \"a photo of a tortoise\",\n", + "]\n", + "inputs = processor(\n", + " text=text_inputs, images=image_hf, return_tensors=\"pt\", padding=True\n", + ")\n", + "\n", + "outputs = model_hf(**inputs)\n", + "logits_per_image = (\n", + " outputs.logits_per_image\n", + ") # this is the image-text similarity score\n", + "probs = logits_per_image.softmax(\n", + " dim=1\n", + ") # we can take the softmax to get the label probabilitiesprobs\n", + "probs" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" }, + "id": "Ep8DRTkv3AwS", + "outputId": "770756bc-8829-484f-b6e5-763fe81e24d0" + }, + "execution_count": null, + "outputs": [ { - "cell_type": "code", - "source": [ - "# verify that we copied all weights\n", - "hf_wts.keys()" - ], - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "Bgen7hxCCeZ7", - "outputId": "c777d6f1-4aa7-4f3e-8fd7-759364364c44" - }, - "execution_count": null, - "outputs": [ - { - "output_type": "execute_result", - "data": { - "text/plain": [ - "odict_keys([])" - ] - }, - "metadata": {}, - "execution_count": 22 - } + "output_type": "execute_result", + "data": { + "text/plain": [ + "tensor([[0.9957, 0.0023, 0.0020]], grad_fn=)" ] + }, + "metadata": {}, + "execution_count": 14 + } + ] + }, + { + "cell_type": "code", + "source": [ + "# hugging face weights\n", + "hf_wts = model_hf.state_dict()" + ], + "metadata": { + "id": "wPa0cVnY3cBC" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "# Copy weights" + ], + "metadata": { + "id": "ArkCHlVZVKfM" + } + }, + { + "cell_type": "markdown", + "source": [ + "##vision encoder" + ], + "metadata": { + "id": "TUCpKltRG4Gd" + } + }, + { + "cell_type": "code", + "source": [ + "model.logit_scale.assign(hf_wts.pop(\"logit_scale\").numpy())\n", + "model.get_layer(\"image_encoder\").get_layer(\n", + " \"clip_patching_and_embedding\"\n", + ").class_embedding.assign(\n", + " hf_wts.pop(\"vision_model.embeddings.class_embedding\").numpy()\n", + ")\n", + "model.get_layer(\"image_encoder\").get_layer(\n", + " \"clip_patching_and_embedding\"\n", + ").positional_embedding.assign(\n", + " hf_wts.pop(\"vision_model.embeddings.position_embedding.weight\").numpy()\n", + ")\n", + "model.get_layer(\"image_encoder\").get_layer(\n", + " \"clip_patching_and_embedding\"\n", + ").conv1.weights[0].assign(\n", + " hf_wts.pop(\"vision_model.embeddings.patch_embedding.weight\")\n", + " .permute(3, 2, 1, 0)\n", + " .numpy()\n", + ")\n", + "model.get_layer(\"image_encoder\").get_layer(\"ln_1\").weights[0].assign(\n", + " hf_wts.pop(\"vision_model.pre_layrnorm.weight\").numpy()\n", + ")\n", + "model.get_layer(\"image_encoder\").get_layer(\"ln_1\").weights[1].assign(\n", + " hf_wts.pop(\"vision_model.pre_layrnorm.bias\").numpy()\n", + ")\n", + "model.get_layer(\"image_encoder\").get_layer(\"ln_2\").weights[0].assign(\n", + " hf_wts.pop(\"vision_model.post_layernorm.weight\").numpy()\n", + ")\n", + "model.get_layer(\"image_encoder\").get_layer(\"ln_2\").weights[1].assign(\n", + " hf_wts.pop(\"vision_model.post_layernorm.bias\").numpy()\n", + ")\n", + "model.get_layer(\"image_encoder\").get_layer(\"vision_projector\").weights[\n", + " 0\n", + "].assign(hf_wts.pop(\"visual_projection.weight\").transpose(1, 0).numpy())" + ], + "metadata": { + "id": "tn_U02N7U2VN" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "for i in range(0, MODEL_CONFIGS[config_name][\"vision_layers\"]):\n", + " if i == 0:\n", + " residual_attention = f\"residual_attention\"\n", + " else:\n", + " residual_attention = f\"residual_attention_{i}\"\n", + " model.get_layer(\"image_encoder\").get_layer(\n", + " \"clip_encoder\"\n", + " ).resblocks.get_layer(residual_attention).attn.q_proj.weights[0].assign(\n", + " hf_wts.pop(f\"vision_model.encoder.layers.{i}.self_attn.q_proj.weight\")\n", + " )\n", + " model.get_layer(\"image_encoder\").get_layer(\n", + " \"clip_encoder\"\n", + " ).resblocks.get_layer(residual_attention).attn.q_proj.weights[1].assign(\n", + " hf_wts.pop(f\"vision_model.encoder.layers.{i}.self_attn.q_proj.bias\")\n", + " )\n", + " model.get_layer(\"image_encoder\").get_layer(\n", + " \"clip_encoder\"\n", + " ).resblocks.get_layer(residual_attention).attn.k_proj.weights[0].assign(\n", + " hf_wts.pop(f\"vision_model.encoder.layers.{i}.self_attn.k_proj.weight\")\n", + " )\n", + " model.get_layer(\"image_encoder\").get_layer(\n", + " \"clip_encoder\"\n", + " ).resblocks.get_layer(residual_attention).attn.k_proj.weights[1].assign(\n", + " hf_wts.pop(f\"vision_model.encoder.layers.{i}.self_attn.k_proj.bias\")\n", + " )\n", + " model.get_layer(\"image_encoder\").get_layer(\n", + " \"clip_encoder\"\n", + " ).resblocks.get_layer(residual_attention).attn.v_proj.weights[0].assign(\n", + " hf_wts.pop(f\"vision_model.encoder.layers.{i}.self_attn.v_proj.weight\")\n", + " )\n", + " model.get_layer(\"image_encoder\").get_layer(\n", + " \"clip_encoder\"\n", + " ).resblocks.get_layer(residual_attention).attn.v_proj.weights[1].assign(\n", + " hf_wts.pop(f\"vision_model.encoder.layers.{i}.self_attn.v_proj.bias\")\n", + " )\n", + " model.get_layer(\"image_encoder\").get_layer(\n", + " \"clip_encoder\"\n", + " ).resblocks.get_layer(residual_attention).attn.out_proj.weights[1].assign(\n", + " hf_wts.pop(\n", + " f\"vision_model.encoder.layers.{i}.self_attn.out_proj.bias\"\n", + " ).numpy()\n", + " )\n", + " model.get_layer(\"image_encoder\").get_layer(\n", + " \"clip_encoder\"\n", + " ).resblocks.get_layer(residual_attention).attn.out_proj.weights[0].assign(\n", + " hf_wts.pop(\n", + " f\"vision_model.encoder.layers.{i}.self_attn.out_proj.weight\"\n", + " ).numpy()\n", + " )\n", + " model.get_layer(\"image_encoder\").get_layer(\n", + " \"clip_encoder\"\n", + " ).resblocks.get_layer(residual_attention).ln_1.weights[0].assign(\n", + " hf_wts.pop(\n", + " f\"vision_model.encoder.layers.{i}.layer_norm1.weight\"\n", + " ).numpy()\n", + " )\n", + " model.get_layer(\"image_encoder\").get_layer(\n", + " \"clip_encoder\"\n", + " ).resblocks.get_layer(residual_attention).ln_1.weights[1].assign(\n", + " hf_wts.pop(f\"vision_model.encoder.layers.{i}.layer_norm1.bias\").numpy()\n", + " )\n", + " model.get_layer(\"image_encoder\").get_layer(\n", + " \"clip_encoder\"\n", + " ).resblocks.get_layer(residual_attention).ln_2.weights[0].assign(\n", + " hf_wts.pop(\n", + " f\"vision_model.encoder.layers.{i}.layer_norm2.weight\"\n", + " ).numpy()\n", + " )\n", + " model.get_layer(\"image_encoder\").get_layer(\n", + " \"clip_encoder\"\n", + " ).resblocks.get_layer(residual_attention).ln_2.weights[1].assign(\n", + " hf_wts.pop(f\"vision_model.encoder.layers.{i}.layer_norm2.bias\").numpy()\n", + " )\n", + " model.get_layer(\"image_encoder\").get_layer(\n", + " \"clip_encoder\"\n", + " ).resblocks.get_layer(residual_attention).mlp.get_layer(\"c_fc\").weights[\n", + " 0\n", + " ].assign(\n", + " hf_wts.pop(f\"vision_model.encoder.layers.{i}.mlp.fc1.weight\")\n", + " .transpose(1, 0)\n", + " .numpy()\n", + " )\n", + " model.get_layer(\"image_encoder\").get_layer(\n", + " \"clip_encoder\"\n", + " ).resblocks.get_layer(residual_attention).mlp.get_layer(\"c_fc\").weights[\n", + " 1\n", + " ].assign(\n", + " hf_wts.pop(f\"vision_model.encoder.layers.{i}.mlp.fc1.bias\").numpy()\n", + " )\n", + " model.get_layer(\"image_encoder\").get_layer(\n", + " \"clip_encoder\"\n", + " ).resblocks.get_layer(residual_attention).mlp.get_layer(\"c_proj\").weights[\n", + " 0\n", + " ].assign(\n", + " hf_wts.pop(f\"vision_model.encoder.layers.{i}.mlp.fc2.weight\")\n", + " .transpose(1, 0)\n", + " .numpy()\n", + " )\n", + " model.get_layer(\"image_encoder\").get_layer(\n", + " \"clip_encoder\"\n", + " ).resblocks.get_layer(residual_attention).mlp.get_layer(\"c_proj\").weights[\n", + " 1\n", + " ].assign(\n", + " hf_wts.pop(f\"vision_model.encoder.layers.{i}.mlp.fc2.bias\").numpy()\n", + " )" + ], + "metadata": { + "id": "qptfuWobZcbT" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "## Text encoder" + ], + "metadata": { + "id": "1RN2aVrYG8T3" + } + }, + { + "cell_type": "code", + "source": [ + "num_transformer_layers = MODEL_CONFIGS[config_name][\"vision_layers\"]" + ], + "metadata": { + "id": "5FtDROnynb0N" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "model.get_layer(\"text_encoder\").get_layer(\"text_projector\").weights[0].assign(\n", + " hf_wts.pop(\"text_projection.weight\").numpy()\n", + ")\n", + "model.get_layer(\"text_encoder\").get_layer(\"token_embedding\").weights[0].assign(\n", + " hf_wts.pop(\"text_model.embeddings.token_embedding.weight\").numpy()\n", + ")\n", + "model.get_layer(\"text_encoder\").positional_embedding.assign(\n", + " hf_wts.pop(\"text_model.embeddings.position_embedding.weight\").numpy()\n", + ")\n", + "model.get_layer(\"text_encoder\").get_layer(\"ln_final\").weights[0].assign(\n", + " hf_wts.pop(\"text_model.final_layer_norm.weight\")\n", + ")\n", + "model.get_layer(\"text_encoder\").get_layer(\"ln_final\").weights[1].assign(\n", + " hf_wts.pop(\"text_model.final_layer_norm.bias\")\n", + ")" + ], + "metadata": { + "id": "_1AD7TcbdWEC" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "for i in range(MODEL_CONFIGS[config_name][\"transformer_layers\"]):\n", + " model.get_layer(\"text_encoder\").get_layer(\n", + " \"clip_encoder\"\n", + " ).resblocks.get_layer(\n", + " f\"residual_attention_{num_transformer_layers+i}\"\n", + " ).attn.k_proj.weights[\n", + " 0\n", + " ].assign(\n", + " hf_wts.pop(f\"text_model.encoder.layers.{i}.self_attn.k_proj.weight\")\n", + " )\n", + " model.get_layer(\"text_encoder\").get_layer(\n", + " \"clip_encoder\"\n", + " ).resblocks.get_layer(\n", + " f\"residual_attention_{num_transformer_layers+i}\"\n", + " ).attn.k_proj.weights[\n", + " 1\n", + " ].assign(\n", + " hf_wts.pop(f\"text_model.encoder.layers.{i}.self_attn.k_proj.bias\")\n", + " )\n", + " model.get_layer(\"text_encoder\").get_layer(\n", + " \"clip_encoder\"\n", + " ).resblocks.get_layer(\n", + " f\"residual_attention_{num_transformer_layers+i}\"\n", + " ).attn.q_proj.weights[\n", + " 0\n", + " ].assign(\n", + " hf_wts.pop(f\"text_model.encoder.layers.{i}.self_attn.q_proj.weight\")\n", + " )\n", + " model.get_layer(\"text_encoder\").get_layer(\n", + " \"clip_encoder\"\n", + " ).resblocks.get_layer(\n", + " f\"residual_attention_{num_transformer_layers+i}\"\n", + " ).attn.q_proj.weights[\n", + " 1\n", + " ].assign(\n", + " hf_wts.pop(f\"text_model.encoder.layers.{i}.self_attn.q_proj.bias\")\n", + " )\n", + " model.get_layer(\"text_encoder\").get_layer(\n", + " \"clip_encoder\"\n", + " ).resblocks.get_layer(\n", + " f\"residual_attention_{num_transformer_layers+i}\"\n", + " ).attn.v_proj.weights[\n", + " 0\n", + " ].assign(\n", + " hf_wts.pop(f\"text_model.encoder.layers.{i}.self_attn.v_proj.weight\")\n", + " )\n", + " model.get_layer(\"text_encoder\").get_layer(\n", + " \"clip_encoder\"\n", + " ).resblocks.get_layer(\n", + " f\"residual_attention_{num_transformer_layers+i}\"\n", + " ).attn.v_proj.weights[\n", + " 1\n", + " ].assign(\n", + " hf_wts.pop(f\"text_model.encoder.layers.{i}.self_attn.v_proj.bias\")\n", + " )\n", + " model.get_layer(\"text_encoder\").get_layer(\n", + " \"clip_encoder\"\n", + " ).resblocks.get_layer(\n", + " f\"residual_attention_{num_transformer_layers+i}\"\n", + " ).attn.out_proj.weights[\n", + " 0\n", + " ].assign(\n", + " hf_wts.pop(f\"text_model.encoder.layers.{i}.self_attn.out_proj.weight\")\n", + " )\n", + " model.get_layer(\"text_encoder\").get_layer(\n", + " \"clip_encoder\"\n", + " ).resblocks.get_layer(\n", + " f\"residual_attention_{num_transformer_layers+i}\"\n", + " ).attn.out_proj.weights[\n", + " 1\n", + " ].assign(\n", + " hf_wts.pop(f\"text_model.encoder.layers.{i}.self_attn.out_proj.bias\")\n", + " )\n", + " model.get_layer(\"text_encoder\").get_layer(\n", + " \"clip_encoder\"\n", + " ).resblocks.get_layer(\n", + " f\"residual_attention_{num_transformer_layers+i}\"\n", + " ).ln_1.weights[\n", + " 0\n", + " ].assign(\n", + " hf_wts.pop(f\"text_model.encoder.layers.{i}.layer_norm1.weight\").numpy()\n", + " )\n", + " model.get_layer(\"text_encoder\").get_layer(\n", + " \"clip_encoder\"\n", + " ).resblocks.get_layer(\n", + " f\"residual_attention_{num_transformer_layers+i}\"\n", + " ).ln_1.weights[\n", + " 1\n", + " ].assign(\n", + " hf_wts.pop(f\"text_model.encoder.layers.{i}.layer_norm1.bias\").numpy()\n", + " )\n", + " model.get_layer(\"text_encoder\").get_layer(\n", + " \"clip_encoder\"\n", + " ).resblocks.get_layer(\n", + " f\"residual_attention_{num_transformer_layers+i}\"\n", + " ).ln_2.weights[\n", + " 0\n", + " ].assign(\n", + " hf_wts.pop(f\"text_model.encoder.layers.{i}.layer_norm2.weight\").numpy()\n", + " )\n", + " model.get_layer(\"text_encoder\").get_layer(\n", + " \"clip_encoder\"\n", + " ).resblocks.get_layer(\n", + " f\"residual_attention_{num_transformer_layers+i}\"\n", + " ).ln_2.weights[\n", + " 1\n", + " ].assign(\n", + " hf_wts.pop(f\"text_model.encoder.layers.{i}.layer_norm2.bias\").numpy()\n", + " )\n", + " model.get_layer(\"text_encoder\").get_layer(\n", + " \"clip_encoder\"\n", + " ).resblocks.get_layer(\n", + " f\"residual_attention_{num_transformer_layers+i}\"\n", + " ).mlp.get_layer(\n", + " \"c_fc\"\n", + " ).weights[\n", + " 0\n", + " ].assign(\n", + " hf_wts.pop(f\"text_model.encoder.layers.{i}.mlp.fc1.weight\")\n", + " .transpose(1, 0)\n", + " .numpy()\n", + " )\n", + " model.get_layer(\"text_encoder\").get_layer(\n", + " \"clip_encoder\"\n", + " ).resblocks.get_layer(\n", + " f\"residual_attention_{num_transformer_layers+i}\"\n", + " ).mlp.get_layer(\n", + " \"c_fc\"\n", + " ).weights[\n", + " 1\n", + " ].assign(\n", + " hf_wts.pop(f\"text_model.encoder.layers.{i}.mlp.fc1.bias\").numpy()\n", + " )\n", + " model.get_layer(\"text_encoder\").get_layer(\n", + " \"clip_encoder\"\n", + " ).resblocks.get_layer(\n", + " f\"residual_attention_{num_transformer_layers+i}\"\n", + " ).mlp.get_layer(\n", + " \"c_proj\"\n", + " ).weights[\n", + " 0\n", + " ].assign(\n", + " hf_wts.pop(f\"text_model.encoder.layers.{i}.mlp.fc2.weight\")\n", + " .transpose(1, 0)\n", + " .numpy()\n", + " )\n", + " model.get_layer(\"text_encoder\").get_layer(\n", + " \"clip_encoder\"\n", + " ).resblocks.get_layer(\n", + " f\"residual_attention_{num_transformer_layers+i}\"\n", + " ).mlp.get_layer(\n", + " \"c_proj\"\n", + " ).weights[\n", + " 1\n", + " ].assign(\n", + " hf_wts.pop(f\"text_model.encoder.layers.{i}.mlp.fc2.bias\").numpy()\n", + " )" + ], + "metadata": { + "id": "s6leOiFO6V2U" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "# verify that we copied all weights\n", + "hf_wts.keys()" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" }, + "id": "Bgen7hxCCeZ7", + "outputId": "c777d6f1-4aa7-4f3e-8fd7-759364364c44" + }, + "execution_count": null, + "outputs": [ { - "cell_type": "markdown", - "source": [ - "# save weights" - ], - "metadata": { - "id": "wlfDdO-mid62" - } - }, - { - "cell_type": "code", - "source": [ - "model.save_weights(\"clip-vit-base-patch32.weights.h5\")" - ], - "metadata": { - "id": "QscCUUZFiqBV" - }, - "execution_count": null, - "outputs": [] + "output_type": "execute_result", + "data": { + "text/plain": [ + "odict_keys([])" + ] + }, + "metadata": {}, + "execution_count": 22 } - ] + ] + }, + { + "cell_type": "markdown", + "source": [ + "# save weights" + ], + "metadata": { + "id": "wlfDdO-mid62" + } + }, + { + "cell_type": "code", + "source": [ + "model.save_weights(\"clip-vit-base-patch32.weights.h5\")" + ], + "metadata": { + "id": "QscCUUZFiqBV" + }, + "execution_count": null, + "outputs": [] + } + ] } \ No newline at end of file From 91e6ea92a5513f783dc8d63baa20373d52b1f79c Mon Sep 17 00:00:00 2001 From: Divyashree Sreepathihalli Date: Fri, 2 Feb 2024 23:52:38 +0000 Subject: [PATCH 09/38] add preset file --- .../feature_extractor/clip/clip_presets.py | 68 +++++++++++++++++++ 1 file changed, 68 insertions(+) create mode 100644 keras_cv/models/feature_extractor/clip/clip_presets.py diff --git a/keras_cv/models/feature_extractor/clip/clip_presets.py b/keras_cv/models/feature_extractor/clip/clip_presets.py new file mode 100644 index 0000000000..9f6dda87aa --- /dev/null +++ b/keras_cv/models/feature_extractor/clip/clip_presets.py @@ -0,0 +1,68 @@ +"""CLIP presets.""" + +clip_presets = { + "clip-vit-base-patch16": { + "metadata": { + "description": ( + "The model uses a ViT-B/16 Transformer architecture as an " + "image encoder and uses a masked self-attention Transformer as " + "a text encoder. These encoders are trained to maximize the " + "similarity of (image, text) pairs via a contrastive loss. The " + "model uses a patch size of 16 and input images of size (224, " + "224)" + ), + "params": 149620737, + "official_name": "CLIP", + "path": "clip", + }, + "kaggle_handle": "kaggle://keras/yolov8/keras/yolo_v8_m_pascalvoc/2", + }, + "clip-vit-base-patch32": { + "metadata": { + "description": ( + "The model uses a ViT-B/32 Transformer architecture as an " + "image encoder and uses a masked self-attention Transformer as " + "a text encoder. These encoders are trained to maximize the " + "similarity of (image, text) pairs via a contrastive loss.The " + "model uses a patch size of 32 and input images of size (224, " + "224)" + ), + "params": 151277313, + "official_name": "CLIP", + "path": "clip", + }, + "kaggle_handle": "kaggle://keras/yolov8/keras/yolo_v8_m_pascalvoc/2", + }, + "clip-vit-large-patch14": { + "metadata": { + "description": ( + "The model uses a ViT-L/14 Transformer architecture as an " + "image encoder and uses a masked self-attention Transformer as " + "a text encoder. These encoders are trained to maximize the " + "similarity of (image, text) pairs via a contrastive loss.The " + "model uses a patch size of 14 and input images of size (224, " + "224)" + ), + "params": 427616513, + "official_name": "CLIP", + "path": "clip", + }, + "kaggle_handle": "kaggle://keras/yolov8/keras/yolo_v8_m_pascalvoc/2", + }, + "clip-vit-large-patch14-336": { + "metadata": { + "description": ( + "The model uses a ViT-L/14 Transformer architecture as an " + "image encoder and uses a masked self-attention Transformer as " + "a text encoder. These encoders are trained to maximize the " + "similarity of (image, text) pairs via a contrastive loss.The " + "model uses a patch size of 14 and input images of size (336, " + "336)" + ), + "params": 427944193, + "official_name": "CLIP", + "path": "clip", + }, + "kaggle_handle": "", + }, +} From 2219bc28a997991ebac3f24e897c74d86ee9ee31 Mon Sep 17 00:00:00 2001 From: Divyashree Sreepathihalli Date: Sat, 3 Feb 2024 00:03:56 +0000 Subject: [PATCH 10/38] update array --- keras_cv/models/feature_extractor/clip/clip_processor.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/keras_cv/models/feature_extractor/clip/clip_processor.py b/keras_cv/models/feature_extractor/clip/clip_processor.py index 526952ba74..5505e87f11 100644 --- a/keras_cv/models/feature_extractor/clip/clip_processor.py +++ b/keras_cv/models/feature_extractor/clip/clip_processor.py @@ -63,8 +63,8 @@ def __init__(self, input_resolution, vocabulary, merges): def transform_image(self, image_path): input_resolution = self.input_resolution - mean = np.array([0.48145466, 0.4578275, 0.40821073]) - std = np.array([0.26862954, 0.26130258, 0.27577711]) + mean = ops.array([0.48145466, 0.4578275, 0.40821073]) + std = ops.array([0.26862954, 0.26130258, 0.27577711]) image = keras.utils.load_img(image_path) image = keras.utils.img_to_array(image) From 957b6c81d0bcf4b7d922e5602eebe91be0e684ce Mon Sep 17 00:00:00 2001 From: Divyashree Sreepathihalli Date: Sat, 3 Feb 2024 00:13:53 +0000 Subject: [PATCH 11/38] update clip prests kaggle handle --- keras_cv/models/feature_extractor/clip/clip_presets.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/keras_cv/models/feature_extractor/clip/clip_presets.py b/keras_cv/models/feature_extractor/clip/clip_presets.py index 9f6dda87aa..29aa08c822 100644 --- a/keras_cv/models/feature_extractor/clip/clip_presets.py +++ b/keras_cv/models/feature_extractor/clip/clip_presets.py @@ -15,7 +15,7 @@ "official_name": "CLIP", "path": "clip", }, - "kaggle_handle": "kaggle://keras/yolov8/keras/yolo_v8_m_pascalvoc/2", + "kaggle_handle": "TBD", }, "clip-vit-base-patch32": { "metadata": { @@ -31,7 +31,7 @@ "official_name": "CLIP", "path": "clip", }, - "kaggle_handle": "kaggle://keras/yolov8/keras/yolo_v8_m_pascalvoc/2", + "kaggle_handle": "TBD", }, "clip-vit-large-patch14": { "metadata": { @@ -47,7 +47,7 @@ "official_name": "CLIP", "path": "clip", }, - "kaggle_handle": "kaggle://keras/yolov8/keras/yolo_v8_m_pascalvoc/2", + "kaggle_handle": "TBD", }, "clip-vit-large-patch14-336": { "metadata": { @@ -63,6 +63,6 @@ "official_name": "CLIP", "path": "clip", }, - "kaggle_handle": "", + "kaggle_handle": "TBD", }, } From 160d2a9506c2f4cb9a841fd3020af9f16a533539 Mon Sep 17 00:00:00 2001 From: Divyashree Sreepathihalli Date: Wed, 7 Feb 2024 23:26:44 +0000 Subject: [PATCH 12/38] update text model --- .../feature_extractor/clip/clip_text_model.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/keras_cv/models/feature_extractor/clip/clip_text_model.py b/keras_cv/models/feature_extractor/clip/clip_text_model.py index 3665e0b741..345d6a3c2c 100644 --- a/keras_cv/models/feature_extractor/clip/clip_text_model.py +++ b/keras_cv/models/feature_extractor/clip/clip_text_model.py @@ -44,13 +44,24 @@ def __init__( def call(self, inputs): token_embedding = self.token_embedding(inputs) + input_shape = token_embedding.shape + position_ids = ops.expand_dims( + ops.arange(start=0, stop=input_shape[-1]), axis=0 + ) + position_embeds = ops.take( + self.positional_embedding, indices=position_ids + ) + position_embeds = ops.tile( + position_embeds, repeats=(input_shape[0], 1, 1) + ) encoded_output = self.encoder( - token_embedding + self.positional_embedding + token_embedding + position_embeds ) layer_norm = self.ln_final(encoded_output) indices = ops.expand_dims( ops.cast(ops.argmax(inputs, axis=1), "int32"), axis=-1 ) + print("incides", indices) selected_features = ops.take_along_axis( layer_norm, indices[:, :, None], axis=1 ) From 681120c513931483e7bae48be4da49a920812d5b Mon Sep 17 00:00:00 2001 From: Divyashree Sreepathihalli Date: Thu, 8 Feb 2024 02:27:38 +0000 Subject: [PATCH 13/38] update text encoder --- .../feature_extractor/clip/clip_model.py | 6 ----- .../feature_extractor/clip/clip_text_model.py | 25 +++++-------------- 2 files changed, 6 insertions(+), 25 deletions(-) diff --git a/keras_cv/models/feature_extractor/clip/clip_model.py b/keras_cv/models/feature_extractor/clip/clip_model.py index bf0740c4e9..f56a7609e4 100644 --- a/keras_cv/models/feature_extractor/clip/clip_model.py +++ b/keras_cv/models/feature_extractor/clip/clip_model.py @@ -146,12 +146,6 @@ def __init__( self.image_embeddings = None self.text_embeddings = None - def build_attention_mask(self): - mask = ops.ones((self.context_length, self.context_length)) - # Zero out the lower diagonal - mask = ops.triu(mask) - return ops.cast(mask, "float32") - def encode_images(self, image): return self.image_encoder(image) diff --git a/keras_cv/models/feature_extractor/clip/clip_text_model.py b/keras_cv/models/feature_extractor/clip/clip_text_model.py index 345d6a3c2c..efe0ac49cb 100644 --- a/keras_cv/models/feature_extractor/clip/clip_text_model.py +++ b/keras_cv/models/feature_extractor/clip/clip_text_model.py @@ -29,11 +29,15 @@ def __init__( shape=[self.context_length, transformer_width], name="positional_embedding", ) + mask = ops.ones((self.context_length, self.context_length)) + # Zero out the lower diagonal + mask = ops.triu(mask) + mask = ops.cast(mask, "float32") self.encoder = CLIPEncoder( width=transformer_width, layers=transformer_layers, heads=transformer_heads, - attn_mask=self.build_attention_mask(), + attn_mask=mask, name="clip_encoder", ) self.ln_final = keras.layers.LayerNormalization(name="ln_final") @@ -44,33 +48,16 @@ def __init__( def call(self, inputs): token_embedding = self.token_embedding(inputs) - input_shape = token_embedding.shape - position_ids = ops.expand_dims( - ops.arange(start=0, stop=input_shape[-1]), axis=0 - ) - position_embeds = ops.take( - self.positional_embedding, indices=position_ids - ) - position_embeds = ops.tile( - position_embeds, repeats=(input_shape[0], 1, 1) - ) encoded_output = self.encoder( - token_embedding + position_embeds + token_embedding + self.positional_embedding ) layer_norm = self.ln_final(encoded_output) indices = ops.expand_dims( ops.cast(ops.argmax(inputs, axis=1), "int32"), axis=-1 ) - print("incides", indices) selected_features = ops.take_along_axis( layer_norm, indices[:, :, None], axis=1 ) text_features = self.text_projector(selected_features) output = ops.squeeze(text_features, axis=1) return output - - def build_attention_mask(self): - mask = ops.ones((self.context_length, self.context_length)) - # Zero out the lower diagonal - mask = ops.triu(mask) - return ops.cast(mask, "float32") From df73f2343261d844a0bf3a2c117e0ab03ec25f26 Mon Sep 17 00:00:00 2001 From: Divyashree Sreepathihalli Date: Thu, 8 Feb 2024 06:12:45 +0000 Subject: [PATCH 14/38] update position embeddings --- .../models/feature_extractor/clip/clip_text_model.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/keras_cv/models/feature_extractor/clip/clip_text_model.py b/keras_cv/models/feature_extractor/clip/clip_text_model.py index efe0ac49cb..b95dd518d4 100644 --- a/keras_cv/models/feature_extractor/clip/clip_text_model.py +++ b/keras_cv/models/feature_extractor/clip/clip_text_model.py @@ -25,8 +25,9 @@ def __init__( ) self.vocab_size = vocab_size - self.positional_embedding = self.add_weight( - shape=[self.context_length, transformer_width], + self.positional_embedding = keras.layers.Embedding( + self.context_length, + transformer_width, name="positional_embedding", ) mask = ops.ones((self.context_length, self.context_length)) @@ -48,8 +49,10 @@ def __init__( def call(self, inputs): token_embedding = self.token_embedding(inputs) + position_ids = ops.expand_dims(ops.arange(self.context_length, dtype="int32"), 0) + position_embedding = self.positional_embedding(position_ids) encoded_output = self.encoder( - token_embedding + self.positional_embedding + token_embedding + position_embedding ) layer_norm = self.ln_final(encoded_output) indices = ops.expand_dims( From 80bde9c641c4a44f7f8e331e555dc3dda16ae063 Mon Sep 17 00:00:00 2001 From: Divyashree Sreepathihalli Date: Thu, 8 Feb 2024 06:35:22 +0000 Subject: [PATCH 15/38] update positonal embeddings --- .../models/feature_extractor/clip/clip_text_model.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/keras_cv/models/feature_extractor/clip/clip_text_model.py b/keras_cv/models/feature_extractor/clip/clip_text_model.py index b95dd518d4..41a0f57ca7 100644 --- a/keras_cv/models/feature_extractor/clip/clip_text_model.py +++ b/keras_cv/models/feature_extractor/clip/clip_text_model.py @@ -49,11 +49,14 @@ def __init__( def call(self, inputs): token_embedding = self.token_embedding(inputs) - position_ids = ops.expand_dims(ops.arange(self.context_length, dtype="int32"), 0) + position_ids = ops.expand_dims( + ops.arange(self.context_length, dtype="int32"), 0 + ) position_embedding = self.positional_embedding(position_ids) - encoded_output = self.encoder( - token_embedding + position_embedding + position_embedding = ops.tile( + position_embedding, repeats=(inputs.shape[0], 1, 1) ) + encoded_output = self.encoder(token_embedding + position_embedding) layer_norm = self.ln_final(encoded_output) indices = ops.expand_dims( ops.cast(ops.argmax(inputs, axis=1), "int32"), axis=-1 From 5f7b23bfcc28764c342be07e747a794c599d2dbd Mon Sep 17 00:00:00 2001 From: Divyashree Sreepathihalli Date: Thu, 8 Feb 2024 17:37:28 +0000 Subject: [PATCH 16/38] add attention masks --- .../feature_extractor/clip/clip_encoder.py | 38 ++++++++++++------- .../feature_extractor/clip/clip_model.py | 10 +++-- .../feature_extractor/clip/clip_processor.py | 3 +- .../feature_extractor/clip/clip_text_model.py | 11 +++++- 4 files changed, 40 insertions(+), 22 deletions(-) diff --git a/keras_cv/models/feature_extractor/clip/clip_encoder.py b/keras_cv/models/feature_extractor/clip/clip_encoder.py index 653189ca7d..1be8ff1991 100644 --- a/keras_cv/models/feature_extractor/clip/clip_encoder.py +++ b/keras_cv/models/feature_extractor/clip/clip_encoder.py @@ -58,14 +58,23 @@ def __init__( * 0.02 ) - def attention(self, x): + def attention(self, x, attention_mask=None): self.attn_mask = ( ops.cast(self.attn_mask, dtype=x.dtype) if self.attn_mask is not None else None ) + attention_mask = ( + ops.cast(attention_mask, dtype=x.dtype) + if attention_mask is not None + else None + ) - return self.attn(x, attention_mask=self.attn_mask) + return self.attn( + x, + attention_mask=attention_mask, + causal_attention_mask=self.attn_mask, + ) def build(self, input_shape): super().build(input_shape) @@ -93,8 +102,8 @@ def build(self, input_shape): ) self.ln_2 = keras.layers.LayerNormalization(epsilon=1e-5, name="ln_2") - def call(self, x): - x = x + self.attention(self.ln_1(x)) + def call(self, x, attention_mask=None): + x = x + self.attention(self.ln_1(x), attention_mask=attention_mask) x = x + self.mlp(self.ln_2(x)) return x @@ -109,20 +118,21 @@ def __init__(self, width, layers, heads, attn_mask=None, **kwargs): self.layers = layers self.heads = heads self.attn_mask = attn_mask - self.resblocks = keras.Sequential( - [ - ResidualAttention( - self.width, self.heads, self.layers, self.attn_mask - ) - for _ in range(self.layers) - ] - ) + self.resblocks = [ + ResidualAttention( + self.width, self.heads, self.layers, self.attn_mask + ) + for _ in range(self.layers) + ] def build(self, input_shape): super().build(input_shape) + self.resblocks.build() - def call(self, x): - return self.resblocks(x) + def call(self, x, attention_mask=None): + for block in self.resblocks: + x = block(x, attention_mask=attention_mask) + return x def compute_output_shape(self, inputs_shape): return inputs_shape diff --git a/keras_cv/models/feature_extractor/clip/clip_model.py b/keras_cv/models/feature_extractor/clip/clip_model.py index f56a7609e4..04dd816f07 100644 --- a/keras_cv/models/feature_extractor/clip/clip_model.py +++ b/keras_cv/models/feature_extractor/clip/clip_model.py @@ -149,12 +149,14 @@ def __init__( def encode_images(self, image): return self.image_encoder(image) - def encode_text(self, text): - return self.text_encoder(text) + def encode_text(self, text, attention_mask=None): + return self.text_encoder(text, attention_mask=attention_mask) - def call(self, image, text): + def call(self, image, text, attention_mask=None): self.image_embeddings = self.encode_images(image) - self.text_embeddings = self.encode_text(text) + self.text_embeddings = self.encode_text( + text, attention_mask=attention_mask + ) normalize_image_features = keras.ops.sqrt( keras.ops.sum( keras.ops.power(self.image_embeddings, 2), keepdims=True diff --git a/keras_cv/models/feature_extractor/clip/clip_processor.py b/keras_cv/models/feature_extractor/clip/clip_processor.py index 5505e87f11..80183fcb0e 100644 --- a/keras_cv/models/feature_extractor/clip/clip_processor.py +++ b/keras_cv/models/feature_extractor/clip/clip_processor.py @@ -109,12 +109,11 @@ def process_texts(self, texts, context_length: int = 77): texts = [texts] def pack_tokens(text): - tok, _ = self.packer( + return self.packer( self.tokenizer(text), sequence_length=context_length, add_start_value=True, add_end_value=True, ) - return tok return pack_tokens(texts) diff --git a/keras_cv/models/feature_extractor/clip/clip_text_model.py b/keras_cv/models/feature_extractor/clip/clip_text_model.py index 41a0f57ca7..bb096c805d 100644 --- a/keras_cv/models/feature_extractor/clip/clip_text_model.py +++ b/keras_cv/models/feature_extractor/clip/clip_text_model.py @@ -47,7 +47,7 @@ def __init__( embed_dim, name="text_projector", use_bias=False ) - def call(self, inputs): + def call(self, inputs, attention_mask=None): token_embedding = self.token_embedding(inputs) position_ids = ops.expand_dims( ops.arange(self.context_length, dtype="int32"), 0 @@ -56,7 +56,14 @@ def call(self, inputs): position_embedding = ops.tile( position_embedding, repeats=(inputs.shape[0], 1, 1) ) - encoded_output = self.encoder(token_embedding + position_embedding) + attention_mask = ops.cast(attention_mask, dtype="float32") + expanded_mask = ops.tile( + attention_mask[:, None, None, :], (1, 1, self.context_length, 1) + ) + expanded_mask = (1.0 - expanded_mask) * (-1e8) + encoded_output = self.encoder( + token_embedding + position_embedding, attention_mask=expanded_mask + ) layer_norm = self.ln_final(encoded_output) indices = ops.expand_dims( ops.cast(ops.argmax(inputs, axis=1), "int32"), axis=-1 From 7530eed43c765e35fbe5e7529a1a5ea460956156 Mon Sep 17 00:00:00 2001 From: Divyashree Sreepathihalli Date: Thu, 8 Feb 2024 19:15:05 +0000 Subject: [PATCH 17/38] update expanded mask --- keras_cv/models/feature_extractor/clip/clip_text_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/keras_cv/models/feature_extractor/clip/clip_text_model.py b/keras_cv/models/feature_extractor/clip/clip_text_model.py index bb096c805d..eb2287ee80 100644 --- a/keras_cv/models/feature_extractor/clip/clip_text_model.py +++ b/keras_cv/models/feature_extractor/clip/clip_text_model.py @@ -60,7 +60,7 @@ def call(self, inputs, attention_mask=None): expanded_mask = ops.tile( attention_mask[:, None, None, :], (1, 1, self.context_length, 1) ) - expanded_mask = (1.0 - expanded_mask) * (-1e8) + # expanded_mask = (1.0 - expanded_mask) * (-1e8) encoded_output = self.encoder( token_embedding + position_embedding, attention_mask=expanded_mask ) From 0211bd47f8267c68a685f785eaeebf56405524f4 Mon Sep 17 00:00:00 2001 From: Divyashree Sreepathihalli Date: Thu, 8 Feb 2024 19:24:46 +0000 Subject: [PATCH 18/38] revert previous commit --- keras_cv/models/feature_extractor/clip/clip_text_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/keras_cv/models/feature_extractor/clip/clip_text_model.py b/keras_cv/models/feature_extractor/clip/clip_text_model.py index eb2287ee80..bb096c805d 100644 --- a/keras_cv/models/feature_extractor/clip/clip_text_model.py +++ b/keras_cv/models/feature_extractor/clip/clip_text_model.py @@ -60,7 +60,7 @@ def call(self, inputs, attention_mask=None): expanded_mask = ops.tile( attention_mask[:, None, None, :], (1, 1, self.context_length, 1) ) - # expanded_mask = (1.0 - expanded_mask) * (-1e8) + expanded_mask = (1.0 - expanded_mask) * (-1e8) encoded_output = self.encoder( token_embedding + position_embedding, attention_mask=expanded_mask ) From d488b7523af0fea4ac94cade1f84b763548b2a7e Mon Sep 17 00:00:00 2001 From: Divyashree Sreepathihalli Date: Thu, 8 Feb 2024 19:33:48 +0000 Subject: [PATCH 19/38] change causal masks --- keras_cv/models/feature_extractor/clip/clip_text_model.py | 1 + 1 file changed, 1 insertion(+) diff --git a/keras_cv/models/feature_extractor/clip/clip_text_model.py b/keras_cv/models/feature_extractor/clip/clip_text_model.py index bb096c805d..d98e13160b 100644 --- a/keras_cv/models/feature_extractor/clip/clip_text_model.py +++ b/keras_cv/models/feature_extractor/clip/clip_text_model.py @@ -34,6 +34,7 @@ def __init__( # Zero out the lower diagonal mask = ops.triu(mask) mask = ops.cast(mask, "float32") + mask = (1.0 - mask) * (-1e8) self.encoder = CLIPEncoder( width=transformer_width, layers=transformer_layers, From d9d126430d5e1ae3ae18291ce60f75c2b7d2532a Mon Sep 17 00:00:00 2001 From: Divyashree Sreepathihalli Date: Thu, 8 Feb 2024 19:42:00 +0000 Subject: [PATCH 20/38] undo previous commit --- keras_cv/models/feature_extractor/clip/clip_text_model.py | 1 - 1 file changed, 1 deletion(-) diff --git a/keras_cv/models/feature_extractor/clip/clip_text_model.py b/keras_cv/models/feature_extractor/clip/clip_text_model.py index d98e13160b..bb096c805d 100644 --- a/keras_cv/models/feature_extractor/clip/clip_text_model.py +++ b/keras_cv/models/feature_extractor/clip/clip_text_model.py @@ -34,7 +34,6 @@ def __init__( # Zero out the lower diagonal mask = ops.triu(mask) mask = ops.cast(mask, "float32") - mask = (1.0 - mask) * (-1e8) self.encoder = CLIPEncoder( width=transformer_width, layers=transformer_layers, From 64d66b54d172e16235e07993779086155921059a Mon Sep 17 00:00:00 2001 From: Divyashree Sreepathihalli Date: Thu, 8 Feb 2024 23:27:34 +0000 Subject: [PATCH 21/38] update attention masks --- keras_cv/models/feature_extractor/clip/clip_encoder.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/keras_cv/models/feature_extractor/clip/clip_encoder.py b/keras_cv/models/feature_extractor/clip/clip_encoder.py index 1be8ff1991..d1146985c3 100644 --- a/keras_cv/models/feature_extractor/clip/clip_encoder.py +++ b/keras_cv/models/feature_extractor/clip/clip_encoder.py @@ -69,11 +69,11 @@ def attention(self, x, attention_mask=None): if attention_mask is not None else None ) + mask = ops.add(self.attn_mask, attention_mask) return self.attn( x, - attention_mask=attention_mask, - causal_attention_mask=self.attn_mask, + attention_mask=mask, ) def build(self, input_shape): From de0be1907fb66ac68e96b97548ef799bd34afebe Mon Sep 17 00:00:00 2001 From: Divyashree Sreepathihalli Date: Thu, 8 Feb 2024 23:51:02 +0000 Subject: [PATCH 22/38] update clip encoder --- .../models/feature_extractor/clip/clip_encoder.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/keras_cv/models/feature_extractor/clip/clip_encoder.py b/keras_cv/models/feature_extractor/clip/clip_encoder.py index d1146985c3..8be0cd05fa 100644 --- a/keras_cv/models/feature_extractor/clip/clip_encoder.py +++ b/keras_cv/models/feature_extractor/clip/clip_encoder.py @@ -59,17 +59,18 @@ def __init__( ) def attention(self, x, attention_mask=None): - self.attn_mask = ( + mask = ( ops.cast(self.attn_mask, dtype=x.dtype) if self.attn_mask is not None else None ) - attention_mask = ( - ops.cast(attention_mask, dtype=x.dtype) - if attention_mask is not None - else None - ) - mask = ops.add(self.attn_mask, attention_mask) + if attention_mask is not None: + attention_mask = ( + ops.cast(attention_mask, dtype=x.dtype) + if attention_mask is not None + else None + ) + mask = ops.add(self.attn_mask, attention_mask) return self.attn( x, From 4b8c1efd8a9f5b8e03d38dd297797449f81ec82e Mon Sep 17 00:00:00 2001 From: Divyashree Sreepathihalli Date: Fri, 9 Feb 2024 00:48:53 +0000 Subject: [PATCH 23/38] add print statements --- keras_cv/models/feature_extractor/clip/clip_text_model.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/keras_cv/models/feature_extractor/clip/clip_text_model.py b/keras_cv/models/feature_extractor/clip/clip_text_model.py index bb096c805d..715c3a8dd3 100644 --- a/keras_cv/models/feature_extractor/clip/clip_text_model.py +++ b/keras_cv/models/feature_extractor/clip/clip_text_model.py @@ -64,6 +64,7 @@ def call(self, inputs, attention_mask=None): encoded_output = self.encoder( token_embedding + position_embedding, attention_mask=expanded_mask ) + print("encoded_output", encoded_output) layer_norm = self.ln_final(encoded_output) indices = ops.expand_dims( ops.cast(ops.argmax(inputs, axis=1), "int32"), axis=-1 @@ -71,6 +72,7 @@ def call(self, inputs, attention_mask=None): selected_features = ops.take_along_axis( layer_norm, indices[:, :, None], axis=1 ) + print("pooler output", selected_features) text_features = self.text_projector(selected_features) output = ops.squeeze(text_features, axis=1) return output From 54f02e81b0bec5035b004673c5a4c9371286a339 Mon Sep 17 00:00:00 2001 From: Divyashree Sreepathihalli Date: Fri, 9 Feb 2024 01:31:54 +0000 Subject: [PATCH 24/38] update the pooler output --- keras_cv/models/feature_extractor/clip/clip_text_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/keras_cv/models/feature_extractor/clip/clip_text_model.py b/keras_cv/models/feature_extractor/clip/clip_text_model.py index 715c3a8dd3..fc765fc67b 100644 --- a/keras_cv/models/feature_extractor/clip/clip_text_model.py +++ b/keras_cv/models/feature_extractor/clip/clip_text_model.py @@ -67,7 +67,7 @@ def call(self, inputs, attention_mask=None): print("encoded_output", encoded_output) layer_norm = self.ln_final(encoded_output) indices = ops.expand_dims( - ops.cast(ops.argmax(inputs, axis=1), "int32"), axis=-1 + ops.cast(ops.argmax(inputs, axis=-1), "int32"), axis=-1 ) selected_features = ops.take_along_axis( layer_norm, indices[:, :, None], axis=1 From f8316387f084d2cbb37adf38dae46be6f0ea466a Mon Sep 17 00:00:00 2001 From: Divyashree Sreepathihalli Date: Fri, 9 Feb 2024 07:01:33 +0000 Subject: [PATCH 25/38] remove print statements --- keras_cv/models/feature_extractor/clip/clip_model_test.py | 7 ++++--- keras_cv/models/feature_extractor/clip/clip_text_model.py | 2 -- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/keras_cv/models/feature_extractor/clip/clip_model_test.py b/keras_cv/models/feature_extractor/clip/clip_model_test.py index a97106e56d..1889db139f 100644 --- a/keras_cv/models/feature_extractor/clip/clip_model_test.py +++ b/keras_cv/models/feature_extractor/clip/clip_model_test.py @@ -15,7 +15,6 @@ import pytest from keras_cv.backend import ops -from keras_cv.backend import random from keras_cv.models import CLIP from keras_cv.tests.test_case import TestCase @@ -28,6 +27,8 @@ def test_clip_tokenizer(self): def test_presets(self): pass - @pytest.mark.extra_large - def test_mixed_precision(self): + def test_image_encoder_golden_values(self): + pass + + def test_text_encoder_golden_values(self): pass diff --git a/keras_cv/models/feature_extractor/clip/clip_text_model.py b/keras_cv/models/feature_extractor/clip/clip_text_model.py index fc765fc67b..e9abdada60 100644 --- a/keras_cv/models/feature_extractor/clip/clip_text_model.py +++ b/keras_cv/models/feature_extractor/clip/clip_text_model.py @@ -64,7 +64,6 @@ def call(self, inputs, attention_mask=None): encoded_output = self.encoder( token_embedding + position_embedding, attention_mask=expanded_mask ) - print("encoded_output", encoded_output) layer_norm = self.ln_final(encoded_output) indices = ops.expand_dims( ops.cast(ops.argmax(inputs, axis=-1), "int32"), axis=-1 @@ -72,7 +71,6 @@ def call(self, inputs, attention_mask=None): selected_features = ops.take_along_axis( layer_norm, indices[:, :, None], axis=1 ) - print("pooler output", selected_features) text_features = self.text_projector(selected_features) output = ops.squeeze(text_features, axis=1) return output From 3868bb5e9a827368ee79436ce087be4afa4b3acb Mon Sep 17 00:00:00 2001 From: Divyashree Sreepathihalli Date: Fri, 9 Feb 2024 23:23:49 +0000 Subject: [PATCH 26/38] add tests and preset --- .../feature_extractor/clip/clip_model.py | 94 +++++---------- .../feature_extractor/clip/clip_model_test.py | 107 +++++++++++++++++- .../feature_extractor/clip/clip_presets.py | 8 +- 3 files changed, 136 insertions(+), 73 deletions(-) diff --git a/keras_cv/models/feature_extractor/clip/clip_model.py b/keras_cv/models/feature_extractor/clip/clip_model.py index 04dd816f07..a03948b909 100644 --- a/keras_cv/models/feature_extractor/clip/clip_model.py +++ b/keras_cv/models/feature_extractor/clip/clip_model.py @@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import copy + from keras_cv.api_export import keras_cv_export from keras_cv.backend import keras from keras_cv.backend import ops @@ -18,68 +20,23 @@ from keras_cv.models.feature_extractor.clip.clip_image_model import ( CLIPImageEncoder, ) +from keras_cv.models.feature_extractor.clip.clip_presets import ( # noqa: E501 + clip_presets, +) from keras_cv.models.feature_extractor.clip.clip_text_model import ( CLIPTextEncoder, ) - -MODEL_CONFIGS = { - "CLIP_B32": { - "embed_dim": 512, - "context_length": 77, - "vocab_size": 49408, - "transformer_width": 512, - "transformer_heads": 8, - "transformer_layers": 12, - "vision_layers": 12, - "vision_width": 768, - "image_resolution": 224, - "vision_patch_size": 32, - }, - "CLIP_B16": { - "embed_dim": 512, - "context_length": 77, - "vocab_size": 49408, - "transformer_width": 512, - "transformer_heads": 8, - "transformer_layers": 12, - "vision_layers": 12, - "vision_width": 768, - "image_resolution": 224, - "vision_patch_size": 16, - }, - "CLIP_L14": { - "embed_dim": 768, - "context_length": 77, - "vocab_size": 49408, - "transformer_width": 768, - "transformer_heads": 12, - "transformer_layers": 12, - "vision_layers": 24, - "vision_width": 1024, - "image_resolution": 224, - "vision_patch_size": 14, - }, - "CLIP_L14_336": { - "embed_dim": 768, - "context_length": 77, - "vocab_size": 49408, - "transformer_width": 768, - "transformer_heads": 12, - "transformer_layers": 12, - "vision_layers": 24, - "vision_width": 1024, - "image_resolution": 336, - "vision_patch_size": 14, - }, -} +from keras_cv.models.task import Task +from keras_cv.utils.python_utils import classproperty @keras_cv_export(["keras_cv.models.CLIP"]) -class CLIP(keras.Model): +class CLIP(Task): """ CLIP implements the Contrastive Language-Image Pretraining (CLIP) architecture, which enables joint learning of visual and textual - representations for various downstream tasks. + representations for various downstream tasks. The deafult base model + achitecture will be set to clip-vit-base-patch32. Args: embed_dim (int): The dimensionality of the joint embedding space for @@ -104,16 +61,16 @@ class CLIP(keras.Model): def __init__( self, - embed_dim, - image_resolution, - vision_layers, - vision_width, - vision_patch_size, - context_length, - vocab_size, - transformer_width, - transformer_heads, - transformer_layers, + embed_dim=512, + image_resolution=224, + vision_layers=12, + vision_width=768, + vision_patch_size=32, + context_length=77, + vocab_size=49408, + transformer_width=768, + transformer_heads=8, + transformer_layers=12, **kwargs, ): super().__init__(**kwargs) @@ -180,3 +137,14 @@ def call(self, image, text, attention_mask=None): logits_per_text = ops.transpose(logits_per_image) return logits_per_image, logits_per_text + + @classproperty + def presets(cls): + """Dictionary of preset names and configurations.""" + return copy.deepcopy({**clip_presets}) + + @classproperty + def presets_with_weights(cls): + """Dictionary of preset names and configurations that include + weights.""" + return copy.deepcopy({**clip_presets}) diff --git a/keras_cv/models/feature_extractor/clip/clip_model_test.py b/keras_cv/models/feature_extractor/clip/clip_model_test.py index 1889db139f..f7cb6a6698 100644 --- a/keras_cv/models/feature_extractor/clip/clip_model_test.py +++ b/keras_cv/models/feature_extractor/clip/clip_model_test.py @@ -12,23 +12,118 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os + +import numpy as np import pytest +from keras_cv.backend import keras from keras_cv.backend import ops +from keras_cv.backend.config import keras_3 from keras_cv.models import CLIP +from keras_cv.models.feature_extractor.clip import CLIPImageEncoder +from keras_cv.models.feature_extractor.clip import CLIPProcessor +from keras_cv.models.feature_extractor.clip import CLIPTextEncoder +from keras_cv.models.feature_extractor.clip import CLIPTokenizer from keras_cv.tests.test_case import TestCase +VOCAB_PATH = keras.utils.get_file( + None, + "https://storage.googleapis.com/keras-cv/models/clip/vocab.json", +) +MERGE_PATH = keras.utils.get_file( + None, + "https://storage.googleapis.com/keras-cv/models/clip/merges.txt", +) + +MODEL_PATH = keras.utils.get_file( + None, + "https://storage.googleapis.com/keras-cv/models/clip/clip-vit-base-patch32.weights.h5", +) + @pytest.mark.tf_only -class StableDiffusionTest(TestCase): - def test_clip_tokenizer(self): - pass +class CLIPTest(TestCase): + @pytest.mark.large + def test_clip_model_golden_values(self): + model = CLIP() + model.load_weights(MODEL_PATH) + processed_image = np.ones(shape=[1, 224, 224, 3]) + processed_text = np.ones(shape=[3, 77]) + attention_mask = np.ones(shape=[3, 77]) + image_logits, text_logits = model( + processed_image, processed_text, attention_mask + ) + print(image_logits) + self.assertAllClose(image_logits, [[3.75321, 3.75321, 3.7532094]]) + self.assertAllClose( + text_logits, ops.transpose([[3.75321, 3.75321, 3.7532094]]) + ) + + def test_clip_preprocessor(self): + processor = CLIPProcessor(224, VOCAB_PATH, MERGE_PATH) + processed_text, attention_mask = processor.process_texts( + ["mountains", "cat on tortoise", "two cats"] + ) + @pytest.mark.large def test_presets(self): - pass + self.skipTest("TODO: Enable after Kaggle model is public") + model = CLIP.from_preset("clip-vit-base-patch32") + processed_image = np.ones(shape=[1, 224, 224, 3]) + processed_text = np.ones(shape=[3, 77]) + attention_mask = np.ones(shape=[3, 77]) + image_logits, text_logits = model( + processed_image, processed_text, attention_mask + ) + @pytest.mark.large def test_image_encoder_golden_values(self): - pass + model = CLIP() + model.load_weights(MODEL_PATH) + processed_image = np.ones(shape=[1, 224, 224, 3]) + processed_text = np.ones(shape=[3, 77]) + attention_mask = np.ones(shape=[3, 77]) + model(processed_image, processed_text, attention_mask) + self.assertAllClose( + model.image_embeddings[:, :5], + [[0.03867503, -0.05168268, -0.07742637, 0.06213959, -0.0895554]], + ) + @pytest.mark.large def test_text_encoder_golden_values(self): - pass + model = CLIP() + model.load_weights(MODEL_PATH) + processed_image = np.ones(shape=[1, 224, 224, 3]) + processed_text = np.ones(shape=[3, 77]) + attention_mask = np.ones(shape=[3, 77]) + model(processed_image, processed_text, attention_mask) + print(model.text_embeddings) + self.assertAllClose( + model.text_embeddings[0, :3], + [0.01148358, 0.03956496, -0.0104028], + ) + + @pytest.mark.large # Saving is slow, so mark these large. + def test_saved_model(self): + model = CLIP() + model.load_weights(MODEL_PATH) + processed_image = np.ones(shape=[1, 224, 224, 3]) + processed_text = np.ones(shape=[3, 77]) + attention_mask = np.ones(shape=[3, 77]) + model_output, _ = model(processed_image, processed_text, attention_mask) + save_path = os.path.join(self.get_temp_dir(), "model.keras") + if keras_3(): + model.save(save_path) + else: + model.save(save_path, save_format="keras_v3") + restored_model = keras.models.load_model(save_path) + + # Check we got the real object back. + self.assertIsInstance(restored_model, CLIP) + + # Check that output matches. + restored_output, _ = restored_model( + processed_image, processed_text, attention_mask + ) + self.assertAllClose(model_output, restored_output) diff --git a/keras_cv/models/feature_extractor/clip/clip_presets.py b/keras_cv/models/feature_extractor/clip/clip_presets.py index 29aa08c822..041e87a24c 100644 --- a/keras_cv/models/feature_extractor/clip/clip_presets.py +++ b/keras_cv/models/feature_extractor/clip/clip_presets.py @@ -15,7 +15,7 @@ "official_name": "CLIP", "path": "clip", }, - "kaggle_handle": "TBD", + "kaggle_handle": "kaggle://keras/clip/keras/clip-vit-base-patch16/1", }, "clip-vit-base-patch32": { "metadata": { @@ -31,7 +31,7 @@ "official_name": "CLIP", "path": "clip", }, - "kaggle_handle": "TBD", + "kaggle_handle": "kaggle://keras/clip/keras/clip-vit-base-patch32/1", }, "clip-vit-large-patch14": { "metadata": { @@ -47,7 +47,7 @@ "official_name": "CLIP", "path": "clip", }, - "kaggle_handle": "TBD", + "kaggle_handle": "kaggle://keras/clip/keras/clip-vit-large-patch14/1", }, "clip-vit-large-patch14-336": { "metadata": { @@ -63,6 +63,6 @@ "official_name": "CLIP", "path": "clip", }, - "kaggle_handle": "TBD", + "kaggle_handle": "kaggle://keras/clip/keras/clip-vit-large-patch14/1", }, } From 95d9e10c79283bc4d21f9ca323285748b2a12747 Mon Sep 17 00:00:00 2001 From: Divyashree Sreepathihalli Date: Tue, 13 Feb 2024 02:12:47 +0000 Subject: [PATCH 27/38] cleanup and reformat --- .../feature_extractor/clip/clip_encoder.py | 142 +++++++++++------- .../clip/clip_image_model.py | 36 ++++- .../feature_extractor/clip/clip_model.py | 29 +++- .../feature_extractor/clip/clip_model_test.py | 11 +- .../feature_extractor/clip/clip_text_model.py | 39 +++-- 5 files changed, 183 insertions(+), 74 deletions(-) diff --git a/keras_cv/models/feature_extractor/clip/clip_encoder.py b/keras_cv/models/feature_extractor/clip/clip_encoder.py index 8be0cd05fa..38b47813c9 100644 --- a/keras_cv/models/feature_extractor/clip/clip_encoder.py +++ b/keras_cv/models/feature_extractor/clip/clip_encoder.py @@ -40,15 +40,13 @@ class ResidualAttention(keras.layers.Layer): def __init__( self, proj_dim, - n_head, + num_heads, num_hidden_layers, - attn_mask=None, **kwargs, ): super().__init__(**kwargs) self.proj_dim = proj_dim - self.n_head = n_head - self.attn_mask = attn_mask + self.num_heads = num_heads self.num_hidden_layers = num_hidden_layers self.fc_std = ops.power(2 * self.proj_dim, -0.5) * 0.02 @@ -58,19 +56,21 @@ def __init__( * 0.02 ) - def attention(self, x, attention_mask=None): - mask = ( - ops.cast(self.attn_mask, dtype=x.dtype) - if self.attn_mask is not None - else None - ) + def attention(self, x, causal_attention_mask=None, attention_mask=None): + mask = None + if causal_attention_mask is not None: + mask = ( + ops.cast(causal_attention_mask, dtype=x.dtype) + if causal_attention_mask is not None + else None + ) if attention_mask is not None: attention_mask = ( ops.cast(attention_mask, dtype=x.dtype) if attention_mask is not None else None ) - mask = ops.add(self.attn_mask, attention_mask) + mask = ops.add(causal_attention_mask, attention_mask) return self.attn( x, @@ -81,7 +81,7 @@ def build(self, input_shape): super().build(input_shape) self.attn = CLIPAttention( self.proj_dim, - self.n_head, + self.num_heads, self.num_hidden_layers, name="multi_head_attention", ) @@ -103,41 +103,77 @@ def build(self, input_shape): ) self.ln_2 = keras.layers.LayerNormalization(epsilon=1e-5, name="ln_2") - def call(self, x, attention_mask=None): - x = x + self.attention(self.ln_1(x), attention_mask=attention_mask) + def call(self, x, causal_attention_mask=None, attention_mask=None): + x = x + self.attention( + self.ln_1(x), + causal_attention_mask=causal_attention_mask, + attention_mask=attention_mask, + ) x = x + self.mlp(self.ln_2(x)) return x def compute_output_shape(self, inputs_shape): return inputs_shape + def get_config(self): + config = super().get_config() + config.update( + { + "proj_dim": self.proj_dim, + "num_heads": self.num_heads, + "num_hidden_layers": self.num_hidden_layers, + } + ) + return config + class CLIPEncoder(keras.layers.Layer): - def __init__(self, width, layers, heads, attn_mask=None, **kwargs): + def __init__(self, width, num_layers, heads, **kwargs): super().__init__(**kwargs) self.width = width - self.layers = layers + self.num_layers = num_layers self.heads = heads - self.attn_mask = attn_mask self.resblocks = [ ResidualAttention( - self.width, self.heads, self.layers, self.attn_mask + self.width, + self.heads, + self.num_layers, ) - for _ in range(self.layers) + for _ in range(self.num_layers) ] def build(self, input_shape): super().build(input_shape) - self.resblocks.build() + self.resblocks.build(input_shape) - def call(self, x, attention_mask=None): + def call( + self, + x, + causal_attention_mask=None, + attention_mask=None, + ): for block in self.resblocks: - x = block(x, attention_mask=attention_mask) + x = block( + x, + causal_attention_mask=causal_attention_mask, + attention_mask=attention_mask, + ) return x def compute_output_shape(self, inputs_shape): return inputs_shape + def get_config(self): + config = super().get_config() + config.update( + { + "width": self.width, + "num_layers": self.num_layers, + "heads": self.heads, + } + ) + return config + class CLIPAttention(keras.layers.Layer): """ @@ -146,58 +182,53 @@ class CLIPAttention(keras.layers.Layer): """ def __init__( - self, project_dim, num_heads, num_hidden_layers, dropout=0.0, **kwargs + self, proj_dim, num_heads, num_hidden_layers, dropout=0.0, **kwargs ): super().__init__(**kwargs) - self.project_dim = project_dim + self.proj_dim = proj_dim self.num_heads = num_heads self.num_hidden_layers = num_hidden_layers - self.head_dim = self.project_dim // self.num_heads - if self.head_dim * self.num_heads != self.project_dim: + self.dropout = dropout + self.head_dim = self.proj_dim // self.num_heads + if self.head_dim * self.num_heads != self.proj_dim: raise ValueError( - f"project_dim must be divisible by num_heads (got `project_dim`" - f": {self.project_dim} and `num_heads`:" + f"proj_dim must be divisible by num_heads (got `proj_dim`" + f": {self.proj_dim} and `num_heads`:" f" {self.num_heads})." ) - self.sqrt_att_head_size = ops.sqrt(self.head_dim) self.scale = self.head_dim**-0.5 + + def build(self, input_shape): + super().build(input_shape) in_proj_std = ( - (self.project_dim**-0.5) + (self.proj_dim**-0.5) * ((2 * self.num_hidden_layers) ** -0.5) * 0.02 ) - out_proj_std = (self.project_dim**-0.5) * 0.02 - self.dropout = dropout + out_proj_std = (self.proj_dim**-0.5) * 0.02 self.q_proj = keras.layers.Dense( - units=self.project_dim, + units=self.proj_dim, kernel_initializer=get_initializer(in_proj_std), name="q_proj", ) self.k_proj = keras.layers.Dense( - units=self.project_dim, + units=self.proj_dim, kernel_initializer=get_initializer(in_proj_std), name="k_proj", ) self.v_proj = keras.layers.Dense( - units=self.project_dim, + units=self.proj_dim, kernel_initializer=get_initializer(in_proj_std), name="v_proj", ) self.out_proj = keras.layers.Dense( - units=self.project_dim, + units=self.proj_dim, kernel_initializer=get_initializer(out_proj_std), name="out_proj", ) - def build(self, input_shape): - super().build(input_shape) - self.q_proj.build(input_shape) - self.k_proj.build(input_shape) - self.v_proj.build(input_shape) - self.out_proj.build(input_shape) - def _transpose_for_scores(self, tensor, batch_size): """ Copied from https://github.com/huggingface/transformers/blob/8e164c5400b7b413c7b8fb32e35132001effc970/src/transformers/models/bert/modeling_tf_bert.py#L252 # noqa: E501 @@ -215,7 +246,6 @@ def call( self, x, attention_mask=None, - causal_attention_mask=None, output_attentions=None, training=False, ): @@ -236,12 +266,6 @@ def call( attention_scores, dk ) # (batch_size, num_heads, seq_len_q, seq_len_k) - # Apply the causal_attention_mask first - if causal_attention_mask is not None: - # Apply the causal attention mask (precomputed for all layers in - # the call() function) - attention_scores = ops.add(attention_scores, causal_attention_mask) - if attention_mask is not None: # Apply the attention mask (precomputed for all layers in the # call() function) @@ -259,10 +283,8 @@ def call( attn_output = ops.matmul(attention_probs, value_layer) attn_output = ops.transpose(attn_output, axes=[0, 2, 1, 3]) - # (batch_size, seq_len_q, project_dim) - attn_output = ops.reshape( - attn_output, (batch_size, -1, self.project_dim) - ) + # (batch_size, seq_len_q, proj_dim) + attn_output = ops.reshape(attn_output, (batch_size, -1, self.proj_dim)) attn_output = self.out_proj(attn_output, training=training) outputs = ( @@ -272,3 +294,15 @@ def call( ) return outputs + + def get_config(self): + config = super().get_config() + config.update( + { + "proj_dim": self.proj_dim, + "num_heads": self.num_heads, + "num_hidden_layers": self.num_hidden_layers, + "dropout": self.dropout, + } + ) + return config diff --git a/keras_cv/models/feature_extractor/clip/clip_image_model.py b/keras_cv/models/feature_extractor/clip/clip_image_model.py index b436e1a8d7..4136f6d8a9 100644 --- a/keras_cv/models/feature_extractor/clip/clip_image_model.py +++ b/keras_cv/models/feature_extractor/clip/clip_image_model.py @@ -81,6 +81,18 @@ def call(self, x): embeddings = embeddings + positional_embedding return embeddings + def get_config(self): + config = super().get_config() + config.update( + { + "width": self.width, + "patch_size": self.patch_size, + "input_resolution": self.input_resolution, + "output_dim": self.output_dim, + } + ) + return config + class CLIPImageEncoder(keras.Model): def __init__( @@ -88,7 +100,7 @@ def __init__( input_resolution: int, patch_size: int, width: int, - layers: int, + num_layers: int, heads: int, output_dim: int, **kwargs, @@ -100,6 +112,8 @@ def __init__( self.width = width self.patch_size = patch_size self.output_dim = output_dim + self.heads = heads + self.num_layers = num_layers self.embeddings = CLIPPatchingAndEmbedding( width=self.width, @@ -111,9 +125,9 @@ def __init__( epsilon=1e-5, name="ln_1" ) self.encoder = CLIPEncoder( - width, - layers, - heads, + self.width, + self.num_layers, + self.heads, name="clip_encoder", ) self.post_norm = keras.layers.LayerNormalization( @@ -133,3 +147,17 @@ def call(self, image): post_norm = self.post_norm(encoded_output[:, 0, :]) image_projected_embeddings = self.image_projector(post_norm) return image_projected_embeddings + + def get_config(self): + config = super().get_config() + config.update( + { + "input_resolution": self.input_resolution, + "patch_size": self.patch_size, + "width": self.width, + "layers": self.num_layers, + "heads": self.heads, + "output_dim": self.output_dim, + } + ) + return config diff --git a/keras_cv/models/feature_extractor/clip/clip_model.py b/keras_cv/models/feature_extractor/clip/clip_model.py index a03948b909..54e95eacd1 100644 --- a/keras_cv/models/feature_extractor/clip/clip_model.py +++ b/keras_cv/models/feature_extractor/clip/clip_model.py @@ -75,14 +75,23 @@ def __init__( ): super().__init__(**kwargs) + self.embed_dim = embed_dim + self.image_resolution = image_resolution + self.vision_layers = vision_layers + self.vision_width = vision_width + self.vision_patch_size = vision_patch_size self.context_length = context_length + self.vocab_size = vocab_size + self.transformer_width = transformer_width + self.transformer_heads = transformer_heads + self.transformer_layers = transformer_layers vision_heads = vision_width // 64 self.image_encoder = CLIPImageEncoder( input_resolution=image_resolution, patch_size=vision_patch_size, width=vision_width, - layers=vision_layers, + num_layers=vision_layers, heads=vision_heads, output_dim=embed_dim, name="image_encoder", @@ -148,3 +157,21 @@ def presets_with_weights(cls): """Dictionary of preset names and configurations that include weights.""" return copy.deepcopy({**clip_presets}) + + def get_config(self): + config = super().get_config() + config.update( + { + "embed_dim": self.embed_dim, + "image_resolution": self.image_resolution, + "vision_layers": self.vision_layers, + "vision_width": self.vision_width, + "vision_patch_size": self.vision_patch_size, + "context_length": self.context_length, + "vocab_size": self.vocab_size, + "transformer_width": self.transformer_width, + "transformer_heads": self.transformer_heads, + "transformer_layers": self.transformer_layers, + } + ) + return config diff --git a/keras_cv/models/feature_extractor/clip/clip_model_test.py b/keras_cv/models/feature_extractor/clip/clip_model_test.py index f7cb6a6698..f47268a33b 100644 --- a/keras_cv/models/feature_extractor/clip/clip_model_test.py +++ b/keras_cv/models/feature_extractor/clip/clip_model_test.py @@ -55,12 +55,13 @@ def test_clip_model_golden_values(self): processed_image, processed_text, attention_mask ) print(image_logits) - self.assertAllClose(image_logits, [[3.75321, 3.75321, 3.7532094]]) + self.assertAllClose(image_logits, [[3.747046, 3.747046, 3.747046]]) self.assertAllClose( - text_logits, ops.transpose([[3.75321, 3.75321, 3.7532094]]) + text_logits, ops.transpose([[3.747046, 3.747046, 3.747046]]) ) def test_clip_preprocessor(self): + self.skipTest("TODO: Enable after Kaggle model is public") processor = CLIPProcessor(224, VOCAB_PATH, MERGE_PATH) processed_text, attention_mask = processor.process_texts( ["mountains", "cat on tortoise", "two cats"] @@ -87,13 +88,12 @@ def test_image_encoder_golden_values(self): model(processed_image, processed_text, attention_mask) self.assertAllClose( model.image_embeddings[:, :5], - [[0.03867503, -0.05168268, -0.07742637, 0.06213959, -0.0895554]], + [[0.038646, -0.051685, -0.077413, 0.062127, -0.089566]], ) @pytest.mark.large def test_text_encoder_golden_values(self): model = CLIP() - model.load_weights(MODEL_PATH) processed_image = np.ones(shape=[1, 224, 224, 3]) processed_text = np.ones(shape=[3, 77]) attention_mask = np.ones(shape=[3, 77]) @@ -101,13 +101,12 @@ def test_text_encoder_golden_values(self): print(model.text_embeddings) self.assertAllClose( model.text_embeddings[0, :3], - [0.01148358, 0.03956496, -0.0104028], + [0.011359, 0.039782, -0.010593], ) @pytest.mark.large # Saving is slow, so mark these large. def test_saved_model(self): model = CLIP() - model.load_weights(MODEL_PATH) processed_image = np.ones(shape=[1, 224, 224, 3]) processed_text = np.ones(shape=[3, 77]) attention_mask = np.ones(shape=[3, 77]) diff --git a/keras_cv/models/feature_extractor/clip/clip_text_model.py b/keras_cv/models/feature_extractor/clip/clip_text_model.py index e9abdada60..322d91de5b 100644 --- a/keras_cv/models/feature_extractor/clip/clip_text_model.py +++ b/keras_cv/models/feature_extractor/clip/clip_text_model.py @@ -17,28 +17,27 @@ def __init__( super().__init__( **kwargs, ) + self.transformer_width = transformer_width + self.transformer_layers = transformer_layers + self.transformer_heads = transformer_heads + self.vocab_size = vocab_size + self.embed_dim = embed_dim self.context_length = context_length self.token_embedding = keras.layers.Embedding( vocab_size, transformer_width, name="token_embedding", ) - - self.vocab_size = vocab_size self.positional_embedding = keras.layers.Embedding( self.context_length, transformer_width, name="positional_embedding", ) - mask = ops.ones((self.context_length, self.context_length)) - # Zero out the lower diagonal - mask = ops.triu(mask) - mask = ops.cast(mask, "float32") + self.encoder = CLIPEncoder( width=transformer_width, - layers=transformer_layers, + num_layers=transformer_layers, heads=transformer_heads, - attn_mask=mask, name="clip_encoder", ) self.ln_final = keras.layers.LayerNormalization(name="ln_final") @@ -56,13 +55,21 @@ def call(self, inputs, attention_mask=None): position_embedding = ops.tile( position_embedding, repeats=(inputs.shape[0], 1, 1) ) + causal_attention_mask = ops.ones( + (self.context_length, self.context_length) + ) + # Zero out the lower diagonal + causal_attention_mask = ops.triu(causal_attention_mask) + causal_attention_mask = ops.cast(causal_attention_mask, "float32") attention_mask = ops.cast(attention_mask, dtype="float32") expanded_mask = ops.tile( attention_mask[:, None, None, :], (1, 1, self.context_length, 1) ) expanded_mask = (1.0 - expanded_mask) * (-1e8) encoded_output = self.encoder( - token_embedding + position_embedding, attention_mask=expanded_mask + token_embedding + position_embedding, + causal_attention_mask=causal_attention_mask, + attention_mask=expanded_mask, ) layer_norm = self.ln_final(encoded_output) indices = ops.expand_dims( @@ -74,3 +81,17 @@ def call(self, inputs, attention_mask=None): text_features = self.text_projector(selected_features) output = ops.squeeze(text_features, axis=1) return output + + def get_config(self): + config = super().get_config() + config.update( + { + "transformer_width": self.transformer_width, + "transformer_layers": self.transformer_layers, + "transformer_heads": self.transformer_heads, + "vocab_size": self.vocab_size, + "embed_dim": self.embed_dim, + "context_length": self.context_length, + } + ) + return config From d4c7e163e283f065b3b2dd0b05ee45554adc27d2 Mon Sep 17 00:00:00 2001 From: Divyashree Sreepathihalli Date: Wed, 14 Feb 2024 02:11:02 +0000 Subject: [PATCH 28/38] update build --- .../feature_extractor/clip/clip_encoder.py | 75 +-- .../clip/clip_image_model.py | 5 + .../feature_extractor/clip/clip_model.py | 32 +- .../feature_extractor/clip/clip_model_test.py | 13 +- .../feature_extractor/clip/clip_processor.py | 20 +- .../feature_extractor/clip/clip_text_model.py | 8 + .../clip_weights_conversion.ipynb | 486 +++++++++--------- requirements-common.txt | 4 +- setup.py | 2 +- 9 files changed, 341 insertions(+), 304 deletions(-) diff --git a/keras_cv/models/feature_extractor/clip/clip_encoder.py b/keras_cv/models/feature_extractor/clip/clip_encoder.py index 38b47813c9..41fea5aeee 100644 --- a/keras_cv/models/feature_extractor/clip/clip_encoder.py +++ b/keras_cv/models/feature_extractor/clip/clip_encoder.py @@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import numpy as np + from keras_cv.backend import keras from keras_cv.backend import ops @@ -20,7 +22,8 @@ def get_initializer(initializer_range=0.02): Creates a `keras.initializers.TruncatedNormal` with the given range. Args: - initializer_range (*float*, defaults to 0.02): Standard deviation of the initializer range. + initializer_range (*float*, defaults to 0.02): Standard deviation of the + initializer range. Returns: `keras.initializers.TruncatedNormal`: The truncated normal initializer. @@ -48,13 +51,34 @@ def __init__( self.proj_dim = proj_dim self.num_heads = num_heads self.num_hidden_layers = num_hidden_layers - self.fc_std = ops.power(2 * self.proj_dim, -0.5) * 0.02 + self.fc_std = np.power(2 * self.proj_dim, -0.5) * 0.02 self.in_proj_std = ( - ops.power(self.proj_dim, -0.5) - * (ops.power(2 * self.num_hidden_layers, -0.5)) + np.power(self.proj_dim, -0.5) + * (np.power(2 * self.num_hidden_layers, -0.5)) * 0.02 ) + self.attn = CLIPAttention( + self.proj_dim, + self.num_heads, + self.num_hidden_layers, + name="multi_head_attention", + ) + self.ln_1 = keras.layers.LayerNormalization(epsilon=1e-5, name="ln_1") + self.mlp = keras.Sequential( + [ + keras.layers.Dense( + self.proj_dim * 4, + name="c_fc", + ), + QuickGELU(name="gelu"), + keras.layers.Dense( + self.proj_dim, + name="c_proj", + ), + ] + ) + self.ln_2 = keras.layers.LayerNormalization(epsilon=1e-5, name="ln_2") def attention(self, x, causal_attention_mask=None, attention_mask=None): mask = None @@ -75,33 +99,14 @@ def attention(self, x, causal_attention_mask=None, attention_mask=None): return self.attn( x, attention_mask=mask, - ) + )[0] def build(self, input_shape): super().build(input_shape) - self.attn = CLIPAttention( - self.proj_dim, - self.num_heads, - self.num_hidden_layers, - name="multi_head_attention", - ) - self.ln_1 = keras.layers.LayerNormalization(epsilon=1e-5, name="ln_1") - self.mlp = keras.Sequential( - [ - keras.layers.Dense( - self.proj_dim * 4, - kernel_initializer=get_initializer(self.in_proj_std), - name="c_fc", - ), - QuickGELU(name="gelu"), - keras.layers.Dense( - self.proj_dim, - kernel_initializer=get_initializer(self.fc_std), - name="c_proj", - ), - ] - ) - self.ln_2 = keras.layers.LayerNormalization(epsilon=1e-5, name="ln_2") + self.attn.build(None) + self.ln_1.build([None, None, self.proj_dim]) + self.mlp.build(None) + self.ln_2.build([None, None, self.proj_dim]) def call(self, x, causal_attention_mask=None, attention_mask=None): x = x + self.attention( @@ -144,7 +149,7 @@ def __init__(self, width, num_layers, heads, **kwargs): def build(self, input_shape): super().build(input_shape) - self.resblocks.build(input_shape) + map(lambda blocks: blocks.build(input_shape), self.resblocks) def call( self, @@ -199,9 +204,6 @@ def __init__( ) self.scale = self.head_dim**-0.5 - - def build(self, input_shape): - super().build(input_shape) in_proj_std = ( (self.proj_dim**-0.5) * ((2 * self.num_hidden_layers) ** -0.5) @@ -229,6 +231,13 @@ def build(self, input_shape): name="out_proj", ) + def build(self, input_shape): + super().build(input_shape) + self.q_proj.build([None, None, self.proj_dim]) + self.k_proj.build([None, None, self.proj_dim]) + self.v_proj.build([None, None, self.proj_dim]) + self.out_proj.build([None, None, self.proj_dim]) + def _transpose_for_scores(self, tensor, batch_size): """ Copied from https://github.com/huggingface/transformers/blob/8e164c5400b7b413c7b8fb32e35132001effc970/src/transformers/models/bert/modeling_tf_bert.py#L252 # noqa: E501 @@ -290,7 +299,7 @@ def call( outputs = ( (attn_output, _attention_probs) if output_attentions - else attn_output + else (attn_output,) ) return outputs diff --git a/keras_cv/models/feature_extractor/clip/clip_image_model.py b/keras_cv/models/feature_extractor/clip/clip_image_model.py index 4136f6d8a9..c36c5dece4 100644 --- a/keras_cv/models/feature_extractor/clip/clip_image_model.py +++ b/keras_cv/models/feature_extractor/clip/clip_image_model.py @@ -138,7 +138,12 @@ def __init__( ) def build(self, input_shape): + super().build(input_shape) self.embeddings.build(input_shape) + self.pre_norm.build([None, None, self.width]) + self.encoder.build(None) + self.post_norm.build([None, self.width]) + self.image_projector.build([None, None, self.width]) def call(self, image): embeddings = self.embeddings(image) diff --git a/keras_cv/models/feature_extractor/clip/clip_model.py b/keras_cv/models/feature_extractor/clip/clip_model.py index 54e95eacd1..c219caaf35 100644 --- a/keras_cv/models/feature_extractor/clip/clip_model.py +++ b/keras_cv/models/feature_extractor/clip/clip_model.py @@ -16,7 +16,6 @@ from keras_cv.api_export import keras_cv_export from keras_cv.backend import keras from keras_cv.backend import ops -from keras_cv.models.feature_extractor.clip.clip_image_model import CLIPEncoder from keras_cv.models.feature_extractor.clip.clip_image_model import ( CLIPImageEncoder, ) @@ -86,23 +85,23 @@ def __init__( self.transformer_heads = transformer_heads self.transformer_layers = transformer_layers - vision_heads = vision_width // 64 + vision_heads = self.vision_width // 64 self.image_encoder = CLIPImageEncoder( - input_resolution=image_resolution, - patch_size=vision_patch_size, - width=vision_width, - num_layers=vision_layers, + input_resolution=self.image_resolution, + patch_size=self.vision_patch_size, + width=self.vision_width, + num_layers=self.vision_layers, heads=vision_heads, - output_dim=embed_dim, + output_dim=self.embed_dim, name="image_encoder", ) self.text_encoder = CLIPTextEncoder( - transformer_width=transformer_width, - transformer_layers=transformer_layers, - transformer_heads=transformer_heads, - vocab_size=vocab_size, - embed_dim=embed_dim, - context_length=context_length, + transformer_width=self.transformer_width, + transformer_layers=self.transformer_layers, + transformer_heads=self.transformer_heads, + vocab_size=self.vocab_size, + embed_dim=self.embed_dim, + context_length=self.context_length, name="text_encoder", ) @@ -112,6 +111,13 @@ def __init__( self.image_embeddings = None self.text_embeddings = None + def build(self, input_shape): + super().build(input_shape) + self.text_encoder.build([None, self.context_length]) + self.image_encoder.build( + [None, self.image_resolution, self.image_resolution, 3] + ) + def encode_images(self, image): return self.image_encoder(image) diff --git a/keras_cv/models/feature_extractor/clip/clip_model_test.py b/keras_cv/models/feature_extractor/clip/clip_model_test.py index f47268a33b..19013ca798 100644 --- a/keras_cv/models/feature_extractor/clip/clip_model_test.py +++ b/keras_cv/models/feature_extractor/clip/clip_model_test.py @@ -21,10 +21,7 @@ from keras_cv.backend import ops from keras_cv.backend.config import keras_3 from keras_cv.models import CLIP -from keras_cv.models.feature_extractor.clip import CLIPImageEncoder from keras_cv.models.feature_extractor.clip import CLIPProcessor -from keras_cv.models.feature_extractor.clip import CLIPTextEncoder -from keras_cv.models.feature_extractor.clip import CLIPTokenizer from keras_cv.tests.test_case import TestCase VOCAB_PATH = keras.utils.get_file( @@ -38,7 +35,7 @@ MODEL_PATH = keras.utils.get_file( None, - "https://storage.googleapis.com/keras-cv/models/clip/clip-vit-base-patch32.weights.h5", + "https://storage.googleapis.com/keras-cv/models/clip/clip-vit-base-patch32.weights.h5", # noqa: E501 ) @@ -55,9 +52,9 @@ def test_clip_model_golden_values(self): processed_image, processed_text, attention_mask ) print(image_logits) - self.assertAllClose(image_logits, [[3.747046, 3.747046, 3.747046]]) + self.assertAllClose(image_logits, [[2.932678, 2.932678, 2.932675]]) self.assertAllClose( - text_logits, ops.transpose([[3.747046, 3.747046, 3.747046]]) + text_logits, ops.transpose([[2.932678, 2.932678, 2.932675]]) ) def test_clip_preprocessor(self): @@ -88,7 +85,7 @@ def test_image_encoder_golden_values(self): model(processed_image, processed_text, attention_mask) self.assertAllClose( model.image_embeddings[:, :5], - [[0.038646, -0.051685, -0.077413, 0.062127, -0.089566]], + [[0.023215, 0.026526, 0.008914, -0.091689, 0.021791]], ) @pytest.mark.large @@ -101,7 +98,7 @@ def test_text_encoder_golden_values(self): print(model.text_embeddings) self.assertAllClose( model.text_embeddings[0, :3], - [0.011359, 0.039782, -0.010593], + [-0.018502, 0.000906, 0.020372], ) @pytest.mark.large # Saving is slow, so mark these large. diff --git a/keras_cv/models/feature_extractor/clip/clip_processor.py b/keras_cv/models/feature_extractor/clip/clip_processor.py index 80183fcb0e..80e616cc02 100644 --- a/keras_cv/models/feature_extractor/clip/clip_processor.py +++ b/keras_cv/models/feature_extractor/clip/clip_processor.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import numpy as np from keras_nlp.layers import StartEndPacker from keras_cv.api_export import keras_cv_export @@ -45,12 +44,14 @@ class CLIPProcessor: """ - def __init__(self, input_resolution, vocabulary, merges): + def __init__(self, input_resolution, vocabulary, merges, **kwargs): self.input_resolution = input_resolution + self.vocabulary = vocabulary + self.merges = merges self.image_transform = self.transform_image self.tokenizer = CLIPTokenizer( - vocabulary=vocabulary, - merges=merges, + vocabulary=self.vocabulary, + merges=self.merges, unsplittable_tokens=[""], ) self.packer = StartEndPacker( @@ -117,3 +118,14 @@ def pack_tokens(text): ) return pack_tokens(texts) + + def get_config(self): + config = super().get_config() + config.update( + { + "input_resolution": self.input_resolution, + "vocabulary": self.vocabulary, + "merges": self.merges, + } + ) + return config diff --git a/keras_cv/models/feature_extractor/clip/clip_text_model.py b/keras_cv/models/feature_extractor/clip/clip_text_model.py index 322d91de5b..d40c86d7d2 100644 --- a/keras_cv/models/feature_extractor/clip/clip_text_model.py +++ b/keras_cv/models/feature_extractor/clip/clip_text_model.py @@ -46,6 +46,14 @@ def __init__( embed_dim, name="text_projector", use_bias=False ) + def build(self, input_shape): + super().build(input_shape) + self.token_embedding.build(input_shape) + self.positional_embedding.build([1, self.context_length]) + self.encoder.build(None) + self.ln_final.build([None, None, self.transformer_width]) + self.text_projector.build([None, None, self.transformer_width]) + def call(self, inputs, attention_mask=None): token_embedding = self.token_embedding(inputs) position_ids = ops.expand_dims( diff --git a/keras_cv/tools/checkpoint_conversion/clip_weights_conversion.ipynb b/keras_cv/tools/checkpoint_conversion/clip_weights_conversion.ipynb index 9e4a771b5f..13e443669a 100644 --- a/keras_cv/tools/checkpoint_conversion/clip_weights_conversion.ipynb +++ b/keras_cv/tools/checkpoint_conversion/clip_weights_conversion.ipynb @@ -1,37 +1,17 @@ { - "nbformat": 4, - "nbformat_minor": 0, - "metadata": { - "colab": { - "provenance": [] - }, - "kernelspec": { - "name": "python3", - "display_name": "Python 3" - }, - "language_info": { - "name": "python" - } - }, "cells": [ { "cell_type": "markdown", - "source": [ - "# Setup" - ], "metadata": { "id": "0DhV6hzOMY0W" - } + }, + "source": [ + "# Setup" + ] }, { "cell_type": "code", - "source": [ - "!pip install -q git+https://github.com/divyashreepathihalli/keras-cv.git@CLIP_refactor\n", - "!pip install -q keras-nlp\n", - "!pip install -q tf-keras\n", - "!pip install -q tensorflow-text\n", - "!pip install keras==3.0.2" - ], + "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" @@ -39,11 +19,10 @@ "id": "cRzYR-oFgxt1", "outputId": "e4b01fcd-9f71-4ba7-b8a2-1796f7ef260d" }, - "execution_count": null, "outputs": [ { - "output_type": "stream", "name": "stdout", + "output_type": "stream", "text": [ " Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n", " Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n", @@ -74,48 +53,51 @@ "\u001b[0mSuccessfully installed keras-3.0.2\n" ] } + ], + "source": [ + "!pip install -q git+https://github.com/divyashreepathihalli/keras-cv.git@CLIP_refactor\n", + "!pip install -q keras-nlp\n", + "!pip install -q tf-keras\n", + "!pip install -q tensorflow-text\n", + "!pip install keras==3.0.2" ] }, { "cell_type": "markdown", - "source": [ - "# Import" - ], "metadata": { "id": "mdGT8Em4Mc4b" - } + }, + "source": [ + "# Import" + ] }, { "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "GDvJmQuug4-x" + }, + "outputs": [], "source": [ "from keras_cv.models.feature_extractor.clip import CLIPProcessor\n", "import keras\n", "from keras_cv.models import CLIP" - ], - "metadata": { - "id": "GDvJmQuug4-x" - }, - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "code", - "source": [ - "!wget https://i.imgur.com/8H7XCH0.jpg -O cat.jpg\n", - "!wget http://images.cocodataset.org/val2017/000000039769.jpg -O test.jpg" - ], + "execution_count": null, "metadata": { - "id": "nuFgha2jTshi", "colab": { "base_uri": "https://localhost:8080/" }, + "id": "nuFgha2jTshi", "outputId": "b99d73eb-cc97-47d0-f46e-687c9e8b8237" }, - "execution_count": null, "outputs": [ { - "output_type": "stream", "name": "stdout", + "output_type": "stream", "text": [ "--2024-02-02 22:19:20-- https://i.imgur.com/8H7XCH0.jpg\n", "Resolving i.imgur.com (i.imgur.com)... 151.101.52.193\n", @@ -141,10 +123,20 @@ "\n" ] } + ], + "source": [ + "!wget https://i.imgur.com/8H7XCH0.jpg -O cat.jpg\n", + "!wget http://images.cocodataset.org/val2017/000000039769.jpg -O test.jpg" ] }, { "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "X3kkmK6h_gFH" + }, + "outputs": [], "source": [ "# @title Select which model weights you would like to convert\n", "MODEL_CONFIGS = {\n", @@ -205,25 +197,24 @@ "}\n", "config_name = \"CLIP_L14_336\" # @param [\"CLIP_B16\", \"CLIP_B32\", \"CLIP_L14\", \"CLIP_L14_336\"]\n", "config_name_hf = model_map_hf[config_name]" - ], - "metadata": { - "cellView": "form", - "id": "X3kkmK6h_gFH" - }, - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "markdown", - "source": [ - "# Keras 3 CLIP" - ], "metadata": { "id": "2l3Ll7dMMd-m" - } + }, + "source": [ + "# Keras 3 CLIP" + ] }, { "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "urhuhwq0Dczo" + }, + "outputs": [], "source": [ "embed_dim = MODEL_CONFIGS[config_name][\"embed_dim\"]\n", "context_length = MODEL_CONFIGS[config_name][\"context_length\"]\n", @@ -247,12 +238,7 @@ " transformer_heads,\n", " transformer_layers,\n", ")" - ], - "metadata": { - "id": "urhuhwq0Dczo" - }, - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "code", @@ -267,30 +253,20 @@ }, "outputs": [ { - "output_type": "display_data", "data": { - "text/plain": [ - "\u001b[1mModel: \"clip\"\u001b[0m\n" - ], "text/html": [ "
Model: \"clip\"\n",
        "
\n" + ], + "text/plain": [ + "\u001b[1mModel: \"clip\"\u001b[0m\n" ] }, - "metadata": {} + "metadata": {}, + "output_type": "display_data" }, { - "output_type": "display_data", "data": { - "text/plain": [ - "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━┓\n", - "┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Param #\u001b[0m\u001b[1m \u001b[0m┃\n", - "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━┩\n", - "│ image_encoder (\u001b[38;5;33mCLIPImageEncoder\u001b[0m) │ ? │ \u001b[38;5;34m0\u001b[0m (unbuilt) │\n", - "├────────────────────────────────────┼───────────────────────────────┼─────────────┤\n", - "│ text_encoder (\u001b[38;5;33mCLIPTextEncoder\u001b[0m) │ ? │ \u001b[38;5;34m0\u001b[0m (unbuilt) │\n", - "└────────────────────────────────────┴───────────────────────────────┴─────────────┘\n" - ], "text/html": [ "
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━┓\n",
        "┃ Layer (type)                        Output Shape                       Param # ┃\n",
@@ -300,48 +276,58 @@
        "│ text_encoder (CLIPTextEncoder)     │ ?                             │ 0 (unbuilt) │\n",
        "└────────────────────────────────────┴───────────────────────────────┴─────────────┘\n",
        "
\n" + ], + "text/plain": [ + "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━┓\n", + "┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Param #\u001b[0m\u001b[1m \u001b[0m┃\n", + "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━┩\n", + "│ image_encoder (\u001b[38;5;33mCLIPImageEncoder\u001b[0m) │ ? │ \u001b[38;5;34m0\u001b[0m (unbuilt) │\n", + "├────────────────────────────────────┼───────────────────────────────┼─────────────┤\n", + "│ text_encoder (\u001b[38;5;33mCLIPTextEncoder\u001b[0m) │ ? │ \u001b[38;5;34m0\u001b[0m (unbuilt) │\n", + "└────────────────────────────────────┴───────────────────────────────┴─────────────┘\n" ] }, - "metadata": {} + "metadata": {}, + "output_type": "display_data" }, { - "output_type": "display_data", "data": { - "text/plain": [ - "\u001b[1m Total params: \u001b[0m\u001b[38;5;34m39,425\u001b[0m (154.00 KB)\n" - ], "text/html": [ "
 Total params: 39,425 (154.00 KB)\n",
        "
\n" + ], + "text/plain": [ + "\u001b[1m Total params: \u001b[0m\u001b[38;5;34m39,425\u001b[0m (154.00 KB)\n" ] }, - "metadata": {} + "metadata": {}, + "output_type": "display_data" }, { - "output_type": "display_data", "data": { - "text/plain": [ - "\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m39,425\u001b[0m (154.00 KB)\n" - ], "text/html": [ "
 Trainable params: 39,425 (154.00 KB)\n",
        "
\n" + ], + "text/plain": [ + "\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m39,425\u001b[0m (154.00 KB)\n" ] }, - "metadata": {} + "metadata": {}, + "output_type": "display_data" }, { - "output_type": "display_data", "data": { - "text/plain": [ - "\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m0\u001b[0m (0.00 B)\n" - ], "text/html": [ "
 Non-trainable params: 0 (0.00 B)\n",
        "
\n" + ], + "text/plain": [ + "\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m0\u001b[0m (0.00 B)\n" ] }, - "metadata": {} + "metadata": {}, + "output_type": "display_data" } ], "source": [ @@ -350,6 +336,11 @@ }, { "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "buXKlNfGTenW" + }, + "outputs": [], "source": [ "processor = CLIPProcessor(224, \"vocab.json\", \"merges.txt\")\n", "image = processor.process_images([\"cat.jpg\"])\n", @@ -359,21 +350,11 @@ " \"a photo of a tortoise\",\n", "]\n", "text = processor.process_texts(text_input)" - ], - "metadata": { - "id": "buXKlNfGTenW" - }, - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "code", - "source": [ - "image_logits, text_logits = model(image, text)\n", - "output = keras.layers.Softmax()(image_logits)\n", - "print(image_logits)\n", - "print(text_input[keras.ops.argmax(output)])" - ], + "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" @@ -381,23 +362,26 @@ "id": "BHSpMv0PT5SX", "outputId": "566c92c4-fbf3-4e2d-87f1-6112b2cff96f" }, - "execution_count": null, "outputs": [ { - "output_type": "stream", "name": "stdout", + "output_type": "stream", "text": [ "tf.Tensor([[ 0.42190465 0.6262117 -0.2368357 ]], shape=(1, 3), dtype=float32)\n", "tortoise on a dog\n" ] } + ], + "source": [ + "image_logits, text_logits = model(image, text)\n", + "output = keras.layers.Softmax()(image_logits)\n", + "print(image_logits)\n", + "print(text_input[keras.ops.argmax(output)])" ] }, { "cell_type": "code", - "source": [ - "model.summary()" - ], + "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", @@ -406,33 +390,22 @@ "id": "GgNBvYCTtmA3", "outputId": "35b9a26c-325e-4535-c33b-3f67ab112e19" }, - "execution_count": null, "outputs": [ { - "output_type": "display_data", "data": { - "text/plain": [ - "\u001b[1mModel: \"clip\"\u001b[0m\n" - ], "text/html": [ "
Model: \"clip\"\n",
        "
\n" + ], + "text/plain": [ + "\u001b[1mModel: \"clip\"\u001b[0m\n" ] }, - "metadata": {} + "metadata": {}, + "output_type": "display_data" }, { - "output_type": "display_data", "data": { - "text/plain": [ - "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━┓\n", - "┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Param #\u001b[0m\u001b[1m \u001b[0m┃\n", - "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━┩\n", - "│ image_encoder (\u001b[38;5;33mCLIPImageEncoder\u001b[0m) │ ? │ \u001b[38;5;34m87,849,216\u001b[0m │\n", - "├────────────────────────────────────┼───────────────────────────────┼─────────────┤\n", - "│ text_encoder (\u001b[38;5;33mCLIPTextEncoder\u001b[0m) │ ? │ \u001b[38;5;34m63,428,096\u001b[0m │\n", - "└────────────────────────────────────┴───────────────────────────────┴─────────────┘\n" - ], "text/html": [ "
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━┓\n",
        "┃ Layer (type)                        Output Shape                       Param # ┃\n",
@@ -442,103 +415,136 @@
        "│ text_encoder (CLIPTextEncoder)     │ ?                             │  63,428,096 │\n",
        "└────────────────────────────────────┴───────────────────────────────┴─────────────┘\n",
        "
\n" + ], + "text/plain": [ + "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━┓\n", + "┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Param #\u001b[0m\u001b[1m \u001b[0m┃\n", + "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━┩\n", + "│ image_encoder (\u001b[38;5;33mCLIPImageEncoder\u001b[0m) │ ? │ \u001b[38;5;34m87,849,216\u001b[0m │\n", + "├────────────────────────────────────┼───────────────────────────────┼─────────────┤\n", + "│ text_encoder (\u001b[38;5;33mCLIPTextEncoder\u001b[0m) │ ? │ \u001b[38;5;34m63,428,096\u001b[0m │\n", + "└────────────────────────────────────┴───────────────────────────────┴─────────────┘\n" ] }, - "metadata": {} + "metadata": {}, + "output_type": "display_data" }, { - "output_type": "display_data", "data": { - "text/plain": [ - "\u001b[1m Total params: \u001b[0m\u001b[38;5;34m151,277,313\u001b[0m (577.08 MB)\n" - ], "text/html": [ "
 Total params: 151,277,313 (577.08 MB)\n",
        "
\n" + ], + "text/plain": [ + "\u001b[1m Total params: \u001b[0m\u001b[38;5;34m151,277,313\u001b[0m (577.08 MB)\n" ] }, - "metadata": {} + "metadata": {}, + "output_type": "display_data" }, { - "output_type": "display_data", "data": { - "text/plain": [ - "\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m151,277,313\u001b[0m (577.08 MB)\n" - ], "text/html": [ "
 Trainable params: 151,277,313 (577.08 MB)\n",
        "
\n" + ], + "text/plain": [ + "\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m151,277,313\u001b[0m (577.08 MB)\n" ] }, - "metadata": {} + "metadata": {}, + "output_type": "display_data" }, { - "output_type": "display_data", "data": { - "text/plain": [ - "\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m0\u001b[0m (0.00 B)\n" - ], "text/html": [ "
 Non-trainable params: 0 (0.00 B)\n",
        "
\n" + ], + "text/plain": [ + "\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m0\u001b[0m (0.00 B)\n" ] }, - "metadata": {} + "metadata": {}, + "output_type": "display_data" } + ], + "source": [ + "model.summary()" ] }, { "cell_type": "markdown", - "source": [ - "# HF CLIP" - ], "metadata": { "id": "P8DWYq_hVFnz" - } + }, + "source": [ + "# HF CLIP" + ] }, { "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "3W2prd6C0pxe" + }, + "outputs": [], "source": [ "from PIL import Image\n", "import requests\n", "\n", "from transformers import CLIPProcessor as CP\n", "from transformers import CLIPModel as CM" - ], - "metadata": { - "id": "3W2prd6C0pxe" - }, - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "code", - "source": [ - "model_hf = CM.from_pretrained(config_name_hf)\n", - "processor = CP.from_pretrained(config_name_hf)" - ], + "execution_count": null, "metadata": { - "id": "EntuvOq1MhwU", "colab": { "base_uri": "https://localhost:8080/" }, + "id": "EntuvOq1MhwU", "outputId": "e154a367-2f94-4fa1-e97d-d2f32db7a2cf" }, - "execution_count": null, "outputs": [ { - "output_type": "stream", "name": "stderr", + "output_type": "stream", "text": [ "`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The value `text_config[\"id2label\"]` will be overriden.\n", "`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The value `text_config[\"bos_token_id\"]` will be overriden.\n", "`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The value `text_config[\"eos_token_id\"]` will be overriden.\n" ] } + ], + "source": [ + "model_hf = CM.from_pretrained(config_name_hf)\n", + "processor = CP.from_pretrained(config_name_hf)" ] }, { "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "Ep8DRTkv3AwS", + "outputId": "770756bc-8829-484f-b6e5-763fe81e24d0" + }, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[0.9957, 0.0023, 0.0020]], grad_fn=)" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "url = \"https://i.imgur.com/8H7XCH0.jpg\"\n", "image_hf = Image.open(requests.get(url, stream=True).raw)\n", @@ -559,60 +565,45 @@ " dim=1\n", ") # we can take the softmax to get the label probabilitiesprobs\n", "probs" - ], - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "Ep8DRTkv3AwS", - "outputId": "770756bc-8829-484f-b6e5-763fe81e24d0" - }, - "execution_count": null, - "outputs": [ - { - "output_type": "execute_result", - "data": { - "text/plain": [ - "tensor([[0.9957, 0.0023, 0.0020]], grad_fn=)" - ] - }, - "metadata": {}, - "execution_count": 14 - } ] }, { "cell_type": "code", - "source": [ - "# hugging face weights\n", - "hf_wts = model_hf.state_dict()" - ], + "execution_count": null, "metadata": { "id": "wPa0cVnY3cBC" }, - "execution_count": null, - "outputs": [] + "outputs": [], + "source": [ + "# hugging face weights\n", + "hf_wts = model_hf.state_dict()" + ] }, { "cell_type": "markdown", - "source": [ - "# Copy weights" - ], "metadata": { "id": "ArkCHlVZVKfM" - } + }, + "source": [ + "# Copy weights" + ] }, { "cell_type": "markdown", - "source": [ - "##vision encoder" - ], "metadata": { "id": "TUCpKltRG4Gd" - } + }, + "source": [ + "##vision encoder" + ] }, { "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "tn_U02N7U2VN" + }, + "outputs": [], "source": [ "model.logit_scale.assign(hf_wts.pop(\"logit_scale\").numpy())\n", "model.get_layer(\"image_encoder\").get_layer(\n", @@ -647,15 +638,15 @@ "model.get_layer(\"image_encoder\").get_layer(\"vision_projector\").weights[\n", " 0\n", "].assign(hf_wts.pop(\"visual_projection.weight\").transpose(1, 0).numpy())" - ], - "metadata": { - "id": "tn_U02N7U2VN" - }, - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "qptfuWobZcbT" + }, + "outputs": [], "source": [ "for i in range(0, MODEL_CONFIGS[config_name][\"vision_layers\"]):\n", " if i == 0:\n", @@ -762,35 +753,35 @@ " ].assign(\n", " hf_wts.pop(f\"vision_model.encoder.layers.{i}.mlp.fc2.bias\").numpy()\n", " )" - ], - "metadata": { - "id": "qptfuWobZcbT" - }, - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "markdown", - "source": [ - "## Text encoder" - ], "metadata": { "id": "1RN2aVrYG8T3" - } + }, + "source": [ + "## Text encoder" + ] }, { "cell_type": "code", - "source": [ - "num_transformer_layers = MODEL_CONFIGS[config_name][\"vision_layers\"]" - ], + "execution_count": null, "metadata": { "id": "5FtDROnynb0N" }, - "execution_count": null, - "outputs": [] + "outputs": [], + "source": [ + "num_transformer_layers = MODEL_CONFIGS[config_name][\"vision_layers\"]" + ] }, { "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "_1AD7TcbdWEC" + }, + "outputs": [], "source": [ "model.get_layer(\"text_encoder\").get_layer(\"text_projector\").weights[0].assign(\n", " hf_wts.pop(\"text_projection.weight\").numpy()\n", @@ -807,15 +798,15 @@ "model.get_layer(\"text_encoder\").get_layer(\"ln_final\").weights[1].assign(\n", " hf_wts.pop(\"text_model.final_layer_norm.bias\")\n", ")" - ], - "metadata": { - "id": "_1AD7TcbdWEC" - }, - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "s6leOiFO6V2U" + }, + "outputs": [], "source": [ "for i in range(MODEL_CONFIGS[config_name][\"transformer_layers\"]):\n", " model.get_layer(\"text_encoder\").get_layer(\n", @@ -974,19 +965,11 @@ " ].assign(\n", " hf_wts.pop(f\"text_model.encoder.layers.{i}.mlp.fc2.bias\").numpy()\n", " )" - ], - "metadata": { - "id": "s6leOiFO6V2U" - }, - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "code", - "source": [ - "# verify that we copied all weights\n", - "hf_wts.keys()" - ], + "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" @@ -994,39 +977,56 @@ "id": "Bgen7hxCCeZ7", "outputId": "c777d6f1-4aa7-4f3e-8fd7-759364364c44" }, - "execution_count": null, "outputs": [ { - "output_type": "execute_result", "data": { "text/plain": [ "odict_keys([])" ] }, + "execution_count": 22, "metadata": {}, - "execution_count": 22 + "output_type": "execute_result" } + ], + "source": [ + "# verify that we copied all weights\n", + "hf_wts.keys()" ] }, { "cell_type": "markdown", - "source": [ - "# save weights" - ], "metadata": { "id": "wlfDdO-mid62" - } + }, + "source": [ + "# save weights" + ] }, { "cell_type": "code", - "source": [ - "model.save_weights(\"clip-vit-base-patch32.weights.h5\")" - ], + "execution_count": null, "metadata": { "id": "QscCUUZFiqBV" }, - "execution_count": null, - "outputs": [] + "outputs": [], + "source": [ + "model.save_weights(\"clip-vit-base-patch32.weights.h5\")" + ] + } + ], + "metadata": { + "colab": { + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "name": "python" } - ] -} \ No newline at end of file + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/requirements-common.txt b/requirements-common.txt index 038a886a9b..d5b0f20776 100644 --- a/requirements-common.txt +++ b/requirements-common.txt @@ -13,5 +13,5 @@ isort black pytest build -keras-nlp -namex \ No newline at end of file +keras_nlp_nightly +namex diff --git a/setup.py b/setup.py index e00c348978..61571beb00 100644 --- a/setup.py +++ b/setup.py @@ -75,7 +75,7 @@ def is_pure(self): "regex", "tensorflow-datasets", "keras-core", - "keras-nlp", + "keras_nlp_nightly", "kagglehub", ], extras_require={ From 305fb0af28dbd203b91629929a61c35603fd6893 Mon Sep 17 00:00:00 2001 From: Divyashree Sreepathihalli Date: Wed, 14 Feb 2024 02:14:20 +0000 Subject: [PATCH 29/38] add copywrite to presets file --- .../models/feature_extractor/clip/clip_presets.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/keras_cv/models/feature_extractor/clip/clip_presets.py b/keras_cv/models/feature_extractor/clip/clip_presets.py index 041e87a24c..ef26d4b045 100644 --- a/keras_cv/models/feature_extractor/clip/clip_presets.py +++ b/keras_cv/models/feature_extractor/clip/clip_presets.py @@ -1,3 +1,16 @@ +# Copyright 2023 The KerasCV Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. """CLIP presets.""" clip_presets = { From 9e6ff3bac03104e19582fbb14d68590187dbe7a4 Mon Sep 17 00:00:00 2001 From: Divyashree Sreepathihalli Date: Wed, 14 Feb 2024 22:55:14 +0000 Subject: [PATCH 30/38] fix build state errors --- .../feature_extractor/clip/clip_encoder.py | 33 ++++++++++--------- .../feature_extractor/clip/clip_model_test.py | 1 - .../feature_extractor/clip/clip_text_model.py | 13 ++++++++ 3 files changed, 30 insertions(+), 17 deletions(-) diff --git a/keras_cv/models/feature_extractor/clip/clip_encoder.py b/keras_cv/models/feature_extractor/clip/clip_encoder.py index 41fea5aeee..95aabfac28 100644 --- a/keras_cv/models/feature_extractor/clip/clip_encoder.py +++ b/keras_cv/models/feature_extractor/clip/clip_encoder.py @@ -65,18 +65,14 @@ def __init__( name="multi_head_attention", ) self.ln_1 = keras.layers.LayerNormalization(epsilon=1e-5, name="ln_1") - self.mlp = keras.Sequential( - [ - keras.layers.Dense( - self.proj_dim * 4, - name="c_fc", - ), - QuickGELU(name="gelu"), - keras.layers.Dense( - self.proj_dim, - name="c_proj", - ), - ] + self.mlp_dense_1 = keras.layers.Dense( + self.proj_dim * 4, + name="c_fc", + ) + self.mlp_activation = QuickGELU(name="gelu") + self.mlp_dense_2 = keras.layers.Dense( + self.proj_dim, + name="c_proj", ) self.ln_2 = keras.layers.LayerNormalization(epsilon=1e-5, name="ln_2") @@ -105,16 +101,20 @@ def build(self, input_shape): super().build(input_shape) self.attn.build(None) self.ln_1.build([None, None, self.proj_dim]) - self.mlp.build(None) + self.mlp_dense_1.build([None, None, self.proj_dim]) + self.mlp_dense_2.build([None, None, self.proj_dim * 4]) self.ln_2.build([None, None, self.proj_dim]) def call(self, x, causal_attention_mask=None, attention_mask=None): - x = x + self.attention( + attn_x = x + self.attention( self.ln_1(x), causal_attention_mask=causal_attention_mask, attention_mask=attention_mask, ) - x = x + self.mlp(self.ln_2(x)) + x = self.mlp_dense_1(self.ln_2(attn_x)) + x = self.mlp_activation(x) + x = self.mlp_dense_2(x) + x = attn_x + x return x def compute_output_shape(self, inputs_shape): @@ -149,7 +149,8 @@ def __init__(self, width, num_layers, heads, **kwargs): def build(self, input_shape): super().build(input_shape) - map(lambda blocks: blocks.build(input_shape), self.resblocks) + for block in self.resblocks: + block.build(input_shape) def call( self, diff --git a/keras_cv/models/feature_extractor/clip/clip_model_test.py b/keras_cv/models/feature_extractor/clip/clip_model_test.py index 19013ca798..90bcd85774 100644 --- a/keras_cv/models/feature_extractor/clip/clip_model_test.py +++ b/keras_cv/models/feature_extractor/clip/clip_model_test.py @@ -117,7 +117,6 @@ def test_saved_model(self): # Check we got the real object back. self.assertIsInstance(restored_model, CLIP) - # Check that output matches. restored_output, _ = restored_model( processed_image, processed_text, attention_mask diff --git a/keras_cv/models/feature_extractor/clip/clip_text_model.py b/keras_cv/models/feature_extractor/clip/clip_text_model.py index d40c86d7d2..5fc92990d2 100644 --- a/keras_cv/models/feature_extractor/clip/clip_text_model.py +++ b/keras_cv/models/feature_extractor/clip/clip_text_model.py @@ -1,3 +1,16 @@ +# Copyright 2023 The KerasCV Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from keras_cv.backend import keras from keras_cv.backend import ops from keras_cv.models.feature_extractor.clip.clip_encoder import CLIPEncoder From 1c88b7e73bcb32aa8861f0b726917b4e1c32e598 Mon Sep 17 00:00:00 2001 From: Divyashree Sreepathihalli Date: Fri, 16 Feb 2024 19:20:53 +0000 Subject: [PATCH 31/38] update github actions and add preprocessor test --- .github/workflows/actions.yml | 3 +++ keras_cv/models/feature_extractor/clip/clip_model.py | 8 +++++++- .../models/feature_extractor/clip/clip_model_test.py | 10 ++++++++-- setup.py | 1 - 4 files changed, 18 insertions(+), 4 deletions(-) diff --git a/.github/workflows/actions.yml b/.github/workflows/actions.yml index 8be69b967b..316e623c57 100644 --- a/.github/workflows/actions.yml +++ b/.github/workflows/actions.yml @@ -38,6 +38,8 @@ jobs: pip install torch>=2.0.1+cpu pip install "jax[cpu]" pip install keras-core + pip install keras-nlp-nightly --no-deps + pip install tensorflow-text==2.15 pip install -e ".[tests]" --progress-bar off --upgrade - name: Test with pytest env: @@ -75,6 +77,7 @@ jobs: run: | pip install -r requirements.txt pip install -e ".[tests]" --progress-bar off --upgrade + pip install keras-nlp-nightly - name: Test with pytest env: TEST_CUSTOM_OPS: false # TODO(ianstenbit): test custom ops, or figure out what our story is here diff --git a/keras_cv/models/feature_extractor/clip/clip_model.py b/keras_cv/models/feature_extractor/clip/clip_model.py index c219caaf35..22a41119c8 100644 --- a/keras_cv/models/feature_extractor/clip/clip_model.py +++ b/keras_cv/models/feature_extractor/clip/clip_model.py @@ -73,7 +73,13 @@ def __init__( **kwargs, ): super().__init__(**kwargs) - + try: + import keras_nlp # noqa: F401 + except ImportError: + raise ImportError( + "CLIP model requires keras-nlp. Please pip " + "install keras-nlp." + ) self.embed_dim = embed_dim self.image_resolution = image_resolution self.vision_layers = vision_layers diff --git a/keras_cv/models/feature_extractor/clip/clip_model_test.py b/keras_cv/models/feature_extractor/clip/clip_model_test.py index 90bcd85774..709b2545d0 100644 --- a/keras_cv/models/feature_extractor/clip/clip_model_test.py +++ b/keras_cv/models/feature_extractor/clip/clip_model_test.py @@ -58,10 +58,16 @@ def test_clip_model_golden_values(self): ) def test_clip_preprocessor(self): - self.skipTest("TODO: Enable after Kaggle model is public") + # self.skipTest("TODO: Enable after Kaggle model is public") processor = CLIPProcessor(224, VOCAB_PATH, MERGE_PATH) processed_text, attention_mask = processor.process_texts( - ["mountains", "cat on tortoise", "two cats"] + ["mountains", "cat on tortoise"] + ) + self.assertAllClose( + processed_text[:, :3], [[49406, 5873, 49407], [49406, 2368, 525]] + ) + self.assertAllClose( + attention_mask[0, :5], [True, True, True, False, False] ) @pytest.mark.large diff --git a/setup.py b/setup.py index 61571beb00..19dc42248c 100644 --- a/setup.py +++ b/setup.py @@ -75,7 +75,6 @@ def is_pure(self): "regex", "tensorflow-datasets", "keras-core", - "keras_nlp_nightly", "kagglehub", ], extras_require={ From eb2bd44339edc9e7d461601cac80aed22ea197e8 Mon Sep 17 00:00:00 2001 From: Divyashree Sreepathihalli Date: Fri, 16 Feb 2024 22:12:26 +0000 Subject: [PATCH 32/38] incorporate review comments --- .../models/feature_extractor/clip/clip_image_model.py | 1 + .../models/feature_extractor/clip/clip_model_test.py | 1 - keras_cv/models/feature_extractor/clip/clip_presets.py | 2 +- keras_cv/models/feature_extractor/clip/clip_tokenizer.py | 9 +++++++-- requirements-common.txt | 1 - 5 files changed, 9 insertions(+), 5 deletions(-) diff --git a/keras_cv/models/feature_extractor/clip/clip_image_model.py b/keras_cv/models/feature_extractor/clip/clip_image_model.py index c36c5dece4..efb62c945c 100644 --- a/keras_cv/models/feature_extractor/clip/clip_image_model.py +++ b/keras_cv/models/feature_extractor/clip/clip_image_model.py @@ -46,6 +46,7 @@ def __init__( self.output_dim = output_dim def build(self, input_shape): + super().build(input_shape) self.conv1.build(input_shape) self.class_embedding = self.add_weight( shape=((self.width,)), diff --git a/keras_cv/models/feature_extractor/clip/clip_model_test.py b/keras_cv/models/feature_extractor/clip/clip_model_test.py index 709b2545d0..3a0c367c96 100644 --- a/keras_cv/models/feature_extractor/clip/clip_model_test.py +++ b/keras_cv/models/feature_extractor/clip/clip_model_test.py @@ -39,7 +39,6 @@ ) -@pytest.mark.tf_only class CLIPTest(TestCase): @pytest.mark.large def test_clip_model_golden_values(self): diff --git a/keras_cv/models/feature_extractor/clip/clip_presets.py b/keras_cv/models/feature_extractor/clip/clip_presets.py index ef26d4b045..4198a76c55 100644 --- a/keras_cv/models/feature_extractor/clip/clip_presets.py +++ b/keras_cv/models/feature_extractor/clip/clip_presets.py @@ -76,6 +76,6 @@ "official_name": "CLIP", "path": "clip", }, - "kaggle_handle": "kaggle://keras/clip/keras/clip-vit-large-patch14/1", + "kaggle_handle": "kaggle://keras/clip/keras/clip-vit-large-patch14-336/1", }, } diff --git a/keras_cv/models/feature_extractor/clip/clip_tokenizer.py b/keras_cv/models/feature_extractor/clip/clip_tokenizer.py index 8a7b5ac3b9..8a975be594 100644 --- a/keras_cv/models/feature_extractor/clip/clip_tokenizer.py +++ b/keras_cv/models/feature_extractor/clip/clip_tokenizer.py @@ -14,8 +14,13 @@ import regex as re import tensorflow as tf import tensorflow_text as tf_text -from keras_nlp.tokenizers import BytePairTokenizer - +try: + from keras_nlp.tokenizers import BytePairTokenizer +except ImportError: + raise ImportError( + "CLIP model requires keras-nlp. Please pip " + "install keras-nlp." + ) VOCAB_FILENAME = "keras_cv/models/feature_extractors/clip/vocab.json" MERGES_FILENAME = "keras_cv/models/feature_extractors/clip/merges.txt" # As python and TF handles special spaces differently, we need to diff --git a/requirements-common.txt b/requirements-common.txt index d5b0f20776..29f7ee9a19 100644 --- a/requirements-common.txt +++ b/requirements-common.txt @@ -13,5 +13,4 @@ isort black pytest build -keras_nlp_nightly namex From 38e00b7eb16110803c0887a9277a47dd8b82937f Mon Sep 17 00:00:00 2001 From: Divyashree Sreepathihalli Date: Tue, 20 Feb 2024 23:44:03 +0000 Subject: [PATCH 33/38] add modifications from review --- .../feature_extractor/clip/clip_encoder.py | 27 ++++++++++--------- .../clip/clip_image_model.py | 14 +++++----- .../feature_extractor/clip/clip_model.py | 20 ++++++-------- .../feature_extractor/clip/clip_presets.py | 2 +- .../feature_extractor/clip/clip_text_model.py | 1 + .../feature_extractor/clip/clip_tokenizer.py | 27 +++---------------- 6 files changed, 35 insertions(+), 56 deletions(-) diff --git a/keras_cv/models/feature_extractor/clip/clip_encoder.py b/keras_cv/models/feature_extractor/clip/clip_encoder.py index 95aabfac28..aeb345c857 100644 --- a/keras_cv/models/feature_extractor/clip/clip_encoder.py +++ b/keras_cv/models/feature_extractor/clip/clip_encoder.py @@ -106,15 +106,19 @@ def build(self, input_shape): self.ln_2.build([None, None, self.proj_dim]) def call(self, x, causal_attention_mask=None, attention_mask=None): - attn_x = x + self.attention( - self.ln_1(x), + residual = x + x = self.ln_1(x) + x = self.attention( + x, causal_attention_mask=causal_attention_mask, attention_mask=attention_mask, ) - x = self.mlp_dense_1(self.ln_2(attn_x)) + x = x + residual + residual = x + x = self.mlp_dense_1(self.ln_2(residual)) x = self.mlp_activation(x) x = self.mlp_dense_2(x) - x = attn_x + x + x = residual + x return x def compute_output_shape(self, inputs_shape): @@ -183,8 +187,7 @@ def get_config(self): class CLIPAttention(keras.layers.Layer): """ - - Documentation page: https://huggingface.co/docs/transformers/model_doc/clip # noqa: E501 - - Implementation: https://github.com/huggingface/transformers/blob/main/src/transformers/models/clip/modeling_clip.py # noqa: E501 + Adapted from https://github.com/huggingface/transformers/blob/main/src/transformers/models/clip/modeling_clip.py # noqa: E501 """ def __init__( @@ -241,7 +244,7 @@ def build(self, input_shape): def _transpose_for_scores(self, tensor, batch_size): """ - Copied from https://github.com/huggingface/transformers/blob/8e164c5400b7b413c7b8fb32e35132001effc970/src/transformers/models/bert/modeling_tf_bert.py#L252 # noqa: E501 + Adapted from https://github.com/huggingface/transformers/blob/8e164c5400b7b413c7b8fb32e35132001effc970/src/transformers/models/bert/modeling_tf_bert.py#L252 # noqa: E501 """ # [batch_size, seq_len, all_head_dim] -> # [batch_size, seq_len, num_heads, head_dim] @@ -282,15 +285,15 @@ def call( attention_scores = ops.add(attention_scores, attention_mask) # Normalize the attention scores to probabilities. - _attention_probs = ops.softmax(attention_scores + 1e-9, axis=-1) + attention_probs = ops.softmax(attention_scores, axis=-1) # This is actually dropping out entire tokens to attend to, which might # seem a bit unusual, but is taken from the original Transformer paper. - attention_probs = keras.layers.Dropout(self.dropout)( - inputs=_attention_probs, training=training + dropout_attention_probs = keras.layers.Dropout(self.dropout)( + inputs=attention_probs, training=training ) - attn_output = ops.matmul(attention_probs, value_layer) + attn_output = ops.matmul(dropout_attention_probs, value_layer) attn_output = ops.transpose(attn_output, axes=[0, 2, 1, 3]) # (batch_size, seq_len_q, proj_dim) @@ -298,7 +301,7 @@ def call( attn_output = self.out_proj(attn_output, training=training) outputs = ( - (attn_output, _attention_probs) + (attn_output, attention_probs) if output_attentions else (attn_output,) ) diff --git a/keras_cv/models/feature_extractor/clip/clip_image_model.py b/keras_cv/models/feature_extractor/clip/clip_image_model.py index efb62c945c..30f2d67405 100644 --- a/keras_cv/models/feature_extractor/clip/clip_image_model.py +++ b/keras_cv/models/feature_extractor/clip/clip_image_model.py @@ -66,7 +66,7 @@ def build(self, input_shape): ) def call(self, x): - batch_size, _, _, _ = ops.shape(x) + batch_size = ops.shape(x)[0] patch_embeddings = self.conv1(x) # shape = [*, grid, grid, channel] patch_embeddings = ops.reshape( @@ -98,12 +98,12 @@ def get_config(self): class CLIPImageEncoder(keras.Model): def __init__( self, - input_resolution: int, - patch_size: int, - width: int, - num_layers: int, - heads: int, - output_dim: int, + input_resolution, + patch_size, + width, + num_layers, + heads, + output_dim, **kwargs, ): super().__init__( diff --git a/keras_cv/models/feature_extractor/clip/clip_model.py b/keras_cv/models/feature_extractor/clip/clip_model.py index 22a41119c8..35d2fd40f3 100644 --- a/keras_cv/models/feature_extractor/clip/clip_model.py +++ b/keras_cv/models/feature_extractor/clip/clip_model.py @@ -32,10 +32,10 @@ @keras_cv_export(["keras_cv.models.CLIP"]) class CLIP(Task): """ - CLIP implements the Contrastive Language-Image Pretraining (CLIP) - architecture, which enables joint learning of visual and textual - representations for various downstream tasks. The deafult base model - achitecture will be set to clip-vit-base-patch32. + CLIP implements the Contrastive Language-Image Pretraining (CLIP) + architecture, which enables joint learning of visual and textual + representations for various downstream tasks. The deafult base model + achitecture will be set to clip-vit-base-patch32. Args: embed_dim (int): The dimensionality of the joint embedding space for @@ -135,15 +135,11 @@ def call(self, image, text, attention_mask=None): self.text_embeddings = self.encode_text( text, attention_mask=attention_mask ) - normalize_image_features = keras.ops.sqrt( - keras.ops.sum( - keras.ops.power(self.image_embeddings, 2), keepdims=True - ) + normalize_image_features = ops.sqrt( + ops.sum(ops.power(self.image_embeddings, 2), keepdims=True) ) - normalize_text_features = keras.ops.sqrt( - keras.ops.sum( - keras.ops.power(self.text_embeddings, 2), keepdims=True - ) + normalize_text_features = ops.sqrt( + ops.sum(ops.power(self.text_embeddings, 2), keepdims=True) ) self.image_embeddings = self.image_embeddings / normalize_image_features self.text_embeddings = self.text_embeddings / normalize_text_features diff --git a/keras_cv/models/feature_extractor/clip/clip_presets.py b/keras_cv/models/feature_extractor/clip/clip_presets.py index 4198a76c55..e58e5f55b5 100644 --- a/keras_cv/models/feature_extractor/clip/clip_presets.py +++ b/keras_cv/models/feature_extractor/clip/clip_presets.py @@ -76,6 +76,6 @@ "official_name": "CLIP", "path": "clip", }, - "kaggle_handle": "kaggle://keras/clip/keras/clip-vit-large-patch14-336/1", + "kaggle_handle": "kaggle://keras/clip/keras/clip-vit-large-patch14-336/1", # noqa: E501 }, } diff --git a/keras_cv/models/feature_extractor/clip/clip_text_model.py b/keras_cv/models/feature_extractor/clip/clip_text_model.py index 5fc92990d2..36142b1a2d 100644 --- a/keras_cv/models/feature_extractor/clip/clip_text_model.py +++ b/keras_cv/models/feature_extractor/clip/clip_text_model.py @@ -82,6 +82,7 @@ def call(self, inputs, attention_mask=None): # Zero out the lower diagonal causal_attention_mask = ops.triu(causal_attention_mask) causal_attention_mask = ops.cast(causal_attention_mask, "float32") + causal_attention_mask = (1.0 - causal_attention_mask) * (-1e8) attention_mask = ops.cast(attention_mask, dtype="float32") expanded_mask = ops.tile( attention_mask[:, None, None, :], (1, 1, self.context_length, 1) diff --git a/keras_cv/models/feature_extractor/clip/clip_tokenizer.py b/keras_cv/models/feature_extractor/clip/clip_tokenizer.py index 8a975be594..f08cdfd7c5 100644 --- a/keras_cv/models/feature_extractor/clip/clip_tokenizer.py +++ b/keras_cv/models/feature_extractor/clip/clip_tokenizer.py @@ -14,15 +14,13 @@ import regex as re import tensorflow as tf import tensorflow_text as tf_text + try: from keras_nlp.tokenizers import BytePairTokenizer except ImportError: raise ImportError( - "CLIP model requires keras-nlp. Please pip " - "install keras-nlp." - ) -VOCAB_FILENAME = "keras_cv/models/feature_extractors/clip/vocab.json" -MERGES_FILENAME = "keras_cv/models/feature_extractors/clip/merges.txt" + "CLIP model requires keras-nlp. Please pip " "install keras-nlp." + ) # As python and TF handles special spaces differently, we need to # manually handle special spaces during string split. SPECIAL_WHITESPACES = r"\x{a0}\x{2009}\x{202f}\x{3000}" @@ -93,25 +91,6 @@ def create_alts_for_unsplittable_tokens(unsplittable_tokens): return alts -def bytes_to_unicode(): - bs = ( - list(range(ord("!"), ord("~") + 1)) - + list(range(ord("¡"), ord("¬") + 1)) - + list(range(ord("®"), ord("ÿ") + 1)) - ) - cs = bs[:] - n = 0 - # removes mapping an int to a whitespace character - for b in range(2**8): - if b not in bs: - bs.append(b) - cs.append(2**8 + n) - n += 1 - cs = [chr(n) for n in cs] - bs = [n.to_bytes(1, "little") for n in bs] - return bs, cs # int to string mapping - - def remove_strings_from_inputs(tensor, string_to_remove): """Remove certain strings from input tensor.""" non_empty_mask = tensor != string_to_remove From 8eeb88e1735b53aa9e20f404651ce1cfd993b364 Mon Sep 17 00:00:00 2001 From: Divyashree Sreepathihalli Date: Wed, 21 Feb 2024 00:02:13 +0000 Subject: [PATCH 34/38] change import checks --- .../feature_extractor/clip/clip_image_model.py | 11 ++++++----- .../models/feature_extractor/clip/clip_model.py | 15 +++++++++------ .../feature_extractor/clip/clip_tokenizer.py | 13 ++++++++++--- 3 files changed, 25 insertions(+), 14 deletions(-) diff --git a/keras_cv/models/feature_extractor/clip/clip_image_model.py b/keras_cv/models/feature_extractor/clip/clip_image_model.py index 30f2d67405..1718768116 100644 --- a/keras_cv/models/feature_extractor/clip/clip_image_model.py +++ b/keras_cv/models/feature_extractor/clip/clip_image_model.py @@ -121,6 +121,7 @@ def __init__( patch_size=self.patch_size, input_resolution=self.input_resolution, output_dim=self.output_dim, + name="clip_patch_embedding", ) self.pre_norm = keras.layers.LayerNormalization( epsilon=1e-5, name="ln_1" @@ -147,11 +148,11 @@ def build(self, input_shape): self.image_projector.build([None, None, self.width]) def call(self, image): - embeddings = self.embeddings(image) - pre_norm = self.pre_norm(embeddings) - encoded_output = self.encoder(pre_norm) - post_norm = self.post_norm(encoded_output[:, 0, :]) - image_projected_embeddings = self.image_projector(post_norm) + x = self.embeddings(image) + x = self.pre_norm(x) + x = self.encoder(x) + x = self.post_norm(x[:, 0, :]) + image_projected_embeddings = self.image_projector(x) return image_projected_embeddings def get_config(self): diff --git a/keras_cv/models/feature_extractor/clip/clip_model.py b/keras_cv/models/feature_extractor/clip/clip_model.py index 35d2fd40f3..cbe9f0ee33 100644 --- a/keras_cv/models/feature_extractor/clip/clip_model.py +++ b/keras_cv/models/feature_extractor/clip/clip_model.py @@ -28,6 +28,11 @@ from keras_cv.models.task import Task from keras_cv.utils.python_utils import classproperty +try: + import keras_nlp +except ImportError: + keras_nlp = None + @keras_cv_export(["keras_cv.models.CLIP"]) class CLIP(Task): @@ -73,12 +78,10 @@ def __init__( **kwargs, ): super().__init__(**kwargs) - try: - import keras_nlp # noqa: F401 - except ImportError: - raise ImportError( - "CLIP model requires keras-nlp. Please pip " - "install keras-nlp." + if keras_nlp is None: + raise ValueError( + "ClipTokenizer requires keras-nlp. Please install " + "using pip `pip install keras-nlp`" ) self.embed_dim = embed_dim self.image_resolution = image_resolution diff --git a/keras_cv/models/feature_extractor/clip/clip_tokenizer.py b/keras_cv/models/feature_extractor/clip/clip_tokenizer.py index f08cdfd7c5..65f15882ec 100644 --- a/keras_cv/models/feature_extractor/clip/clip_tokenizer.py +++ b/keras_cv/models/feature_extractor/clip/clip_tokenizer.py @@ -18,9 +18,8 @@ try: from keras_nlp.tokenizers import BytePairTokenizer except ImportError: - raise ImportError( - "CLIP model requires keras-nlp. Please pip " "install keras-nlp." - ) + keras_nlp = None + # As python and TF handles special spaces differently, we need to # manually handle special spaces during string split. SPECIAL_WHITESPACES = r"\x{a0}\x{2009}\x{202f}\x{3000}" @@ -105,6 +104,14 @@ def remove_strings_from_inputs(tensor, string_to_remove): class CLIPTokenizer(BytePairTokenizer): + def __init__(self, **kwargs): + super().__init__(**kwargs) + if keras_nlp is None: + raise ValueError( + "ClipTokenizer requires keras-nlp. Please install " + "using pip `pip install keras-nlp`" + ) + def _bpe_merge_and_update_cache(self, tokens): """Process unseen tokens and add to cache.""" words = self._transform_bytes(tokens) From d5b2534b7fdda65a26df5605d943efd6b98c09ec Mon Sep 17 00:00:00 2001 From: Divyashree Sreepathihalli Date: Wed, 21 Feb 2024 00:22:06 +0000 Subject: [PATCH 35/38] update keras_nlp import check --- keras_cv/models/feature_extractor/clip/clip_tokenizer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/keras_cv/models/feature_extractor/clip/clip_tokenizer.py b/keras_cv/models/feature_extractor/clip/clip_tokenizer.py index 65f15882ec..fccbf680e3 100644 --- a/keras_cv/models/feature_extractor/clip/clip_tokenizer.py +++ b/keras_cv/models/feature_extractor/clip/clip_tokenizer.py @@ -16,6 +16,7 @@ import tensorflow_text as tf_text try: + import keras_nlp from keras_nlp.tokenizers import BytePairTokenizer except ImportError: keras_nlp = None From 9a6646494fb8db238d52bd4a0d19bd2a1163c3c3 Mon Sep 17 00:00:00 2001 From: Divyashree Sreepathihalli Date: Wed, 21 Feb 2024 00:54:10 +0000 Subject: [PATCH 36/38] update kokoro tests --- .kokoro/github/ubuntu/gpu/build.sh | 5 +++++ keras_cv/models/feature_extractor/clip/clip_model_test.py | 8 +++++++- 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/.kokoro/github/ubuntu/gpu/build.sh b/.kokoro/github/ubuntu/gpu/build.sh index 9d07218317..fedfcd0566 100644 --- a/.kokoro/github/ubuntu/gpu/build.sh +++ b/.kokoro/github/ubuntu/gpu/build.sh @@ -29,21 +29,26 @@ then pip install --extra-index-url https://download.pytorch.org/whl/cpu torch==2.1.0+cpu pip install torchvision~=0.16.0 pip install "jax[cpu]" + pip install keras-nlp-nightly --no-deps + pip install tensorflow-text==2.15 elif [ "$KERAS_BACKEND" == "tensorflow" ] then echo "TensorFlow backend detected." pip install -r requirements-tensorflow-cuda.txt --progress-bar off + pip install keras-nlp-nightly elif [ "$KERAS_BACKEND" == "jax" ] then echo "JAX backend detected." pip install -r requirements-jax-cuda.txt --progress-bar off + pip install keras-nlp-nightly elif [ "$KERAS_BACKEND" == "torch" ] then echo "PyTorch backend detected." pip install -r requirements-torch-cuda.txt --progress-bar off + pip install keras-nlp-nightly fi pip install --no-deps -e "." --progress-bar off diff --git a/keras_cv/models/feature_extractor/clip/clip_model_test.py b/keras_cv/models/feature_extractor/clip/clip_model_test.py index 3a0c367c96..d5c777c653 100644 --- a/keras_cv/models/feature_extractor/clip/clip_model_test.py +++ b/keras_cv/models/feature_extractor/clip/clip_model_test.py @@ -16,6 +16,7 @@ import numpy as np import pytest +from tensorflow import data as tf_data from keras_cv.backend import keras from keras_cv.backend import ops @@ -57,7 +58,6 @@ def test_clip_model_golden_values(self): ) def test_clip_preprocessor(self): - # self.skipTest("TODO: Enable after Kaggle model is public") processor = CLIPProcessor(224, VOCAB_PATH, MERGE_PATH) processed_text, attention_mask = processor.process_texts( ["mountains", "cat on tortoise"] @@ -69,6 +69,12 @@ def test_clip_preprocessor(self): attention_mask[0, :5], [True, True, True, False, False] ) + def test_clip_preprocessor_tf_data(self): + processor = CLIPProcessor(224, VOCAB_PATH, MERGE_PATH) + text_input = ["a bus", "a dog", "a cat"] + dataset = tf_data.Dataset.from_tensor_slices(text_input) + dataset.map(processor.process_texts) + @pytest.mark.large def test_presets(self): self.skipTest("TODO: Enable after Kaggle model is public") From a0b8e300ef5fac19864af1d4e38df55dcf34659d Mon Sep 17 00:00:00 2001 From: Divyashree Sreepathihalli Date: Wed, 21 Feb 2024 21:41:46 +0000 Subject: [PATCH 37/38] update kaggle preset version --- keras_cv/models/feature_extractor/clip/clip_presets.py | 8 ++++---- keras_cv/models/feature_extractor/clip/clip_text_model.py | 1 - 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/keras_cv/models/feature_extractor/clip/clip_presets.py b/keras_cv/models/feature_extractor/clip/clip_presets.py index e58e5f55b5..6b4d98727e 100644 --- a/keras_cv/models/feature_extractor/clip/clip_presets.py +++ b/keras_cv/models/feature_extractor/clip/clip_presets.py @@ -28,7 +28,7 @@ "official_name": "CLIP", "path": "clip", }, - "kaggle_handle": "kaggle://keras/clip/keras/clip-vit-base-patch16/1", + "kaggle_handle": "kaggle://keras/clip/keras/clip-vit-base-patch16/2", }, "clip-vit-base-patch32": { "metadata": { @@ -44,7 +44,7 @@ "official_name": "CLIP", "path": "clip", }, - "kaggle_handle": "kaggle://keras/clip/keras/clip-vit-base-patch32/1", + "kaggle_handle": "kaggle://keras/clip/keras/clip-vit-base-patch32/2", }, "clip-vit-large-patch14": { "metadata": { @@ -60,7 +60,7 @@ "official_name": "CLIP", "path": "clip", }, - "kaggle_handle": "kaggle://keras/clip/keras/clip-vit-large-patch14/1", + "kaggle_handle": "kaggle://keras/clip/keras/clip-vit-large-patch14/2", }, "clip-vit-large-patch14-336": { "metadata": { @@ -76,6 +76,6 @@ "official_name": "CLIP", "path": "clip", }, - "kaggle_handle": "kaggle://keras/clip/keras/clip-vit-large-patch14-336/1", # noqa: E501 + "kaggle_handle": "kaggle://keras/clip/keras/clip-vit-large-patch14-336/2", # noqa: E501 }, } diff --git a/keras_cv/models/feature_extractor/clip/clip_text_model.py b/keras_cv/models/feature_extractor/clip/clip_text_model.py index 36142b1a2d..5fc92990d2 100644 --- a/keras_cv/models/feature_extractor/clip/clip_text_model.py +++ b/keras_cv/models/feature_extractor/clip/clip_text_model.py @@ -82,7 +82,6 @@ def call(self, inputs, attention_mask=None): # Zero out the lower diagonal causal_attention_mask = ops.triu(causal_attention_mask) causal_attention_mask = ops.cast(causal_attention_mask, "float32") - causal_attention_mask = (1.0 - causal_attention_mask) * (-1e8) attention_mask = ops.cast(attention_mask, dtype="float32") expanded_mask = ops.tile( attention_mask[:, None, None, :], (1, 1, self.context_length, 1) From fe2ac12288fde73dadb088ffec5144536da6abd9 Mon Sep 17 00:00:00 2001 From: Divyashree Sreepathihalli Date: Wed, 21 Feb 2024 22:25:34 +0000 Subject: [PATCH 38/38] update install instructions for keras-nlp --- keras_cv/models/feature_extractor/clip/clip_model.py | 2 +- keras_cv/models/feature_extractor/clip/clip_tokenizer.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/keras_cv/models/feature_extractor/clip/clip_model.py b/keras_cv/models/feature_extractor/clip/clip_model.py index cbe9f0ee33..e81dbd5d09 100644 --- a/keras_cv/models/feature_extractor/clip/clip_model.py +++ b/keras_cv/models/feature_extractor/clip/clip_model.py @@ -81,7 +81,7 @@ def __init__( if keras_nlp is None: raise ValueError( "ClipTokenizer requires keras-nlp. Please install " - "using pip `pip install keras-nlp`" + "using pip `pip install -U keras-nlp && pip install -U keras`" ) self.embed_dim = embed_dim self.image_resolution = image_resolution diff --git a/keras_cv/models/feature_extractor/clip/clip_tokenizer.py b/keras_cv/models/feature_extractor/clip/clip_tokenizer.py index fccbf680e3..66b4d7cef6 100644 --- a/keras_cv/models/feature_extractor/clip/clip_tokenizer.py +++ b/keras_cv/models/feature_extractor/clip/clip_tokenizer.py @@ -110,7 +110,7 @@ def __init__(self, **kwargs): if keras_nlp is None: raise ValueError( "ClipTokenizer requires keras-nlp. Please install " - "using pip `pip install keras-nlp`" + "using pip `pip install -U keras-nlp && pip install -U keras`" ) def _bpe_merge_and_update_cache(self, tokens):