Skip to content

Commit

Permalink
Refactor CLIP to a functional model
Browse files Browse the repository at this point in the history
update model input format

update golden values

update CLIP to functional model

update tests

code reformat

use dict instead of list

Update keras_cv/models/feature_extractor/clip/clip_model.py

Co-authored-by: Tirth Patel <tirthasheshpatel@gmail.com>

remove build and compute output shape

update model input format

update golden values

Refactor CLIP

Refactor includes:

- CLIPProcessor is now a Keras layer and uses some utilities from KerasNLP to support all types of python types and array inputs
- CLIPImageEncoder, CLIPTextEncoder, and CLIPEncoder now implement a `.compute_output_shape` method (required for CLIP to work with the functional API)
- CLIPHead added to remove raw variables from the CLIP Task models; having variables in `keras.Model` class is tricky since functional API doesn't allow state.
- CLIP checkpointing script has been updated to now work with the new API: new weights will be uploaded to Kaggle.

TODO: attribute KerasNLP wherever relevant
TODO: upload new weights to Kaggle
TODO: refactor the CLIPProcessor class and the CLIP class to also pull tokenizer vocab and merges from Kaggle.

remove build and compute output shape

Some fixes for the refactor

Fix the tests, update presets

update to layers instead of models
  • Loading branch information
Divyashree Sreepathihalli authored and tirthasheshpatel committed Apr 8, 2024
1 parent bfeba12 commit c6e84b9
Show file tree
Hide file tree
Showing 10 changed files with 625 additions and 701 deletions.
2 changes: 2 additions & 0 deletions .kokoro/github/ubuntu/gpu/build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ then
keras_cv/models/object_detection/yolo_v8 \
keras_cv/models/object_detection_3d \
keras_cv/models/segmentation \
keras_cv/models/feature_extractor/clip \
keras_cv/models/stable_diffusion
else
pytest --cache-clear --check_gpu --run_large --durations 0 \
Expand All @@ -83,5 +84,6 @@ else
keras_cv/models/object_detection/yolo_v8 \
keras_cv/models/object_detection_3d \
keras_cv/models/segmentation \
keras_cv/models/feature_extractor/clip \
keras_cv/models/stable_diffusion
fi
10 changes: 6 additions & 4 deletions keras_cv/models/feature_extractor/clip/clip_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,9 +156,14 @@ def __init__(self, width, num_layers, heads, **kwargs):
]

def build(self, input_shape):
super().build(input_shape)
for block in self.resblocks:
block.build(input_shape)
self.built = True

def compute_output_shape(self, input_shape):
for block in self.resblocks:
input_shape = block.compute_output_shape(input_shape)
return input_shape

