diff --git a/.github/workflows/actions.yml b/.github/workflows/actions.yml index 8be69b967b..316e623c57 100644 --- a/.github/workflows/actions.yml +++ b/.github/workflows/actions.yml @@ -38,6 +38,8 @@ jobs: pip install torch>=2.0.1+cpu pip install "jax[cpu]" pip install keras-core + pip install keras-nlp-nightly --no-deps + pip install tensorflow-text==2.15 pip install -e ".[tests]" --progress-bar off --upgrade - name: Test with pytest env: @@ -75,6 +77,7 @@ jobs: run: | pip install -r requirements.txt pip install -e ".[tests]" --progress-bar off --upgrade + pip install keras-nlp-nightly - name: Test with pytest env: TEST_CUSTOM_OPS: false # TODO(ianstenbit): test custom ops, or figure out what our story is here diff --git a/.kokoro/github/ubuntu/gpu/build.sh b/.kokoro/github/ubuntu/gpu/build.sh index 9d07218317..fedfcd0566 100644 --- a/.kokoro/github/ubuntu/gpu/build.sh +++ b/.kokoro/github/ubuntu/gpu/build.sh @@ -29,21 +29,26 @@ then pip install --extra-index-url https://download.pytorch.org/whl/cpu torch==2.1.0+cpu pip install torchvision~=0.16.0 pip install "jax[cpu]" + pip install keras-nlp-nightly --no-deps + pip install tensorflow-text==2.15 elif [ "$KERAS_BACKEND" == "tensorflow" ] then echo "TensorFlow backend detected." pip install -r requirements-tensorflow-cuda.txt --progress-bar off + pip install keras-nlp-nightly elif [ "$KERAS_BACKEND" == "jax" ] then echo "JAX backend detected." pip install -r requirements-jax-cuda.txt --progress-bar off + pip install keras-nlp-nightly elif [ "$KERAS_BACKEND" == "torch" ] then echo "PyTorch backend detected." pip install -r requirements-torch-cuda.txt --progress-bar off + pip install keras-nlp-nightly fi pip install --no-deps -e "." --progress-bar off diff --git a/keras_cv/models/__init__.py b/keras_cv/models/__init__.py index b9b90b946a..8e6a849a95 100644 --- a/keras_cv/models/__init__.py +++ b/keras_cv/models/__init__.py @@ -183,6 +183,7 @@ from keras_cv.models.backbones.vit_det.vit_det_aliases import ViTDetLBackbone from keras_cv.models.backbones.vit_det.vit_det_backbone import ViTDetBackbone from keras_cv.models.classification.image_classifier import ImageClassifier +from keras_cv.models.feature_extractor.clip import CLIP from keras_cv.models.object_detection.retinanet.retinanet import RetinaNet from keras_cv.models.object_detection.yolo_v8.yolo_v8_backbone import ( YOLOV8Backbone, diff --git a/keras_cv/models/feature_extractor/__init__.py b/keras_cv/models/feature_extractor/__init__.py new file mode 100644 index 0000000000..3992ffb59a --- /dev/null +++ b/keras_cv/models/feature_extractor/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2023 The KerasCV Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/keras_cv/models/feature_extractor/clip/__init__.py b/keras_cv/models/feature_extractor/clip/__init__.py new file mode 100644 index 0000000000..8826871115 --- /dev/null +++ b/keras_cv/models/feature_extractor/clip/__init__.py @@ -0,0 +1,23 @@ +# Copyright 2023 The KerasCV Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from keras_cv.models.feature_extractor.clip.clip_image_model import ( + CLIPImageEncoder, +) +from keras_cv.models.feature_extractor.clip.clip_model import CLIP +from keras_cv.models.feature_extractor.clip.clip_processor import CLIPProcessor +from keras_cv.models.feature_extractor.clip.clip_text_model import ( + CLIPTextEncoder, +) +from keras_cv.models.feature_extractor.clip.clip_tokenizer import CLIPTokenizer diff --git a/keras_cv/models/feature_extractor/clip/clip_encoder.py b/keras_cv/models/feature_extractor/clip/clip_encoder.py new file mode 100644 index 0000000000..aeb345c857 --- /dev/null +++ b/keras_cv/models/feature_extractor/clip/clip_encoder.py @@ -0,0 +1,321 @@ +# Copyright 2023 The KerasCV Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import numpy as np + +from keras_cv.backend import keras +from keras_cv.backend import ops + + +def get_initializer(initializer_range=0.02): + """ + Creates a `keras.initializers.TruncatedNormal` with the given range. + + Args: + initializer_range (*float*, defaults to 0.02): Standard deviation of the + initializer range. + + Returns: + `keras.initializers.TruncatedNormal`: The truncated normal initializer. + """ + return keras.initializers.TruncatedNormal(stddev=initializer_range) + + +class QuickGELU(keras.layers.Layer): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def call(self, x): + return x * ops.sigmoid(1.702 * x) + + +class ResidualAttention(keras.layers.Layer): + def __init__( + self, + proj_dim, + num_heads, + num_hidden_layers, + **kwargs, + ): + super().__init__(**kwargs) + self.proj_dim = proj_dim + self.num_heads = num_heads + self.num_hidden_layers = num_hidden_layers + self.fc_std = np.power(2 * self.proj_dim, -0.5) * 0.02 + + self.in_proj_std = ( + np.power(self.proj_dim, -0.5) + * (np.power(2 * self.num_hidden_layers, -0.5)) + * 0.02 + ) + self.attn = CLIPAttention( + self.proj_dim, + self.num_heads, + self.num_hidden_layers, + name="multi_head_attention", + ) + self.ln_1 = keras.layers.LayerNormalization(epsilon=1e-5, name="ln_1") + self.mlp_dense_1 = keras.layers.Dense( + self.proj_dim * 4, + name="c_fc", + ) + self.mlp_activation = QuickGELU(name="gelu") + self.mlp_dense_2 = keras.layers.Dense( + self.proj_dim, + name="c_proj", + ) + self.ln_2 = keras.layers.LayerNormalization(epsilon=1e-5, name="ln_2") + + def attention(self, x, causal_attention_mask=None, attention_mask=None): + mask = None + if causal_attention_mask is not None: + mask = ( + ops.cast(causal_attention_mask, dtype=x.dtype) + if causal_attention_mask is not None + else None + ) + if attention_mask is not None: + attention_mask = ( + ops.cast(attention_mask, dtype=x.dtype) + if attention_mask is not None + else None + ) + mask = ops.add(causal_attention_mask, attention_mask) + + return self.attn( + x, + attention_mask=mask, + )[0] + + def build(self, input_shape): + super().build(input_shape) + self.attn.build(None) + self.ln_1.build([None, None, self.proj_dim]) + self.mlp_dense_1.build([None, None, self.proj_dim]) + self.mlp_dense_2.build([None, None, self.proj_dim * 4]) + self.ln_2.build([None, None, self.proj_dim]) + + def call(self, x, causal_attention_mask=None, attention_mask=None): + residual = x + x = self.ln_1(x) + x = self.attention( + x, + causal_attention_mask=causal_attention_mask, + attention_mask=attention_mask, + ) + x = x + residual + residual = x + x = self.mlp_dense_1(self.ln_2(residual)) + x = self.mlp_activation(x) + x = self.mlp_dense_2(x) + x = residual + x + return x + + def compute_output_shape(self, inputs_shape): + return inputs_shape + + def get_config(self): + config = super().get_config() + config.update( + { + "proj_dim": self.proj_dim, + "num_heads": self.num_heads, + "num_hidden_layers": self.num_hidden_layers, + } + ) + return config + + +class CLIPEncoder(keras.layers.Layer): + def __init__(self, width, num_layers, heads, **kwargs): + super().__init__(**kwargs) + self.width = width + self.num_layers = num_layers + self.heads = heads + self.resblocks = [ + ResidualAttention( + self.width, + self.heads, + self.num_layers, + ) + for _ in range(self.num_layers) + ] + + def build(self, input_shape): + super().build(input_shape) + for block in self.resblocks: + block.build(input_shape) + + def call( + self, + x, + causal_attention_mask=None, + attention_mask=None, + ): + for block in self.resblocks: + x = block( + x, + causal_attention_mask=causal_attention_mask, + attention_mask=attention_mask, + ) + return x + + def compute_output_shape(self, inputs_shape): + return inputs_shape + + def get_config(self): + config = super().get_config() + config.update( + { + "width": self.width, + "num_layers": self.num_layers, + "heads": self.heads, + } + ) + return config + + +class CLIPAttention(keras.layers.Layer): + """ + Adapted from https://github.com/huggingface/transformers/blob/main/src/transformers/models/clip/modeling_clip.py # noqa: E501 + """ + + def __init__( + self, proj_dim, num_heads, num_hidden_layers, dropout=0.0, **kwargs + ): + super().__init__(**kwargs) + + self.proj_dim = proj_dim + self.num_heads = num_heads + self.num_hidden_layers = num_hidden_layers + self.dropout = dropout + self.head_dim = self.proj_dim // self.num_heads + if self.head_dim * self.num_heads != self.proj_dim: + raise ValueError( + f"proj_dim must be divisible by num_heads (got `proj_dim`" + f": {self.proj_dim} and `num_heads`:" + f" {self.num_heads})." + ) + + self.scale = self.head_dim**-0.5 + in_proj_std = ( + (self.proj_dim**-0.5) + * ((2 * self.num_hidden_layers) ** -0.5) + * 0.02 + ) + out_proj_std = (self.proj_dim**-0.5) * 0.02 + self.q_proj = keras.layers.Dense( + units=self.proj_dim, + kernel_initializer=get_initializer(in_proj_std), + name="q_proj", + ) + self.k_proj = keras.layers.Dense( + units=self.proj_dim, + kernel_initializer=get_initializer(in_proj_std), + name="k_proj", + ) + self.v_proj = keras.layers.Dense( + units=self.proj_dim, + kernel_initializer=get_initializer(in_proj_std), + name="v_proj", + ) + self.out_proj = keras.layers.Dense( + units=self.proj_dim, + kernel_initializer=get_initializer(out_proj_std), + name="out_proj", + ) + + def build(self, input_shape): + super().build(input_shape) + self.q_proj.build([None, None, self.proj_dim]) + self.k_proj.build([None, None, self.proj_dim]) + self.v_proj.build([None, None, self.proj_dim]) + self.out_proj.build([None, None, self.proj_dim]) + + def _transpose_for_scores(self, tensor, batch_size): + """ + Adapted from https://github.com/huggingface/transformers/blob/8e164c5400b7b413c7b8fb32e35132001effc970/src/transformers/models/bert/modeling_tf_bert.py#L252 # noqa: E501 + """ + # [batch_size, seq_len, all_head_dim] -> + # [batch_size, seq_len, num_heads, head_dim] + tensor = ops.reshape( + tensor, (batch_size, -1, self.num_heads, self.head_dim) + ) + # [batch_size, seq_len, num_heads, head_dim] -> + # [batch_size, num_heads, seq_len, head_dim] + return ops.transpose(tensor, axes=[0, 2, 1, 3]) + + def call( + self, + x, + attention_mask=None, + output_attentions=None, + training=False, + ): + batch_size = ops.shape(x)[0] + mixed_query_layer = self.q_proj(inputs=x) + mixed_key_layer = self.k_proj(inputs=x) + mixed_value_layer = self.v_proj(inputs=x) + query_layer = self._transpose_for_scores(mixed_query_layer, batch_size) + key_layer = self._transpose_for_scores(mixed_key_layer, batch_size) + value_layer = self._transpose_for_scores(mixed_value_layer, batch_size) + + # Scaled dot product between key and query = raw attention scores. + attention_scores = ops.matmul( + query_layer, ops.transpose(key_layer, axes=[0, 1, 3, 2]) + ) + dk = ops.cast(ops.sqrt(self.head_dim), dtype=attention_scores.dtype) + attention_scores = ops.divide( + attention_scores, dk + ) # (batch_size, num_heads, seq_len_q, seq_len_k) + + if attention_mask is not None: + # Apply the attention mask (precomputed for all layers in the + # call() function) + attention_scores = ops.add(attention_scores, attention_mask) + + # Normalize the attention scores to probabilities. + attention_probs = ops.softmax(attention_scores, axis=-1) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + dropout_attention_probs = keras.layers.Dropout(self.dropout)( + inputs=attention_probs, training=training + ) + + attn_output = ops.matmul(dropout_attention_probs, value_layer) + attn_output = ops.transpose(attn_output, axes=[0, 2, 1, 3]) + + # (batch_size, seq_len_q, proj_dim) + attn_output = ops.reshape(attn_output, (batch_size, -1, self.proj_dim)) + + attn_output = self.out_proj(attn_output, training=training) + outputs = ( + (attn_output, attention_probs) + if output_attentions + else (attn_output,) + ) + + return outputs + + def get_config(self): + config = super().get_config() + config.update( + { + "proj_dim": self.proj_dim, + "num_heads": self.num_heads, + "num_hidden_layers": self.num_hidden_layers, + "dropout": self.dropout, + } + ) + return config diff --git a/keras_cv/models/feature_extractor/clip/clip_image_model.py b/keras_cv/models/feature_extractor/clip/clip_image_model.py new file mode 100644 index 0000000000..1718768116 --- /dev/null +++ b/keras_cv/models/feature_extractor/clip/clip_image_model.py @@ -0,0 +1,170 @@ +# Copyright 2023 The KerasCV Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from keras_cv.backend import keras +from keras_cv.backend import ops +from keras_cv.models.feature_extractor.clip.clip_encoder import CLIPEncoder +from keras_cv.models.feature_extractor.clip.clip_encoder import get_initializer + + +class CLIPPatchingAndEmbedding(keras.layers.Layer): + def __init__( + self, width, patch_size, input_resolution, output_dim, **kwargs + ): + super().__init__(**kwargs) + + self.conv1 = keras.layers.Conv2D( + filters=width, + kernel_size=patch_size, + strides=patch_size, + padding="valid", + use_bias=False, + data_format="channels_last", + kernel_initializer=get_initializer(0.02), + name="patch_embed.embedding", + ) + self.width = width + self.input_resolution = input_resolution + self.patch_size = patch_size + self.num_patches = ops.power( + (self.input_resolution // self.patch_size), 2 + ) + self.class_embedding_initializer = get_initializer( + ops.power(self.width, -0.5) * 0.02 + ) + self.output_dim = output_dim + + def build(self, input_shape): + super().build(input_shape) + self.conv1.build(input_shape) + self.class_embedding = self.add_weight( + shape=((self.width,)), + initializer=self.class_embedding_initializer, + name="patch_embed.class_embedding", + ) + + self.positional_embedding = self.add_weight( + shape=( + ( + (self.input_resolution // self.patch_size) ** 2 + 1, + self.width, + ) + ), + trainable=True, + name="patch_embed.positional_embedding", + ) + + def call(self, x): + batch_size = ops.shape(x)[0] + patch_embeddings = self.conv1(x) # shape = [*, grid, grid, channel] + + patch_embeddings = ops.reshape( + patch_embeddings, (batch_size, self.num_patches, -1) + ) + class_embeds = ops.broadcast_to( + self.class_embedding, (batch_size, 1, self.width) + ) + embeddings = ops.concatenate( + [class_embeds, patch_embeddings], axis=1 + ) # shape = [*, grid ** 2 + 1, width] + positional_embedding = self.positional_embedding + embeddings = embeddings + positional_embedding + return embeddings + + def get_config(self): + config = super().get_config() + config.update( + { + "width": self.width, + "patch_size": self.patch_size, + "input_resolution": self.input_resolution, + "output_dim": self.output_dim, + } + ) + return config + + +class CLIPImageEncoder(keras.Model): + def __init__( + self, + input_resolution, + patch_size, + width, + num_layers, + heads, + output_dim, + **kwargs, + ): + super().__init__( + **kwargs, + ) + self.input_resolution = input_resolution + self.width = width + self.patch_size = patch_size + self.output_dim = output_dim + self.heads = heads + self.num_layers = num_layers + + self.embeddings = CLIPPatchingAndEmbedding( + width=self.width, + patch_size=self.patch_size, + input_resolution=self.input_resolution, + output_dim=self.output_dim, + name="clip_patch_embedding", + ) + self.pre_norm = keras.layers.LayerNormalization( + epsilon=1e-5, name="ln_1" + ) + self.encoder = CLIPEncoder( + self.width, + self.num_layers, + self.heads, + name="clip_encoder", + ) + self.post_norm = keras.layers.LayerNormalization( + epsilon=1e-5, name="ln_2" + ) + self.image_projector = keras.layers.Dense( + output_dim, name="vision_projector", use_bias=False + ) + + def build(self, input_shape): + super().build(input_shape) + self.embeddings.build(input_shape) + self.pre_norm.build([None, None, self.width]) + self.encoder.build(None) + self.post_norm.build([None, self.width]) + self.image_projector.build([None, None, self.width]) + + def call(self, image): + x = self.embeddings(image) + x = self.pre_norm(x) + x = self.encoder(x) + x = self.post_norm(x[:, 0, :]) + image_projected_embeddings = self.image_projector(x) + return image_projected_embeddings + + def get_config(self): + config = super().get_config() + config.update( + { + "input_resolution": self.input_resolution, + "patch_size": self.patch_size, + "width": self.width, + "layers": self.num_layers, + "heads": self.heads, + "output_dim": self.output_dim, + } + ) + return config diff --git a/keras_cv/models/feature_extractor/clip/clip_model.py b/keras_cv/models/feature_extractor/clip/clip_model.py new file mode 100644 index 0000000000..e81dbd5d09 --- /dev/null +++ b/keras_cv/models/feature_extractor/clip/clip_model.py @@ -0,0 +1,188 @@ +# Copyright 2023 The KerasCV Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import copy + +from keras_cv.api_export import keras_cv_export +from keras_cv.backend import keras +from keras_cv.backend import ops +from keras_cv.models.feature_extractor.clip.clip_image_model import ( + CLIPImageEncoder, +) +from keras_cv.models.feature_extractor.clip.clip_presets import ( # noqa: E501 + clip_presets, +) +from keras_cv.models.feature_extractor.clip.clip_text_model import ( + CLIPTextEncoder, +) +from keras_cv.models.task import Task +from keras_cv.utils.python_utils import classproperty + +try: + import keras_nlp +except ImportError: + keras_nlp = None + + +@keras_cv_export(["keras_cv.models.CLIP"]) +class CLIP(Task): + """ + CLIP implements the Contrastive Language-Image Pretraining (CLIP) + architecture, which enables joint learning of visual and textual + representations for various downstream tasks. The deafult base model + achitecture will be set to clip-vit-base-patch32. + + Args: + embed_dim (int): The dimensionality of the joint embedding space for + images and texts. + image_resolution (int): The resolution of the input images (both height + and width). + vision_layers (int): The number of layers in the vision (image) encoder. + vision_width (int): The width of the hidden layers in the vision + encoder. + vision_patch_size (int): The size of each square patch in the input + images. + context_length (int): The maximum length of the contextualized text + sequences. + vocab_size (int): The size of the vocabulary for tokenization. + transformer_width (int): The width of the hidden layers in the + transformer-based text encoder. + transformer_heads (int): The number of attention heads in the + transformer-based text encoder. + transformer_layers (int): The number of layers in the transformer-based + text encoder. + """ + + def __init__( + self, + embed_dim=512, + image_resolution=224, + vision_layers=12, + vision_width=768, + vision_patch_size=32, + context_length=77, + vocab_size=49408, + transformer_width=768, + transformer_heads=8, + transformer_layers=12, + **kwargs, + ): + super().__init__(**kwargs) + if keras_nlp is None: + raise ValueError( + "ClipTokenizer requires keras-nlp. Please install " + "using pip `pip install -U keras-nlp && pip install -U keras`" + ) + self.embed_dim = embed_dim + self.image_resolution = image_resolution + self.vision_layers = vision_layers + self.vision_width = vision_width + self.vision_patch_size = vision_patch_size + self.context_length = context_length + self.vocab_size = vocab_size + self.transformer_width = transformer_width + self.transformer_heads = transformer_heads + self.transformer_layers = transformer_layers + + vision_heads = self.vision_width // 64 + self.image_encoder = CLIPImageEncoder( + input_resolution=self.image_resolution, + patch_size=self.vision_patch_size, + width=self.vision_width, + num_layers=self.vision_layers, + heads=vision_heads, + output_dim=self.embed_dim, + name="image_encoder", + ) + self.text_encoder = CLIPTextEncoder( + transformer_width=self.transformer_width, + transformer_layers=self.transformer_layers, + transformer_heads=self.transformer_heads, + vocab_size=self.vocab_size, + embed_dim=self.embed_dim, + context_length=self.context_length, + name="text_encoder", + ) + + self.logit_scale = keras.Variable( + ops.ones([]) * ops.log(1 / 0.07), name="logit_scale" + ) + self.image_embeddings = None + self.text_embeddings = None + + def build(self, input_shape): + super().build(input_shape) + self.text_encoder.build([None, self.context_length]) + self.image_encoder.build( + [None, self.image_resolution, self.image_resolution, 3] + ) + + def encode_images(self, image): + return self.image_encoder(image) + + def encode_text(self, text, attention_mask=None): + return self.text_encoder(text, attention_mask=attention_mask) + + def call(self, image, text, attention_mask=None): + self.image_embeddings = self.encode_images(image) + self.text_embeddings = self.encode_text( + text, attention_mask=attention_mask + ) + normalize_image_features = ops.sqrt( + ops.sum(ops.power(self.image_embeddings, 2), keepdims=True) + ) + normalize_text_features = ops.sqrt( + ops.sum(ops.power(self.text_embeddings, 2), keepdims=True) + ) + self.image_embeddings = self.image_embeddings / normalize_image_features + self.text_embeddings = self.text_embeddings / normalize_text_features + logit_scale = ops.exp(self.logit_scale) + logits_per_image = ( + ops.matmul( + self.image_embeddings, + ops.transpose(self.text_embeddings), + ) + * logit_scale + ) + logits_per_text = ops.transpose(logits_per_image) + + return logits_per_image, logits_per_text + + @classproperty + def presets(cls): + """Dictionary of preset names and configurations.""" + return copy.deepcopy({**clip_presets}) + + @classproperty + def presets_with_weights(cls): + """Dictionary of preset names and configurations that include + weights.""" + return copy.deepcopy({**clip_presets}) + + def get_config(self): + config = super().get_config() + config.update( + { + "embed_dim": self.embed_dim, + "image_resolution": self.image_resolution, + "vision_layers": self.vision_layers, + "vision_width": self.vision_width, + "vision_patch_size": self.vision_patch_size, + "context_length": self.context_length, + "vocab_size": self.vocab_size, + "transformer_width": self.transformer_width, + "transformer_heads": self.transformer_heads, + "transformer_layers": self.transformer_layers, + } + ) + return config diff --git a/keras_cv/models/feature_extractor/clip/clip_model_test.py b/keras_cv/models/feature_extractor/clip/clip_model_test.py new file mode 100644 index 0000000000..d5c777c653 --- /dev/null +++ b/keras_cv/models/feature_extractor/clip/clip_model_test.py @@ -0,0 +1,135 @@ +# Copyright 2022 The KerasCV Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import numpy as np +import pytest +from tensorflow import data as tf_data + +from keras_cv.backend import keras +from keras_cv.backend import ops +from keras_cv.backend.config import keras_3 +from keras_cv.models import CLIP +from keras_cv.models.feature_extractor.clip import CLIPProcessor +from keras_cv.tests.test_case import TestCase + +VOCAB_PATH = keras.utils.get_file( + None, + "https://storage.googleapis.com/keras-cv/models/clip/vocab.json", +) +MERGE_PATH = keras.utils.get_file( + None, + "https://storage.googleapis.com/keras-cv/models/clip/merges.txt", +) + +MODEL_PATH = keras.utils.get_file( + None, + "https://storage.googleapis.com/keras-cv/models/clip/clip-vit-base-patch32.weights.h5", # noqa: E501 +) + + +class CLIPTest(TestCase): + @pytest.mark.large + def test_clip_model_golden_values(self): + model = CLIP() + model.load_weights(MODEL_PATH) + processed_image = np.ones(shape=[1, 224, 224, 3]) + processed_text = np.ones(shape=[3, 77]) + attention_mask = np.ones(shape=[3, 77]) + image_logits, text_logits = model( + processed_image, processed_text, attention_mask + ) + print(image_logits) + self.assertAllClose(image_logits, [[2.932678, 2.932678, 2.932675]]) + self.assertAllClose( + text_logits, ops.transpose([[2.932678, 2.932678, 2.932675]]) + ) + + def test_clip_preprocessor(self): + processor = CLIPProcessor(224, VOCAB_PATH, MERGE_PATH) + processed_text, attention_mask = processor.process_texts( + ["mountains", "cat on tortoise"] + ) + self.assertAllClose( + processed_text[:, :3], [[49406, 5873, 49407], [49406, 2368, 525]] + ) + self.assertAllClose( + attention_mask[0, :5], [True, True, True, False, False] + ) + + def test_clip_preprocessor_tf_data(self): + processor = CLIPProcessor(224, VOCAB_PATH, MERGE_PATH) + text_input = ["a bus", "a dog", "a cat"] + dataset = tf_data.Dataset.from_tensor_slices(text_input) + dataset.map(processor.process_texts) + + @pytest.mark.large + def test_presets(self): + self.skipTest("TODO: Enable after Kaggle model is public") + model = CLIP.from_preset("clip-vit-base-patch32") + processed_image = np.ones(shape=[1, 224, 224, 3]) + processed_text = np.ones(shape=[3, 77]) + attention_mask = np.ones(shape=[3, 77]) + image_logits, text_logits = model( + processed_image, processed_text, attention_mask + ) + + @pytest.mark.large + def test_image_encoder_golden_values(self): + model = CLIP() + model.load_weights(MODEL_PATH) + processed_image = np.ones(shape=[1, 224, 224, 3]) + processed_text = np.ones(shape=[3, 77]) + attention_mask = np.ones(shape=[3, 77]) + model(processed_image, processed_text, attention_mask) + self.assertAllClose( + model.image_embeddings[:, :5], + [[0.023215, 0.026526, 0.008914, -0.091689, 0.021791]], + ) + + @pytest.mark.large + def test_text_encoder_golden_values(self): + model = CLIP() + processed_image = np.ones(shape=[1, 224, 224, 3]) + processed_text = np.ones(shape=[3, 77]) + attention_mask = np.ones(shape=[3, 77]) + model(processed_image, processed_text, attention_mask) + print(model.text_embeddings) + self.assertAllClose( + model.text_embeddings[0, :3], + [-0.018502, 0.000906, 0.020372], + ) + + @pytest.mark.large # Saving is slow, so mark these large. + def test_saved_model(self): + model = CLIP() + processed_image = np.ones(shape=[1, 224, 224, 3]) + processed_text = np.ones(shape=[3, 77]) + attention_mask = np.ones(shape=[3, 77]) + model_output, _ = model(processed_image, processed_text, attention_mask) + save_path = os.path.join(self.get_temp_dir(), "model.keras") + if keras_3(): + model.save(save_path) + else: + model.save(save_path, save_format="keras_v3") + restored_model = keras.models.load_model(save_path) + + # Check we got the real object back. + self.assertIsInstance(restored_model, CLIP) + # Check that output matches. + restored_output, _ = restored_model( + processed_image, processed_text, attention_mask + ) + self.assertAllClose(model_output, restored_output) diff --git a/keras_cv/models/feature_extractor/clip/clip_presets.py b/keras_cv/models/feature_extractor/clip/clip_presets.py new file mode 100644 index 0000000000..6b4d98727e --- /dev/null +++ b/keras_cv/models/feature_extractor/clip/clip_presets.py @@ -0,0 +1,81 @@ +# Copyright 2023 The KerasCV Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""CLIP presets.""" + +clip_presets = { + "clip-vit-base-patch16": { + "metadata": { + "description": ( + "The model uses a ViT-B/16 Transformer architecture as an " + "image encoder and uses a masked self-attention Transformer as " + "a text encoder. These encoders are trained to maximize the " + "similarity of (image, text) pairs via a contrastive loss. The " + "model uses a patch size of 16 and input images of size (224, " + "224)" + ), + "params": 149620737, + "official_name": "CLIP", + "path": "clip", + }, + "kaggle_handle": "kaggle://keras/clip/keras/clip-vit-base-patch16/2", + }, + "clip-vit-base-patch32": { + "metadata": { + "description": ( + "The model uses a ViT-B/32 Transformer architecture as an " + "image encoder and uses a masked self-attention Transformer as " + "a text encoder. These encoders are trained to maximize the " + "similarity of (image, text) pairs via a contrastive loss.The " + "model uses a patch size of 32 and input images of size (224, " + "224)" + ), + "params": 151277313, + "official_name": "CLIP", + "path": "clip", + }, + "kaggle_handle": "kaggle://keras/clip/keras/clip-vit-base-patch32/2", + }, + "clip-vit-large-patch14": { + "metadata": { + "description": ( + "The model uses a ViT-L/14 Transformer architecture as an " + "image encoder and uses a masked self-attention Transformer as " + "a text encoder. These encoders are trained to maximize the " + "similarity of (image, text) pairs via a contrastive loss.The " + "model uses a patch size of 14 and input images of size (224, " + "224)" + ), + "params": 427616513, + "official_name": "CLIP", + "path": "clip", + }, + "kaggle_handle": "kaggle://keras/clip/keras/clip-vit-large-patch14/2", + }, + "clip-vit-large-patch14-336": { + "metadata": { + "description": ( + "The model uses a ViT-L/14 Transformer architecture as an " + "image encoder and uses a masked self-attention Transformer as " + "a text encoder. These encoders are trained to maximize the " + "similarity of (image, text) pairs via a contrastive loss.The " + "model uses a patch size of 14 and input images of size (336, " + "336)" + ), + "params": 427944193, + "official_name": "CLIP", + "path": "clip", + }, + "kaggle_handle": "kaggle://keras/clip/keras/clip-vit-large-patch14-336/2", # noqa: E501 + }, +} diff --git a/keras_cv/models/feature_extractor/clip/clip_processor.py b/keras_cv/models/feature_extractor/clip/clip_processor.py new file mode 100644 index 0000000000..80e616cc02 --- /dev/null +++ b/keras_cv/models/feature_extractor/clip/clip_processor.py @@ -0,0 +1,131 @@ +# Copyright 2023 The KerasCV Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from keras_nlp.layers import StartEndPacker + +from keras_cv.api_export import keras_cv_export +from keras_cv.backend import keras +from keras_cv.backend import ops +from keras_cv.models.feature_extractor.clip.clip_tokenizer import CLIPTokenizer + + +@keras_cv_export("keras_cv.models.feature_extractors.CLIPProcessor") +class CLIPProcessor: + """ + CLIPProcessor is a utility class that provides functionality for processing + images and texts in the context of the CLIP (Contrastive Language-Image + Pretraining) model. + + Args: + input_resolution (int): The resolution of input images. + vocabulary (str): string or dict, maps token to integer ids. If it is a + string, it should be the file path to a json file. + merges: string or list, contains the merge rule. If it is a string, it + should be the file path to merge rules. The merge rule file should + have one merge rule per line. + + Methods: + process_images(image_path: List[str]): Transforms an image located at + the specified path. + + process_texts(texts: Union[str, List[str]], context_length: int = 77): + Processes a single text or a list of texts, returning packed token + sequences. + + """ + + def __init__(self, input_resolution, vocabulary, merges, **kwargs): + self.input_resolution = input_resolution + self.vocabulary = vocabulary + self.merges = merges + self.image_transform = self.transform_image + self.tokenizer = CLIPTokenizer( + vocabulary=self.vocabulary, + merges=self.merges, + unsplittable_tokens=[""], + ) + self.packer = StartEndPacker( + start_value=self.tokenizer.token_to_id("<|startoftext|>"), + end_value=self.tokenizer.token_to_id("<|endoftext|>"), + pad_value=None, + sequence_length=77, + return_padding_mask=True, + ) + + def transform_image(self, image_path): + input_resolution = self.input_resolution + mean = ops.array([0.48145466, 0.4578275, 0.40821073]) + std = ops.array([0.26862954, 0.26130258, 0.27577711]) + + image = keras.utils.load_img(image_path) + image = keras.utils.img_to_array(image) + image = ( + ops.image.resize( + image, + (input_resolution, input_resolution), + interpolation="bicubic", + ) + / 255.0 + ) + central_fraction = input_resolution / image.shape[0] + width, height = image.shape[0], image.shape[1] + left = ops.cast((width - width * central_fraction) / 2, dtype="int32") + top = ops.cast((height - height * central_fraction) / 2, dtype="int32") + right = ops.cast((width + width * central_fraction) / 2, dtype="int32") + bottom = ops.cast( + (height + height * central_fraction) / 2, dtype="int32" + ) + + image = ops.slice( + image, [left, top, 0], [right - left, bottom - top, 3] + ) + + image = (image - mean) / std + return image + + def process_images(self, images): + if isinstance(images, str): + images = [images] + + def process_image(image): + if isinstance(image, str): + return self.image_transform(image) + + processed_images = list(map(process_image, images)) + processed_images = ops.stack(processed_images) + return processed_images + + def process_texts(self, texts, context_length: int = 77): + if isinstance(texts, str): + texts = [texts] + + def pack_tokens(text): + return self.packer( + self.tokenizer(text), + sequence_length=context_length, + add_start_value=True, + add_end_value=True, + ) + + return pack_tokens(texts) + + def get_config(self): + config = super().get_config() + config.update( + { + "input_resolution": self.input_resolution, + "vocabulary": self.vocabulary, + "merges": self.merges, + } + ) + return config diff --git a/keras_cv/models/feature_extractor/clip/clip_text_model.py b/keras_cv/models/feature_extractor/clip/clip_text_model.py new file mode 100644 index 0000000000..5fc92990d2 --- /dev/null +++ b/keras_cv/models/feature_extractor/clip/clip_text_model.py @@ -0,0 +1,118 @@ +# Copyright 2023 The KerasCV Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from keras_cv.backend import keras +from keras_cv.backend import ops +from keras_cv.models.feature_extractor.clip.clip_encoder import CLIPEncoder + + +class CLIPTextEncoder(keras.Model): + def __init__( + self, + transformer_width, + transformer_layers, + transformer_heads, + vocab_size, + embed_dim, + context_length, + **kwargs, + ): + super().__init__( + **kwargs, + ) + self.transformer_width = transformer_width + self.transformer_layers = transformer_layers + self.transformer_heads = transformer_heads + self.vocab_size = vocab_size + self.embed_dim = embed_dim + self.context_length = context_length + self.token_embedding = keras.layers.Embedding( + vocab_size, + transformer_width, + name="token_embedding", + ) + self.positional_embedding = keras.layers.Embedding( + self.context_length, + transformer_width, + name="positional_embedding", + ) + + self.encoder = CLIPEncoder( + width=transformer_width, + num_layers=transformer_layers, + heads=transformer_heads, + name="clip_encoder", + ) + self.ln_final = keras.layers.LayerNormalization(name="ln_final") + + self.text_projector = keras.layers.Dense( + embed_dim, name="text_projector", use_bias=False + ) + + def build(self, input_shape): + super().build(input_shape) + self.token_embedding.build(input_shape) + self.positional_embedding.build([1, self.context_length]) + self.encoder.build(None) + self.ln_final.build([None, None, self.transformer_width]) + self.text_projector.build([None, None, self.transformer_width]) + + def call(self, inputs, attention_mask=None): + token_embedding = self.token_embedding(inputs) + position_ids = ops.expand_dims( + ops.arange(self.context_length, dtype="int32"), 0 + ) + position_embedding = self.positional_embedding(position_ids) + position_embedding = ops.tile( + position_embedding, repeats=(inputs.shape[0], 1, 1) + ) + causal_attention_mask = ops.ones( + (self.context_length, self.context_length) + ) + # Zero out the lower diagonal + causal_attention_mask = ops.triu(causal_attention_mask) + causal_attention_mask = ops.cast(causal_attention_mask, "float32") + attention_mask = ops.cast(attention_mask, dtype="float32") + expanded_mask = ops.tile( + attention_mask[:, None, None, :], (1, 1, self.context_length, 1) + ) + expanded_mask = (1.0 - expanded_mask) * (-1e8) + encoded_output = self.encoder( + token_embedding + position_embedding, + causal_attention_mask=causal_attention_mask, + attention_mask=expanded_mask, + ) + layer_norm = self.ln_final(encoded_output) + indices = ops.expand_dims( + ops.cast(ops.argmax(inputs, axis=-1), "int32"), axis=-1 + ) + selected_features = ops.take_along_axis( + layer_norm, indices[:, :, None], axis=1 + ) + text_features = self.text_projector(selected_features) + output = ops.squeeze(text_features, axis=1) + return output + + def get_config(self): + config = super().get_config() + config.update( + { + "transformer_width": self.transformer_width, + "transformer_layers": self.transformer_layers, + "transformer_heads": self.transformer_heads, + "vocab_size": self.vocab_size, + "embed_dim": self.embed_dim, + "context_length": self.context_length, + } + ) + return config diff --git a/keras_cv/models/feature_extractor/clip/clip_tokenizer.py b/keras_cv/models/feature_extractor/clip/clip_tokenizer.py new file mode 100644 index 0000000000..66b4d7cef6 --- /dev/null +++ b/keras_cv/models/feature_extractor/clip/clip_tokenizer.py @@ -0,0 +1,186 @@ +# Copyright 2023 The KerasCV Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import regex as re +import tensorflow as tf +import tensorflow_text as tf_text + +try: + import keras_nlp + from keras_nlp.tokenizers import BytePairTokenizer +except ImportError: + keras_nlp = None + +# As python and TF handles special spaces differently, we need to +# manually handle special spaces during string split. +SPECIAL_WHITESPACES = r"\x{a0}\x{2009}\x{202f}\x{3000}" +SPLIT_PATTERN_1 = ( + r"'s|'t|'re|'ve|'m|'ll|'d" + + r"|[\s{special_spaces}]+[\n\r\t\f६{special_spaces}]| ?\p{L}+|" + + r" ?[\p{N}]+| ?[^\s\p{L}\p{N}{special_spaces}]+" +) +SPLIT_PATTERN_1 = SPLIT_PATTERN_1.replace( + "{special_spaces}", SPECIAL_WHITESPACES +) +SPLIT_PATTERN_2 = rf"""[\s६{SPECIAL_WHITESPACES}]$""" + + +def split_strings_for_bpe(inputs, unsplittable_tokens=None): + # We need to recreate the exact behavior of token presplitting in the + # original gpt2 tokenizer which uses a lookahead. As re2 does not + # support lookahead match, we are using an alternative insert a special + # token "६" before leading space of non-space characters and after the + # trailing space, e.g., " keras" will be "६ keras". + inputs = tf.strings.regex_replace( + inputs, rf"( )([^\s{SPECIAL_WHITESPACES}])", r"६\1\2" + ) + inputs = tf.strings.regex_replace( + inputs, rf"(\s{SPECIAL_WHITESPACES})$", r"\1६" + ) + inputs = tf.strings.regex_replace(inputs, r"\s", "") + if unsplittable_tokens: + alts = create_alts_for_unsplittable_tokens(unsplittable_tokens) + for token, alt in zip(unsplittable_tokens, alts): + escaped_token = re.escape(token) + inputs = tf_text.regex_split(inputs, escaped_token, escaped_token) + inputs = tf.strings.regex_replace(inputs, escaped_token, alt) + raw_tokens = tf_text.regex_split(inputs, SPLIT_PATTERN_1, SPLIT_PATTERN_1) + # Second pass splits out the last whilespace char or "६". + raw_tokens = tf_text.regex_split( + raw_tokens, SPLIT_PATTERN_2, SPLIT_PATTERN_2 + ) + if unsplittable_tokens: + # Replace special tokens alternate with originals. + for token, alt in zip(unsplittable_tokens, alts): + escaped_alt = re.escape(alt) + raw_tokens = tf.strings.regex_replace( + raw_tokens, escaped_alt, token + ) + + # Add '' to the end of each token + tokens_with_end_tag = tf.strings.regex_replace( + raw_tokens, r"(\p{L}+)", r"\1" + ) + + while tokens_with_end_tag.shape.rank > 2: + tokens_with_end_tag = tokens_with_end_tag.merge_dims(1, 2) + + return remove_strings_from_inputs(tokens_with_end_tag, "६") + + +def create_alts_for_unsplittable_tokens(unsplittable_tokens): + # Create alternates for all special tokens that will be not split during + # tokenization. + alts = [] + prefix = "Ĵ" + # Trim out splitters. + replace_pattern = r"'|\s+|[^\p{L}\p{N}]+" + for token in unsplittable_tokens: + token = re.sub(replace_pattern, "", token) + alts.append(prefix + token) + return alts + + +def remove_strings_from_inputs(tensor, string_to_remove): + """Remove certain strings from input tensor.""" + non_empty_mask = tensor != string_to_remove + flatten_indexes = tf.where(non_empty_mask) + flatten_result = tf.gather_nd(tensor, flatten_indexes) + row_lengths = tf.reduce_sum(tf.cast(non_empty_mask, "int64"), axis=1) + result = tf.RaggedTensor.from_row_lengths( + values=flatten_result, + row_lengths=row_lengths, + ) + return result + + +class CLIPTokenizer(BytePairTokenizer): + def __init__(self, **kwargs): + super().__init__(**kwargs) + if keras_nlp is None: + raise ValueError( + "ClipTokenizer requires keras-nlp. Please install " + "using pip `pip install -U keras-nlp && pip install -U keras`" + ) + + def _bpe_merge_and_update_cache(self, tokens): + """Process unseen tokens and add to cache.""" + words = self._transform_bytes(tokens) + tokenized_words = self._bpe_merge(words) + + # For each word, join all its token by a whitespace, + # e.g., ["dragon", "fly"] => "dragon fly" for hash purpose. + tokenized_words = tf.strings.reduce_join( + tokenized_words, + axis=1, + ) + self.cache.insert(tokens, tokenized_words) + + def tokenize(self, inputs): + self._check_vocabulary() + if not isinstance(inputs, (tf.Tensor, tf.RaggedTensor)): + inputs = tf.convert_to_tensor(inputs) + + if self.add_prefix_space: + inputs = tf.strings.join([" ", inputs]) + + scalar_input = inputs.shape.rank == 0 + if scalar_input: + inputs = tf.expand_dims(inputs, 0) + + raw_tokens = split_strings_for_bpe(inputs, self.unsplittable_tokens) + token_row_splits = raw_tokens.row_splits + flat_tokens = raw_tokens.flat_values + # Check cache. + cache_lookup = self.cache.lookup(flat_tokens) + cache_mask = cache_lookup == "" + + has_unseen_words = tf.math.reduce_any( + (cache_lookup == "") & (flat_tokens != "") + ) + + def process_unseen_tokens(): + unseen_tokens = tf.boolean_mask(flat_tokens, cache_mask) + self._bpe_merge_and_update_cache(unseen_tokens) + return self.cache.lookup(flat_tokens) + + # If `has_unseen_words == True`, it means not all tokens are in cache, + # we will process the unseen tokens. Otherwise return the cache lookup. + tokenized_words = tf.cond( + has_unseen_words, + process_unseen_tokens, + lambda: cache_lookup, + ) + tokens = tf.strings.split(tokenized_words, sep=" ") + if self.compute_dtype != tf.string: + # Encode merged tokens. + tokens = self.token_to_id_map.lookup(tokens) + + # Unflatten to match input. + tokens = tf.RaggedTensor.from_row_splits( + tokens.flat_values, + tf.gather(tokens.row_splits, token_row_splits), + ) + + # Convert to a dense output if `sequence_length` is set. + if self.sequence_length: + output_shape = tokens.shape.as_list() + output_shape[-1] = self.sequence_length + tokens = tokens.to_tensor(shape=output_shape) + + # Convert to a dense output if input in scalar + if scalar_input: + tokens = tf.squeeze(tokens, 0) + tf.ensure_shape(tokens, shape=[self.sequence_length]) + + return tokens diff --git a/keras_cv/tools/checkpoint_conversion/clip_weights_conversion.ipynb b/keras_cv/tools/checkpoint_conversion/clip_weights_conversion.ipynb new file mode 100644 index 0000000000..13e443669a --- /dev/null +++ b/keras_cv/tools/checkpoint_conversion/clip_weights_conversion.ipynb @@ -0,0 +1,1032 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "0DhV6hzOMY0W" + }, + "source": [ + "# Setup" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "cRzYR-oFgxt1", + "outputId": "e4b01fcd-9f71-4ba7-b8a2-1796f7ef260d" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n", + " Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n", + " Preparing metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m950.8/950.8 kB\u001b[0m \u001b[31m5.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25h Building wheel for keras-cv (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m415.4/415.4 kB\u001b[0m \u001b[31m5.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m5.2/5.2 MB\u001b[0m \u001b[31m17.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hCollecting keras==3.0.2\n", + " Downloading keras-3.0.2-py3-none-any.whl (1.0 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.0/1.0 MB\u001b[0m \u001b[31m8.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hRequirement already satisfied: absl-py in /usr/local/lib/python3.10/dist-packages (from keras==3.0.2) (1.4.0)\n", + "Requirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (from keras==3.0.2) (1.23.5)\n", + "Requirement already satisfied: rich in /usr/local/lib/python3.10/dist-packages (from keras==3.0.2) (13.7.0)\n", + "Requirement already satisfied: namex in /usr/local/lib/python3.10/dist-packages (from keras==3.0.2) (0.0.7)\n", + "Requirement already satisfied: h5py in /usr/local/lib/python3.10/dist-packages (from keras==3.0.2) (3.9.0)\n", + "Requirement already satisfied: dm-tree in /usr/local/lib/python3.10/dist-packages (from keras==3.0.2) (0.1.8)\n", + "Requirement already satisfied: markdown-it-py>=2.2.0 in /usr/local/lib/python3.10/dist-packages (from rich->keras==3.0.2) (3.0.0)\n", + "Requirement already satisfied: pygments<3.0.0,>=2.13.0 in /usr/local/lib/python3.10/dist-packages (from rich->keras==3.0.2) (2.16.1)\n", + "Requirement already satisfied: mdurl~=0.1 in /usr/local/lib/python3.10/dist-packages (from markdown-it-py>=2.2.0->rich->keras==3.0.2) (0.1.2)\n", + "Installing collected packages: keras\n", + " Attempting uninstall: keras\n", + " Found existing installation: keras 2.15.0\n", + " Uninstalling keras-2.15.0:\n", + " Successfully uninstalled keras-2.15.0\n", + "\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n", + "tensorflow 2.15.0 requires keras<2.16,>=2.15.0, but you have keras 3.0.2 which is incompatible.\u001b[0m\u001b[31m\n", + "\u001b[0mSuccessfully installed keras-3.0.2\n" + ] + } + ], + "source": [ + "!pip install -q git+https://github.com/divyashreepathihalli/keras-cv.git@CLIP_refactor\n", + "!pip install -q keras-nlp\n", + "!pip install -q tf-keras\n", + "!pip install -q tensorflow-text\n", + "!pip install keras==3.0.2" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "mdGT8Em4Mc4b" + }, + "source": [ + "# Import" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "GDvJmQuug4-x" + }, + "outputs": [], + "source": [ + "from keras_cv.models.feature_extractor.clip import CLIPProcessor\n", + "import keras\n", + "from keras_cv.models import CLIP" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "nuFgha2jTshi", + "outputId": "b99d73eb-cc97-47d0-f46e-687c9e8b8237" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "--2024-02-02 22:19:20-- https://i.imgur.com/8H7XCH0.jpg\n", + "Resolving i.imgur.com (i.imgur.com)... 151.101.52.193\n", + "Connecting to i.imgur.com (i.imgur.com)|151.101.52.193|:443... connected.\n", + "HTTP request sent, awaiting response... 200 OK\n", + "Length: 44544 (44K) [image/jpeg]\n", + "Saving to: ‘cat.jpg’\n", + "\n", + "\rcat.jpg 0%[ ] 0 --.-KB/s \rcat.jpg 100%[===================>] 43.50K --.-KB/s in 0.01s \n", + "\n", + "2024-02-02 22:19:20 (3.58 MB/s) - ‘cat.jpg’ saved [44544/44544]\n", + "\n", + "--2024-02-02 22:19:20-- http://images.cocodataset.org/val2017/000000039769.jpg\n", + "Resolving images.cocodataset.org (images.cocodataset.org)... 52.216.78.4, 3.5.1.13, 52.217.139.73, ...\n", + "Connecting to images.cocodataset.org (images.cocodataset.org)|52.216.78.4|:80... connected.\n", + "HTTP request sent, awaiting response... 200 OK\n", + "Length: 173131 (169K) [image/jpeg]\n", + "Saving to: ‘test.jpg’\n", + "\n", + "test.jpg 100%[===================>] 169.07K --.-KB/s in 0.06s \n", + "\n", + "2024-02-02 22:19:20 (2.67 MB/s) - ‘test.jpg’ saved [173131/173131]\n", + "\n" + ] + } + ], + "source": [ + "!wget https://i.imgur.com/8H7XCH0.jpg -O cat.jpg\n", + "!wget http://images.cocodataset.org/val2017/000000039769.jpg -O test.jpg" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "X3kkmK6h_gFH" + }, + "outputs": [], + "source": [ + "# @title Select which model weights you would like to convert\n", + "MODEL_CONFIGS = {\n", + " \"CLIP_B32\": {\n", + " \"embed_dim\": 512,\n", + " \"context_length\": 77,\n", + " \"vocab_size\": 49408,\n", + " \"transformer_width\": 512,\n", + " \"transformer_heads\": 8,\n", + " \"transformer_layers\": 12,\n", + " \"vision_layers\": 12,\n", + " \"vision_width\": 768,\n", + " \"image_resolution\": 224,\n", + " \"vision_patch_size\": 32,\n", + " },\n", + " \"CLIP_B16\": {\n", + " \"embed_dim\": 512,\n", + " \"context_length\": 77,\n", + " \"vocab_size\": 49408,\n", + " \"transformer_width\": 512,\n", + " \"transformer_heads\": 8,\n", + " \"transformer_layers\": 12,\n", + " \"vision_layers\": 12,\n", + " \"vision_width\": 768,\n", + " \"image_resolution\": 224,\n", + " \"vision_patch_size\": 16,\n", + " },\n", + " \"CLIP_L14\": {\n", + " \"embed_dim\": 768,\n", + " \"context_length\": 77,\n", + " \"vocab_size\": 49408,\n", + " \"transformer_width\": 768,\n", + " \"transformer_heads\": 12,\n", + " \"transformer_layers\": 12,\n", + " \"vision_layers\": 24,\n", + " \"vision_width\": 1024,\n", + " \"image_resolution\": 224,\n", + " \"vision_patch_size\": 14,\n", + " },\n", + " \"CLIP_L14_336\": {\n", + " \"embed_dim\": 768,\n", + " \"context_length\": 77,\n", + " \"vocab_size\": 49408,\n", + " \"transformer_width\": 768,\n", + " \"transformer_heads\": 12,\n", + " \"transformer_layers\": 12,\n", + " \"vision_layers\": 24,\n", + " \"vision_width\": 1024,\n", + " \"image_resolution\": 336,\n", + " \"vision_patch_size\": 14,\n", + " },\n", + "}\n", + "model_map_hf = {\n", + " \"CLIP_B16\": \"openai/clip-vit-base-patch32\",\n", + " \"CLIP_B32\": \"openai/clip-vit-base-patch16\",\n", + " \"CLIP_L14\": \"openai/clip-vit-large-patch14\",\n", + " \"CLIP_L14_336\": \"openai/clip-vit-large-patch14-336\",\n", + "}\n", + "config_name = \"CLIP_L14_336\" # @param [\"CLIP_B16\", \"CLIP_B32\", \"CLIP_L14\", \"CLIP_L14_336\"]\n", + "config_name_hf = model_map_hf[config_name]" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "2l3Ll7dMMd-m" + }, + "source": [ + "# Keras 3 CLIP" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "urhuhwq0Dczo" + }, + "outputs": [], + "source": [ + "embed_dim = MODEL_CONFIGS[config_name][\"embed_dim\"]\n", + "context_length = MODEL_CONFIGS[config_name][\"context_length\"]\n", + "vocab_size = MODEL_CONFIGS[config_name][\"vocab_size\"]\n", + "transformer_width = MODEL_CONFIGS[config_name][\"transformer_width\"]\n", + "transformer_heads = MODEL_CONFIGS[config_name][\"transformer_heads\"]\n", + "transformer_layers = MODEL_CONFIGS[config_name][\"transformer_layers\"]\n", + "vision_layers = MODEL_CONFIGS[config_name][\"vision_layers\"]\n", + "vision_width = MODEL_CONFIGS[config_name][\"vision_width\"]\n", + "vision_patch_size = MODEL_CONFIGS[config_name][\"vision_patch_size\"]\n", + "image_resolution = MODEL_CONFIGS[config_name][\"image_resolution\"]\n", + "model = CLIP(\n", + " embed_dim,\n", + " image_resolution,\n", + " vision_layers,\n", + " vision_width,\n", + " vision_patch_size,\n", + " context_length,\n", + " vocab_size,\n", + " transformer_width,\n", + " transformer_heads,\n", + " transformer_layers,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 193 + }, + "id": "uE6x7gfqa3Ee", + "outputId": "9a080569-7ab9-49ad-8589-87f335ef2f31" + }, + "outputs": [ + { + "data": { + "text/html": [ + "
Model: \"clip\"\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1mModel: \"clip\"\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━┓\n",
+       "┃ Layer (type)                        Output Shape                       Param # ┃\n",
+       "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━┩\n",
+       "│ image_encoder (CLIPImageEncoder)   │ ?                             │ 0 (unbuilt) │\n",
+       "├────────────────────────────────────┼───────────────────────────────┼─────────────┤\n",
+       "│ text_encoder (CLIPTextEncoder)     │ ?                             │ 0 (unbuilt) │\n",
+       "└────────────────────────────────────┴───────────────────────────────┴─────────────┘\n",
+       "
\n" + ], + "text/plain": [ + "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━┓\n", + "┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Param #\u001b[0m\u001b[1m \u001b[0m┃\n", + "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━┩\n", + "│ image_encoder (\u001b[38;5;33mCLIPImageEncoder\u001b[0m) │ ? │ \u001b[38;5;34m0\u001b[0m (unbuilt) │\n", + "├────────────────────────────────────┼───────────────────────────────┼─────────────┤\n", + "│ text_encoder (\u001b[38;5;33mCLIPTextEncoder\u001b[0m) │ ? │ \u001b[38;5;34m0\u001b[0m (unbuilt) │\n", + "└────────────────────────────────────┴───────────────────────────────┴─────────────┘\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
 Total params: 39,425 (154.00 KB)\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m Total params: \u001b[0m\u001b[38;5;34m39,425\u001b[0m (154.00 KB)\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
 Trainable params: 39,425 (154.00 KB)\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m39,425\u001b[0m (154.00 KB)\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
 Non-trainable params: 0 (0.00 B)\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m0\u001b[0m (0.00 B)\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "model.summary()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "buXKlNfGTenW" + }, + "outputs": [], + "source": [ + "processor = CLIPProcessor(224, \"vocab.json\", \"merges.txt\")\n", + "image = processor.process_images([\"cat.jpg\"])\n", + "text_input = [\n", + " \"photo of a cat on a tortoise\",\n", + " \"tortoise on a dog\",\n", + " \"a photo of a tortoise\",\n", + "]\n", + "text = processor.process_texts(text_input)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "BHSpMv0PT5SX", + "outputId": "566c92c4-fbf3-4e2d-87f1-6112b2cff96f" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "tf.Tensor([[ 0.42190465 0.6262117 -0.2368357 ]], shape=(1, 3), dtype=float32)\n", + "tortoise on a dog\n" + ] + } + ], + "source": [ + "image_logits, text_logits = model(image, text)\n", + "output = keras.layers.Softmax()(image_logits)\n", + "print(image_logits)\n", + "print(text_input[keras.ops.argmax(output)])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 193 + }, + "id": "GgNBvYCTtmA3", + "outputId": "35b9a26c-325e-4535-c33b-3f67ab112e19" + }, + "outputs": [ + { + "data": { + "text/html": [ + "
Model: \"clip\"\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1mModel: \"clip\"\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━┓\n",
+       "┃ Layer (type)                        Output Shape                       Param # ┃\n",
+       "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━┩\n",
+       "│ image_encoder (CLIPImageEncoder)   │ ?                             │  87,849,216 │\n",
+       "├────────────────────────────────────┼───────────────────────────────┼─────────────┤\n",
+       "│ text_encoder (CLIPTextEncoder)     │ ?                             │  63,428,096 │\n",
+       "└────────────────────────────────────┴───────────────────────────────┴─────────────┘\n",
+       "
\n" + ], + "text/plain": [ + "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━┓\n", + "┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Param #\u001b[0m\u001b[1m \u001b[0m┃\n", + "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━┩\n", + "│ image_encoder (\u001b[38;5;33mCLIPImageEncoder\u001b[0m) │ ? │ \u001b[38;5;34m87,849,216\u001b[0m │\n", + "├────────────────────────────────────┼───────────────────────────────┼─────────────┤\n", + "│ text_encoder (\u001b[38;5;33mCLIPTextEncoder\u001b[0m) │ ? │ \u001b[38;5;34m63,428,096\u001b[0m │\n", + "└────────────────────────────────────┴───────────────────────────────┴─────────────┘\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
 Total params: 151,277,313 (577.08 MB)\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m Total params: \u001b[0m\u001b[38;5;34m151,277,313\u001b[0m (577.08 MB)\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
 Trainable params: 151,277,313 (577.08 MB)\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m151,277,313\u001b[0m (577.08 MB)\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
 Non-trainable params: 0 (0.00 B)\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m0\u001b[0m (0.00 B)\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "model.summary()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "P8DWYq_hVFnz" + }, + "source": [ + "# HF CLIP" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "3W2prd6C0pxe" + }, + "outputs": [], + "source": [ + "from PIL import Image\n", + "import requests\n", + "\n", + "from transformers import CLIPProcessor as CP\n", + "from transformers import CLIPModel as CM" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "EntuvOq1MhwU", + "outputId": "e154a367-2f94-4fa1-e97d-d2f32db7a2cf" + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The value `text_config[\"id2label\"]` will be overriden.\n", + "`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The value `text_config[\"bos_token_id\"]` will be overriden.\n", + "`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The value `text_config[\"eos_token_id\"]` will be overriden.\n" + ] + } + ], + "source": [ + "model_hf = CM.from_pretrained(config_name_hf)\n", + "processor = CP.from_pretrained(config_name_hf)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "Ep8DRTkv3AwS", + "outputId": "770756bc-8829-484f-b6e5-763fe81e24d0" + }, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[0.9957, 0.0023, 0.0020]], grad_fn=)" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "url = \"https://i.imgur.com/8H7XCH0.jpg\"\n", + "image_hf = Image.open(requests.get(url, stream=True).raw)\n", + "text_inputs = [\n", + " \"photo of a cat on a tortoise\",\n", + " \"tortoise on a dog\",\n", + " \"a photo of a tortoise\",\n", + "]\n", + "inputs = processor(\n", + " text=text_inputs, images=image_hf, return_tensors=\"pt\", padding=True\n", + ")\n", + "\n", + "outputs = model_hf(**inputs)\n", + "logits_per_image = (\n", + " outputs.logits_per_image\n", + ") # this is the image-text similarity score\n", + "probs = logits_per_image.softmax(\n", + " dim=1\n", + ") # we can take the softmax to get the label probabilitiesprobs\n", + "probs" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "wPa0cVnY3cBC" + }, + "outputs": [], + "source": [ + "# hugging face weights\n", + "hf_wts = model_hf.state_dict()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ArkCHlVZVKfM" + }, + "source": [ + "# Copy weights" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "TUCpKltRG4Gd" + }, + "source": [ + "##vision encoder" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "tn_U02N7U2VN" + }, + "outputs": [], + "source": [ + "model.logit_scale.assign(hf_wts.pop(\"logit_scale\").numpy())\n", + "model.get_layer(\"image_encoder\").get_layer(\n", + " \"clip_patching_and_embedding\"\n", + ").class_embedding.assign(\n", + " hf_wts.pop(\"vision_model.embeddings.class_embedding\").numpy()\n", + ")\n", + "model.get_layer(\"image_encoder\").get_layer(\n", + " \"clip_patching_and_embedding\"\n", + ").positional_embedding.assign(\n", + " hf_wts.pop(\"vision_model.embeddings.position_embedding.weight\").numpy()\n", + ")\n", + "model.get_layer(\"image_encoder\").get_layer(\n", + " \"clip_patching_and_embedding\"\n", + ").conv1.weights[0].assign(\n", + " hf_wts.pop(\"vision_model.embeddings.patch_embedding.weight\")\n", + " .permute(3, 2, 1, 0)\n", + " .numpy()\n", + ")\n", + "model.get_layer(\"image_encoder\").get_layer(\"ln_1\").weights[0].assign(\n", + " hf_wts.pop(\"vision_model.pre_layrnorm.weight\").numpy()\n", + ")\n", + "model.get_layer(\"image_encoder\").get_layer(\"ln_1\").weights[1].assign(\n", + " hf_wts.pop(\"vision_model.pre_layrnorm.bias\").numpy()\n", + ")\n", + "model.get_layer(\"image_encoder\").get_layer(\"ln_2\").weights[0].assign(\n", + " hf_wts.pop(\"vision_model.post_layernorm.weight\").numpy()\n", + ")\n", + "model.get_layer(\"image_encoder\").get_layer(\"ln_2\").weights[1].assign(\n", + " hf_wts.pop(\"vision_model.post_layernorm.bias\").numpy()\n", + ")\n", + "model.get_layer(\"image_encoder\").get_layer(\"vision_projector\").weights[\n", + " 0\n", + "].assign(hf_wts.pop(\"visual_projection.weight\").transpose(1, 0).numpy())" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "qptfuWobZcbT" + }, + "outputs": [], + "source": [ + "for i in range(0, MODEL_CONFIGS[config_name][\"vision_layers\"]):\n", + " if i == 0:\n", + " residual_attention = f\"residual_attention\"\n", + " else:\n", + " residual_attention = f\"residual_attention_{i}\"\n", + " model.get_layer(\"image_encoder\").get_layer(\n", + " \"clip_encoder\"\n", + " ).resblocks.get_layer(residual_attention).attn.q_proj.weights[0].assign(\n", + " hf_wts.pop(f\"vision_model.encoder.layers.{i}.self_attn.q_proj.weight\")\n", + " )\n", + " model.get_layer(\"image_encoder\").get_layer(\n", + " \"clip_encoder\"\n", + " ).resblocks.get_layer(residual_attention).attn.q_proj.weights[1].assign(\n", + " hf_wts.pop(f\"vision_model.encoder.layers.{i}.self_attn.q_proj.bias\")\n", + " )\n", + " model.get_layer(\"image_encoder\").get_layer(\n", + " \"clip_encoder\"\n", + " ).resblocks.get_layer(residual_attention).attn.k_proj.weights[0].assign(\n", + " hf_wts.pop(f\"vision_model.encoder.layers.{i}.self_attn.k_proj.weight\")\n", + " )\n", + " model.get_layer(\"image_encoder\").get_layer(\n", + " \"clip_encoder\"\n", + " ).resblocks.get_layer(residual_attention).attn.k_proj.weights[1].assign(\n", + " hf_wts.pop(f\"vision_model.encoder.layers.{i}.self_attn.k_proj.bias\")\n", + " )\n", + " model.get_layer(\"image_encoder\").get_layer(\n", + " \"clip_encoder\"\n", + " ).resblocks.get_layer(residual_attention).attn.v_proj.weights[0].assign(\n", + " hf_wts.pop(f\"vision_model.encoder.layers.{i}.self_attn.v_proj.weight\")\n", + " )\n", + " model.get_layer(\"image_encoder\").get_layer(\n", + " \"clip_encoder\"\n", + " ).resblocks.get_layer(residual_attention).attn.v_proj.weights[1].assign(\n", + " hf_wts.pop(f\"vision_model.encoder.layers.{i}.self_attn.v_proj.bias\")\n", + " )\n", + " model.get_layer(\"image_encoder\").get_layer(\n", + " \"clip_encoder\"\n", + " ).resblocks.get_layer(residual_attention).attn.out_proj.weights[1].assign(\n", + " hf_wts.pop(\n", + " f\"vision_model.encoder.layers.{i}.self_attn.out_proj.bias\"\n", + " ).numpy()\n", + " )\n", + " model.get_layer(\"image_encoder\").get_layer(\n", + " \"clip_encoder\"\n", + " ).resblocks.get_layer(residual_attention).attn.out_proj.weights[0].assign(\n", + " hf_wts.pop(\n", + " f\"vision_model.encoder.layers.{i}.self_attn.out_proj.weight\"\n", + " ).numpy()\n", + " )\n", + " model.get_layer(\"image_encoder\").get_layer(\n", + " \"clip_encoder\"\n", + " ).resblocks.get_layer(residual_attention).ln_1.weights[0].assign(\n", + " hf_wts.pop(\n", + " f\"vision_model.encoder.layers.{i}.layer_norm1.weight\"\n", + " ).numpy()\n", + " )\n", + " model.get_layer(\"image_encoder\").get_layer(\n", + " \"clip_encoder\"\n", + " ).resblocks.get_layer(residual_attention).ln_1.weights[1].assign(\n", + " hf_wts.pop(f\"vision_model.encoder.layers.{i}.layer_norm1.bias\").numpy()\n", + " )\n", + " model.get_layer(\"image_encoder\").get_layer(\n", + " \"clip_encoder\"\n", + " ).resblocks.get_layer(residual_attention).ln_2.weights[0].assign(\n", + " hf_wts.pop(\n", + " f\"vision_model.encoder.layers.{i}.layer_norm2.weight\"\n", + " ).numpy()\n", + " )\n", + " model.get_layer(\"image_encoder\").get_layer(\n", + " \"clip_encoder\"\n", + " ).resblocks.get_layer(residual_attention).ln_2.weights[1].assign(\n", + " hf_wts.pop(f\"vision_model.encoder.layers.{i}.layer_norm2.bias\").numpy()\n", + " )\n", + " model.get_layer(\"image_encoder\").get_layer(\n", + " \"clip_encoder\"\n", + " ).resblocks.get_layer(residual_attention).mlp.get_layer(\"c_fc\").weights[\n", + " 0\n", + " ].assign(\n", + " hf_wts.pop(f\"vision_model.encoder.layers.{i}.mlp.fc1.weight\")\n", + " .transpose(1, 0)\n", + " .numpy()\n", + " )\n", + " model.get_layer(\"image_encoder\").get_layer(\n", + " \"clip_encoder\"\n", + " ).resblocks.get_layer(residual_attention).mlp.get_layer(\"c_fc\").weights[\n", + " 1\n", + " ].assign(\n", + " hf_wts.pop(f\"vision_model.encoder.layers.{i}.mlp.fc1.bias\").numpy()\n", + " )\n", + " model.get_layer(\"image_encoder\").get_layer(\n", + " \"clip_encoder\"\n", + " ).resblocks.get_layer(residual_attention).mlp.get_layer(\"c_proj\").weights[\n", + " 0\n", + " ].assign(\n", + " hf_wts.pop(f\"vision_model.encoder.layers.{i}.mlp.fc2.weight\")\n", + " .transpose(1, 0)\n", + " .numpy()\n", + " )\n", + " model.get_layer(\"image_encoder\").get_layer(\n", + " \"clip_encoder\"\n", + " ).resblocks.get_layer(residual_attention).mlp.get_layer(\"c_proj\").weights[\n", + " 1\n", + " ].assign(\n", + " hf_wts.pop(f\"vision_model.encoder.layers.{i}.mlp.fc2.bias\").numpy()\n", + " )" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "1RN2aVrYG8T3" + }, + "source": [ + "## Text encoder" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "5FtDROnynb0N" + }, + "outputs": [], + "source": [ + "num_transformer_layers = MODEL_CONFIGS[config_name][\"vision_layers\"]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "_1AD7TcbdWEC" + }, + "outputs": [], + "source": [ + "model.get_layer(\"text_encoder\").get_layer(\"text_projector\").weights[0].assign(\n", + " hf_wts.pop(\"text_projection.weight\").numpy()\n", + ")\n", + "model.get_layer(\"text_encoder\").get_layer(\"token_embedding\").weights[0].assign(\n", + " hf_wts.pop(\"text_model.embeddings.token_embedding.weight\").numpy()\n", + ")\n", + "model.get_layer(\"text_encoder\").positional_embedding.assign(\n", + " hf_wts.pop(\"text_model.embeddings.position_embedding.weight\").numpy()\n", + ")\n", + "model.get_layer(\"text_encoder\").get_layer(\"ln_final\").weights[0].assign(\n", + " hf_wts.pop(\"text_model.final_layer_norm.weight\")\n", + ")\n", + "model.get_layer(\"text_encoder\").get_layer(\"ln_final\").weights[1].assign(\n", + " hf_wts.pop(\"text_model.final_layer_norm.bias\")\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "s6leOiFO6V2U" + }, + "outputs": [], + "source": [ + "for i in range(MODEL_CONFIGS[config_name][\"transformer_layers\"]):\n", + " model.get_layer(\"text_encoder\").get_layer(\n", + " \"clip_encoder\"\n", + " ).resblocks.get_layer(\n", + " f\"residual_attention_{num_transformer_layers+i}\"\n", + " ).attn.k_proj.weights[\n", + " 0\n", + " ].assign(\n", + " hf_wts.pop(f\"text_model.encoder.layers.{i}.self_attn.k_proj.weight\")\n", + " )\n", + " model.get_layer(\"text_encoder\").get_layer(\n", + " \"clip_encoder\"\n", + " ).resblocks.get_layer(\n", + " f\"residual_attention_{num_transformer_layers+i}\"\n", + " ).attn.k_proj.weights[\n", + " 1\n", + " ].assign(\n", + " hf_wts.pop(f\"text_model.encoder.layers.{i}.self_attn.k_proj.bias\")\n", + " )\n", + " model.get_layer(\"text_encoder\").get_layer(\n", + " \"clip_encoder\"\n", + " ).resblocks.get_layer(\n", + " f\"residual_attention_{num_transformer_layers+i}\"\n", + " ).attn.q_proj.weights[\n", + " 0\n", + " ].assign(\n", + " hf_wts.pop(f\"text_model.encoder.layers.{i}.self_attn.q_proj.weight\")\n", + " )\n", + " model.get_layer(\"text_encoder\").get_layer(\n", + " \"clip_encoder\"\n", + " ).resblocks.get_layer(\n", + " f\"residual_attention_{num_transformer_layers+i}\"\n", + " ).attn.q_proj.weights[\n", + " 1\n", + " ].assign(\n", + " hf_wts.pop(f\"text_model.encoder.layers.{i}.self_attn.q_proj.bias\")\n", + " )\n", + " model.get_layer(\"text_encoder\").get_layer(\n", + " \"clip_encoder\"\n", + " ).resblocks.get_layer(\n", + " f\"residual_attention_{num_transformer_layers+i}\"\n", + " ).attn.v_proj.weights[\n", + " 0\n", + " ].assign(\n", + " hf_wts.pop(f\"text_model.encoder.layers.{i}.self_attn.v_proj.weight\")\n", + " )\n", + " model.get_layer(\"text_encoder\").get_layer(\n", + " \"clip_encoder\"\n", + " ).resblocks.get_layer(\n", + " f\"residual_attention_{num_transformer_layers+i}\"\n", + " ).attn.v_proj.weights[\n", + " 1\n", + " ].assign(\n", + " hf_wts.pop(f\"text_model.encoder.layers.{i}.self_attn.v_proj.bias\")\n", + " )\n", + " model.get_layer(\"text_encoder\").get_layer(\n", + " \"clip_encoder\"\n", + " ).resblocks.get_layer(\n", + " f\"residual_attention_{num_transformer_layers+i}\"\n", + " ).attn.out_proj.weights[\n", + " 0\n", + " ].assign(\n", + " hf_wts.pop(f\"text_model.encoder.layers.{i}.self_attn.out_proj.weight\")\n", + " )\n", + " model.get_layer(\"text_encoder\").get_layer(\n", + " \"clip_encoder\"\n", + " ).resblocks.get_layer(\n", + " f\"residual_attention_{num_transformer_layers+i}\"\n", + " ).attn.out_proj.weights[\n", + " 1\n", + " ].assign(\n", + " hf_wts.pop(f\"text_model.encoder.layers.{i}.self_attn.out_proj.bias\")\n", + " )\n", + " model.get_layer(\"text_encoder\").get_layer(\n", + " \"clip_encoder\"\n", + " ).resblocks.get_layer(\n", + " f\"residual_attention_{num_transformer_layers+i}\"\n", + " ).ln_1.weights[\n", + " 0\n", + " ].assign(\n", + " hf_wts.pop(f\"text_model.encoder.layers.{i}.layer_norm1.weight\").numpy()\n", + " )\n", + " model.get_layer(\"text_encoder\").get_layer(\n", + " \"clip_encoder\"\n", + " ).resblocks.get_layer(\n", + " f\"residual_attention_{num_transformer_layers+i}\"\n", + " ).ln_1.weights[\n", + " 1\n", + " ].assign(\n", + " hf_wts.pop(f\"text_model.encoder.layers.{i}.layer_norm1.bias\").numpy()\n", + " )\n", + " model.get_layer(\"text_encoder\").get_layer(\n", + " \"clip_encoder\"\n", + " ).resblocks.get_layer(\n", + " f\"residual_attention_{num_transformer_layers+i}\"\n", + " ).ln_2.weights[\n", + " 0\n", + " ].assign(\n", + " hf_wts.pop(f\"text_model.encoder.layers.{i}.layer_norm2.weight\").numpy()\n", + " )\n", + " model.get_layer(\"text_encoder\").get_layer(\n", + " \"clip_encoder\"\n", + " ).resblocks.get_layer(\n", + " f\"residual_attention_{num_transformer_layers+i}\"\n", + " ).ln_2.weights[\n", + " 1\n", + " ].assign(\n", + " hf_wts.pop(f\"text_model.encoder.layers.{i}.layer_norm2.bias\").numpy()\n", + " )\n", + " model.get_layer(\"text_encoder\").get_layer(\n", + " \"clip_encoder\"\n", + " ).resblocks.get_layer(\n", + " f\"residual_attention_{num_transformer_layers+i}\"\n", + " ).mlp.get_layer(\n", + " \"c_fc\"\n", + " ).weights[\n", + " 0\n", + " ].assign(\n", + " hf_wts.pop(f\"text_model.encoder.layers.{i}.mlp.fc1.weight\")\n", + " .transpose(1, 0)\n", + " .numpy()\n", + " )\n", + " model.get_layer(\"text_encoder\").get_layer(\n", + " \"clip_encoder\"\n", + " ).resblocks.get_layer(\n", + " f\"residual_attention_{num_transformer_layers+i}\"\n", + " ).mlp.get_layer(\n", + " \"c_fc\"\n", + " ).weights[\n", + " 1\n", + " ].assign(\n", + " hf_wts.pop(f\"text_model.encoder.layers.{i}.mlp.fc1.bias\").numpy()\n", + " )\n", + " model.get_layer(\"text_encoder\").get_layer(\n", + " \"clip_encoder\"\n", + " ).resblocks.get_layer(\n", + " f\"residual_attention_{num_transformer_layers+i}\"\n", + " ).mlp.get_layer(\n", + " \"c_proj\"\n", + " ).weights[\n", + " 0\n", + " ].assign(\n", + " hf_wts.pop(f\"text_model.encoder.layers.{i}.mlp.fc2.weight\")\n", + " .transpose(1, 0)\n", + " .numpy()\n", + " )\n", + " model.get_layer(\"text_encoder\").get_layer(\n", + " \"clip_encoder\"\n", + " ).resblocks.get_layer(\n", + " f\"residual_attention_{num_transformer_layers+i}\"\n", + " ).mlp.get_layer(\n", + " \"c_proj\"\n", + " ).weights[\n", + " 1\n", + " ].assign(\n", + " hf_wts.pop(f\"text_model.encoder.layers.{i}.mlp.fc2.bias\").numpy()\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "Bgen7hxCCeZ7", + "outputId": "c777d6f1-4aa7-4f3e-8fd7-759364364c44" + }, + "outputs": [ + { + "data": { + "text/plain": [ + "odict_keys([])" + ] + }, + "execution_count": 22, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# verify that we copied all weights\n", + "hf_wts.keys()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "wlfDdO-mid62" + }, + "source": [ + "# save weights" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "QscCUUZFiqBV" + }, + "outputs": [], + "source": [ + "model.save_weights(\"clip-vit-base-patch32.weights.h5\")" + ] + } + ], + "metadata": { + "colab": { + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/requirements-common.txt b/requirements-common.txt index fc21cc5f96..29f7ee9a19 100644 --- a/requirements-common.txt +++ b/requirements-common.txt @@ -13,4 +13,4 @@ isort black pytest build -namex \ No newline at end of file +namex