Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add CLIP to KerasCV #2331

Merged
Merged
Show file tree
Hide file tree
Changes from 35 commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
67b2796
clip refactor
divyashreepathihalli Feb 2, 2024
88ae6a4
code cleanup and reformat
Feb 2, 2024
3aa5c6c
update encoder name
Feb 2, 2024
1f648b3
update clip encoder name
Feb 2, 2024
3c4743d
update clip encoder name in image encoder
Feb 2, 2024
54ec6e5
add weights conversion script
Feb 2, 2024
286d0c2
update setup to install keras-nlp
Feb 2, 2024
209e5da
new black formatting
Feb 2, 2024
91e6ea9
add preset file
Feb 2, 2024
2219bc2
update array
Feb 3, 2024
957b6c8
update clip prests kaggle handle
Feb 3, 2024
160d2a9
update text model
Feb 7, 2024
3c391ed
Merge branch 'keras-team:master' into clip_refactor_sub
divyashreepathihalli Feb 7, 2024
681120c
update text encoder
Feb 8, 2024
df73f23
update position embeddings
Feb 8, 2024
80bde9c
update positonal embeddings
Feb 8, 2024
5f7b23b
add attention masks
Feb 8, 2024
7530eed
update expanded mask
Feb 8, 2024
0211bd4
revert previous commit
Feb 8, 2024
d488b75
change causal masks
Feb 8, 2024
d9d1264
undo previous commit
Feb 8, 2024
64d66b5
update attention masks
Feb 8, 2024
de0be19
update clip encoder
Feb 8, 2024
4b8c1ef
add print statements
Feb 9, 2024
54f02e8
update the pooler output
Feb 9, 2024
f831638
remove print statements
Feb 9, 2024
79de15d
Merge pull request #2 from divyashreepathihalli/clip_refactor_sub
divyashreepathihalli Feb 9, 2024
3868bb5
add tests and preset
Feb 9, 2024
719417e
Merge pull request #3 from divyashreepathihalli/clip_refactor_sub
divyashreepathihalli Feb 9, 2024
39ccb18
Merge branch 'keras-team:master' into CLIP_refactor
divyashreepathihalli Feb 9, 2024
95d9e10
cleanup and reformat
Feb 13, 2024
d4c7e16
update build
Feb 14, 2024
305fb0a
add copywrite to presets file
Feb 14, 2024
9e6ff3b
fix build state errors
Feb 14, 2024
1c88b7e
update github actions and add preprocessor test
Feb 16, 2024
eb2bd44
incorporate review comments
Feb 16, 2024
38e00b7
add modifications from review
Feb 20, 2024
8eeb88e
change import checks
Feb 21, 2024
d5b2534
update keras_nlp import check
Feb 21, 2024
9a66464
update kokoro tests
Feb 21, 2024
a0b8e30
update kaggle preset version
Feb 21, 2024
fe2ac12
update install instructions for keras-nlp
Feb 21, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .github/workflows/actions.yml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ jobs:
pip install torch>=2.0.1+cpu
pip install "jax[cpu]"
pip install keras-core
pip install keras-nlp-nightly --no-deps
pip install tensorflow-text==2.15
divyashreepathihalli marked this conversation as resolved.
Show resolved Hide resolved
pip install -e ".[tests]" --progress-bar off --upgrade
- name: Test with pytest
env:
Expand Down Expand Up @@ -75,6 +77,7 @@ jobs:
run: |
pip install -r requirements.txt
pip install -e ".[tests]" --progress-bar off --upgrade
pip install keras-nlp-nightly
divyashreepathihalli marked this conversation as resolved.
Show resolved Hide resolved
- name: Test with pytest
env:
TEST_CUSTOM_OPS: false # TODO(ianstenbit): test custom ops, or figure out what our story is here
Expand Down
1 change: 1 addition & 0 deletions keras_cv/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,7 @@
from keras_cv.models.backbones.vit_det.vit_det_aliases import ViTDetLBackbone
from keras_cv.models.backbones.vit_det.vit_det_backbone import ViTDetBackbone
from keras_cv.models.classification.image_classifier import ImageClassifier
from keras_cv.models.feature_extractor.clip import CLIP
from keras_cv.models.object_detection.retinanet.retinanet import RetinaNet
from keras_cv.models.object_detection.yolo_v8.yolo_v8_backbone import (
YOLOV8Backbone,
Expand Down
13 changes: 13 additions & 0 deletions keras_cv/models/feature_extractor/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright 2023 The KerasCV Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
23 changes: 23 additions & 0 deletions keras_cv/models/feature_extractor/clip/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# Copyright 2023 The KerasCV Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from keras_cv.models.feature_extractor.clip.clip_image_model import (
CLIPImageEncoder,
)
from keras_cv.models.feature_extractor.clip.clip_model import CLIP
from keras_cv.models.feature_extractor.clip.clip_processor import CLIPProcessor
from keras_cv.models.feature_extractor.clip.clip_text_model import (
CLIPTextEncoder,
)
from keras_cv.models.feature_extractor.clip.clip_tokenizer import CLIPTokenizer
318 changes: 318 additions & 0 deletions keras_cv/models/feature_extractor/clip/clip_encoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,318 @@
# Copyright 2023 The KerasCV Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np