def call(
self,
Expand All @@ -174,9 +179,6 @@ def call(
)
return x

def compute_output_shape(self, inputs_shape):
return inputs_shape

def get_config(self):
config = super().get_config()
config.update(
Expand Down
15 changes: 12 additions & 3 deletions keras_cv/models/feature_extractor/clip/clip_image_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from keras_cv.models.feature_extractor.clip.clip_encoder import get_initializer


@keras_cv_export("keras_cv.models.feature_extractor.CLIPPatchingAndEmbedding")
class CLIPPatchingAndEmbedding(keras.layers.Layer):
def __init__(
self, width, patch_size, input_resolution, output_dim, **kwargs
Expand Down Expand Up @@ -67,6 +66,13 @@ def build(self, input_shape):
name="patch_embed.positional_embedding",
)

def compute_output_shape(self, input_shape):
return [
None,
(self.input_resolution // self.patch_size) ** 2 + 1,
self.width,
]

def call(self, x):
batch_size = ops.shape(x)[0]
patch_embeddings = self.conv1(x) # shape = [*, grid, grid, channel]
Expand Down Expand Up @@ -143,12 +149,15 @@ def __init__(
)

def build(self, input_shape):
super().build(input_shape)
self.embeddings.build(input_shape)
self.pre_norm.build([None, None, self.width])
self.encoder.build(None)
self.post_norm.build([None, self.width])
self.image_projector.build([None, None, self.width])
self.image_projector.build([None, self.width])
self.built = True

def compute_output_shape(self, input_shape):
return [input_shape[0], self.output_dim]

def call(self, image):
x = self.embeddings(image)
Expand Down
203 changes: 120 additions & 83 deletions keras_cv/models/feature_extractor/clip/clip_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import warnings

from keras_cv.api_export import keras_cv_export
from keras_cv.backend import keras
Expand All @@ -34,6 +35,41 @@
keras_nlp = None


class CLIPHead(keras.layers.Layer):
def __init__(self, **kwargs):
super().__init__(**kwargs)

def build(self, input_shape):
self.logit_scale = self.add_variable(
shape=(),
initializer=lambda *a, **kw: ops.log(1 / 0.07),
trainable=True,
dtype=self.variable_dtype,
name="logit_scale",
)
self.built = True

def call(self, image_embeddings, text_embeddings):
normalize_image_features = ops.sqrt(
ops.sum(ops.power(image_embeddings, 2), keepdims=True)
)
normalize_text_features = ops.sqrt(
ops.sum(ops.power(text_embeddings, 2), keepdims=True)
)
image_embeddings = image_embeddings / normalize_image_features
text_embeddings = text_embeddings / normalize_text_features
logit_scale = ops.exp(self.logit_scale)
image_logits = (
ops.matmul(
image_embeddings,
ops.transpose(text_embeddings),
)
* logit_scale
)
text_logits = ops.transpose(image_logits)
return image_logits, text_logits


@keras_cv_export(["keras_cv.models.CLIP"])
class CLIP(Task):
"""
Expand Down Expand Up @@ -61,25 +97,27 @@ class CLIP(Task):
transformer-based text encoder.
transformer_layers (int): The number of layers in the transformer-based
text encoder.
Example:
```python
processor = CLIPProcessor(
input_resolution=224,
"path_to_vocab.json",
"path_to_merges.txt"
)
input_resolution=224,
"path_to_vocab.json",
"path_to_merges.txt"
)
processed_image = processor.process_images(["cat.jpg"])
processed_text, attention_mask = processor.process_texts(
["mountains", "cat on tortoise", "two cats"]
)
tokens = processor(
["mountains", "cat on tortoise", "two cats"]
)
model = CLIP.from_preset("clip-vit-base-patch16")
image_logits, text_logits = model(
{
"image": processed_image,
"text": processed_text,
"attention_mask": attention_mask,
}
)
{
"images": processed_image,
"token_ids": tokens["token_ids"],
"padding_mask": tokens["padding_mask"],
}
)
```
"""

Expand All @@ -97,12 +135,77 @@ def __init__(
transformer_layers=12,
**kwargs,
):
super().__init__(**kwargs)
if keras_nlp is None:
raise ValueError(
"ClipTokenizer requires keras-nlp. Please install "
"using pip `pip install -U keras-nlp && pip install -U keras`"
)

if "dtype" in kwargs:
kwargs.pop("dtype")
# warnings.warn("Currently, CLIP doesn't support passing the "
# "`dtype` kwarg. Instead, use "
# "`keras.mixed_precision` to set dtype policies.",
# UserWarning)

vision_heads = vision_width // 64

images = keras.Input(
shape=[image_resolution, image_resolution, 3], name="images"
)
token_ids = keras.Input(
shape=[
context_length,
],
name="token_ids",
)
padding_mask = keras.Input(
shape=[
context_length,
],
name="padding_mask",
)

image_encoder = CLIPImageEncoder(
input_resolution=image_resolution,
patch_size=vision_patch_size,
width=vision_width,
num_layers=vision_layers,
heads=vision_heads,
output_dim=embed_dim,
name="image_encoder",
)
text_encoder = CLIPTextEncoder(
transformer_width=transformer_width,
transformer_layers=transformer_layers,
transformer_heads=transformer_heads,
vocab_size=vocab_size,
embed_dim=embed_dim,
context_length=context_length,
name="text_encoder",
)
clip_head = CLIPHead(name="clip_head")

image_embeddings = image_encoder(images)
text_embeddings = text_encoder(token_ids, attention_mask=padding_mask)
image_logits, text_logits = clip_head(image_embeddings, text_embeddings)

inputs = {
"images": images,
"token_ids": token_ids,
"padding_mask": padding_mask,
}
outputs = {
"image_logits": image_logits,
"text_logits": text_logits,
}

super().__init__(
inputs=inputs,
outputs=outputs,
**kwargs,
)

self.embed_dim = embed_dim
self.image_resolution = image_resolution
self.vision_layers = vision_layers
Expand All @@ -113,75 +216,9 @@ def __init__(
self.transformer_width = transformer_width
self.transformer_heads = transformer_heads
self.transformer_layers = transformer_layers

vision_heads = self.vision_width // 64
self.image_encoder = CLIPImageEncoder(
input_resolution=self.image_resolution,
patch_size=self.vision_patch_size,
width=self.vision_width,
num_layers=self.vision_layers,
heads=vision_heads,
output_dim=self.embed_dim,
name="image_encoder",
)
self.text_encoder = CLIPTextEncoder(
transformer_width=self.transformer_width,
transformer_layers=self.transformer_layers,
transformer_heads=self.transformer_heads,
vocab_size=self.vocab_size,
embed_dim=self.embed_dim,
context_length=self.context_length,
name="text_encoder",
)

self.logit_scale = keras.Variable(
ops.ones([]) * ops.log(1 / 0.07), name="logit_scale"
)
self.image_embeddings = None
self.text_embeddings = None

def build(self, input_shape):
super().build(input_shape)
self.text_encoder.build([None, self.context_length])
self.image_encoder.build(
[None, self.image_resolution, self.image_resolution, 3]
)

def encode_images(self, image):
return self.image_encoder(image)

def encode_text(self, text, attention_mask=None):
return self.text_encoder(text, attention_mask=attention_mask)

def call(self, inputs):
image, text = inputs["image"], inputs["text"]
if "attention_mask" in inputs:
attention_mask = inputs["attention_mask"]
else:
attention_mask = None
self.image_embeddings = self.encode_images(image)
self.text_embeddings = self.encode_text(
text, attention_mask=attention_mask
)
normalize_image_features = ops.sqrt(
ops.sum(ops.power(self.image_embeddings, 2), keepdims=True)
)
normalize_text_features = ops.sqrt(
ops.sum(ops.power(self.text_embeddings, 2), keepdims=True)
)
self.image_embeddings = self.image_embeddings / normalize_image_features
self.text_embeddings = self.text_embeddings / normalize_text_features
logit_scale = ops.exp(self.logit_scale)
logits_per_image = (
ops.matmul(
self.image_embeddings,
ops.transpose(self.text_embeddings),
)
* logit_scale
)
logits_per_text = ops.transpose(logits_per_image)

return logits_per_image, logits_per_text
self.image_encoder = image_encoder
self.text_encoder = text_encoder
self.clip_head = clip_head

@classproperty
def presets(cls):
Expand Down
Loading

0 comments on commit c6e84b9

Please sign in to comment.