from keras_cv.backend import keras
from keras_cv.backend import ops


def get_initializer(initializer_range=0.02):
"""
Creates a `keras.initializers.TruncatedNormal` with the given range.

Args:
initializer_range (*float*, defaults to 0.02): Standard deviation of the
initializer range.

Returns:
`keras.initializers.TruncatedNormal`: The truncated normal initializer.
"""
return keras.initializers.TruncatedNormal(stddev=initializer_range)


class QuickGELU(keras.layers.Layer):
divyashreepathihalli marked this conversation as resolved.
Show resolved Hide resolved
def __init__(self, **kwargs):
super().__init__(**kwargs)

def call(self, x):
return x * ops.sigmoid(1.702 * x)


class ResidualAttention(keras.layers.Layer):
def __init__(
self,
proj_dim,
num_heads,
num_hidden_layers,
**kwargs,
):
super().__init__(**kwargs)
self.proj_dim = proj_dim
self.num_heads = num_heads
self.num_hidden_layers = num_hidden_layers
self.fc_std = np.power(2 * self.proj_dim, -0.5) * 0.02

self.in_proj_std = (
np.power(self.proj_dim, -0.5)
divyashreepathihalli marked this conversation as resolved.
Show resolved Hide resolved
* (np.power(2 * self.num_hidden_layers, -0.5))
* 0.02
)
self.attn = CLIPAttention(
self.proj_dim,
self.num_heads,
self.num_hidden_layers,
name="multi_head_attention",
)
self.ln_1 = keras.layers.LayerNormalization(epsilon=1e-5, name="ln_1")
self.mlp_dense_1 = keras.layers.Dense(
self.proj_dim * 4,
name="c_fc",
)
self.mlp_activation = QuickGELU(name="gelu")
self.mlp_dense_2 = keras.layers.Dense(
self.proj_dim,
name="c_proj",
)
self.ln_2 = keras.layers.LayerNormalization(epsilon=1e-5, name="ln_2")

def attention(self, x, causal_attention_mask=None, attention_mask=None):
mask = None
if causal_attention_mask is not None:
mask = (
ops.cast(causal_attention_mask, dtype=x.dtype)
if causal_attention_mask is not None
else None
)
if attention_mask is not None:
attention_mask = (
ops.cast(attention_mask, dtype=x.dtype)
if attention_mask is not None
else None
)
mask = ops.add(causal_attention_mask, attention_mask)
divyashreepathihalli marked this conversation as resolved.
Show resolved Hide resolved

return self.attn(
x,
attention_mask=mask,
)[0]

def build(self, input_shape):
super().build(input_shape)
self.attn.build(None)
self.ln_1.build([None, None, self.proj_dim])
self.mlp_dense_1.build([None, None, self.proj_dim])
self.mlp_dense_2.build([None, None, self.proj_dim * 4])
self.ln_2.build([None, None, self.proj_dim])

def call(self, x, causal_attention_mask=None, attention_mask=None):
attn_x = x + self.attention(
divyashreepathihalli marked this conversation as resolved.
Show resolved Hide resolved
self.ln_1(x),
causal_attention_mask=causal_attention_mask,
attention_mask=attention_mask,
)
x = self.mlp_dense_1(self.ln_2(attn_x))
x = self.mlp_activation(x)
x = self.mlp_dense_2(x)
x = attn_x + x
return x

def compute_output_shape(self, inputs_shape):
return inputs_shape

def get_config(self):
config = super().get_config()
config.update(
{
"proj_dim": self.proj_dim,
"num_heads": self.num_heads,
"num_hidden_layers": self.num_hidden_layers,
}
)
return config


class CLIPEncoder(keras.layers.Layer):
def __init__(self, width, num_layers, heads, **kwargs):
super().__init__(**kwargs)
self.width = width
self.num_layers = num_layers
self.heads = heads
self.resblocks = [
ResidualAttention(
self.width,
self.heads,
self.num_layers,
)
for _ in range(self.num_layers)
]

def build(self, input_shape):
super().build(input_shape)
for block in self.resblocks:
block.build(input_shape)

def call(
self,
x,
causal_attention_mask=None,
attention_mask=None,
):
for block in self.resblocks:
x = block(
x,
causal_attention_mask=causal_attention_mask,
attention_mask=attention_mask,
)
return x

def compute_output_shape(self, inputs_shape):
return inputs_shape

def get_config(self):
config = super().get_config()
config.update(
{
"width": self.width,
"num_layers": self.num_layers,
"heads": self.heads,
}
)
return config


class CLIPAttention(keras.layers.Layer):
"""
- Documentation page: https://huggingface.co/docs/transformers/model_doc/clip # noqa: E501
divyashreepathihalli marked this conversation as resolved.
Show resolved Hide resolved
- Implementation: https://github.com/huggingface/transformers/blob/main/src/transformers/models/clip/modeling_clip.py # noqa: E501
"""

def __init__(
self, proj_dim, num_heads, num_hidden_layers, dropout=0.0, **kwargs
):
super().__init__(**kwargs)

self.proj_dim = proj_dim
self.num_heads = num_heads
self.num_hidden_layers = num_hidden_layers
self.dropout = dropout
self.head_dim = self.proj_dim // self.num_heads
if self.head_dim * self.num_heads != self.proj_dim:
raise ValueError(
f"proj_dim must be divisible by num_heads (got `proj_dim`"
f": {self.proj_dim} and `num_heads`:"
f" {self.num_heads})."
)

self.scale = self.head_dim**-0.5
in_proj_std = (
(self.proj_dim**-0.5)
* ((2 * self.num_hidden_layers) ** -0.5)
* 0.02
)
out_proj_std = (self.proj_dim**-0.5) * 0.02
self.q_proj = keras.layers.Dense(
units=self.proj_dim,
kernel_initializer=get_initializer(in_proj_std),
name="q_proj",
)
self.k_proj = keras.layers.Dense(
units=self.proj_dim,
kernel_initializer=get_initializer(in_proj_std),
name="k_proj",
)
self.v_proj = keras.layers.Dense(
units=self.proj_dim,
kernel_initializer=get_initializer(in_proj_std),
name="v_proj",
)
self.out_proj = keras.layers.Dense(
units=self.proj_dim,
kernel_initializer=get_initializer(out_proj_std),
name="out_proj",
)

def build(self, input_shape):
super().build(input_shape)
self.q_proj.build([None, None, self.proj_dim])
self.k_proj.build([None, None, self.proj_dim])
self.v_proj.build([None, None, self.proj_dim])
self.out_proj.build([None, None, self.proj_dim])

def _transpose_for_scores(self, tensor, batch_size):
"""
Copied from https://github.com/huggingface/transformers/blob/8e164c5400b7b413c7b8fb32e35132001effc970/src/transformers/models/bert/modeling_tf_bert.py#L252 # noqa: E501
"""
# [batch_size, seq_len, all_head_dim] ->
# [batch_size, seq_len, num_heads, head_dim]
tensor = ops.reshape(
tensor, (batch_size, -1, self.num_heads, self.head_dim)
)
# [batch_size, seq_len, num_heads, head_dim] ->
# [batch_size, num_heads, seq_len, head_dim]
return ops.transpose(tensor, axes=[0, 2, 1, 3])

def call(
self,
x,
attention_mask=None,
output_attentions=None,
training=False,
):
batch_size = ops.shape(x)[0]
mixed_query_layer = self.q_proj(inputs=x)
mixed_key_layer = self.k_proj(inputs=x)
mixed_value_layer = self.v_proj(inputs=x)
query_layer = self._transpose_for_scores(mixed_query_layer, batch_size)
key_layer = self._transpose_for_scores(mixed_key_layer, batch_size)
value_layer = self._transpose_for_scores(mixed_value_layer, batch_size)

# Scaled dot product between key and query = raw attention scores.
attention_scores = ops.matmul(
query_layer, ops.transpose(key_layer, axes=[0, 1, 3, 2])
)
dk = ops.cast(ops.sqrt(self.head_dim), dtype=attention_scores.dtype)
attention_scores = ops.divide(
attention_scores, dk
) # (batch_size, num_heads, seq_len_q, seq_len_k)

if attention_mask is not None:
# Apply the attention mask (precomputed for all layers in the
# call() function)
attention_scores = ops.add(attention_scores, attention_mask)
sampathweb marked this conversation as resolved.
Show resolved Hide resolved

# Normalize the attention scores to probabilities.
_attention_probs = ops.softmax(attention_scores + 1e-9, axis=-1)
divyashreepathihalli marked this conversation as resolved.
Show resolved Hide resolved

# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_probs = keras.layers.Dropout(self.dropout)(
inputs=_attention_probs, training=training
divyashreepathihalli marked this conversation as resolved.
Show resolved Hide resolved
)

attn_output = ops.matmul(attention_probs, value_layer)
attn_output = ops.transpose(attn_output, axes=[0, 2, 1, 3])

# (batch_size, seq_len_q, proj_dim)
attn_output = ops.reshape(attn_output, (batch_size, -1, self.proj_dim))

attn_output = self.out_proj(attn_output, training=training)
outputs = (
(attn_output, _attention_probs)
if output_attentions
else (attn_output,)
)

return outputs

def get_config(self):
config = super().get_config()
config.update(
{
"proj_dim": self.proj_dim,
"num_heads": self.num_heads,
"num_hidden_layers": self.num_hidden_layers,
"dropout": self.dropout,
}
)
return config
Loading
Loading