Skip to content

Commit

Permalink
Merge branch 'fix_clip_jax' of github.com:divyashreepathihalli/keras-…
Browse files Browse the repository at this point in the history
…cv into clip_refactor
  • Loading branch information
tirthasheshpatel committed Apr 8, 2024
2 parents e923687 + 30f5209 commit 8fdabd5
Show file tree
Hide file tree
Showing 16 changed files with 2,184 additions and 4 deletions.
13 changes: 13 additions & 0 deletions keras_cv/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,11 +179,24 @@
ResNetV2Backbone,
)
from keras_cv.models.backbones.vgg16.vgg16_backbone import VGG16Backbone
from keras_cv.models.backbones.video_swin.video_swin_aliases import (
VideoSwinBBackbone,
)
from keras_cv.models.backbones.video_swin.video_swin_aliases import (
VideoSwinSBackbone,
)
from keras_cv.models.backbones.video_swin.video_swin_aliases import (
VideoSwinTBackbone,
)
from keras_cv.models.backbones.video_swin.video_swin_backbone import (
VideoSwinBackbone,
)
from keras_cv.models.backbones.vit_det.vit_det_aliases import ViTDetBBackbone
from keras_cv.models.backbones.vit_det.vit_det_aliases import ViTDetHBackbone
from keras_cv.models.backbones.vit_det.vit_det_aliases import ViTDetLBackbone
from keras_cv.models.backbones.vit_det.vit_det_backbone import ViTDetBackbone
from keras_cv.models.classification.image_classifier import ImageClassifier
from keras_cv.models.classification.video_classifier import VideoClassifier
from keras_cv.models.feature_extractor.clip import CLIP
from keras_cv.models.object_detection.retinanet.retinanet import RetinaNet
from keras_cv.models.object_detection.yolo_v8.yolo_v8_backbone import (
Expand Down
3 changes: 3 additions & 0 deletions keras_cv/models/backbones/backbone_presets.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from keras_cv.models.backbones.mobilenet_v3 import mobilenet_v3_backbone_presets
from keras_cv.models.backbones.resnet_v1 import resnet_v1_backbone_presets
from keras_cv.models.backbones.resnet_v2 import resnet_v2_backbone_presets
from keras_cv.models.backbones.video_swin import video_swin_backbone_presets
from keras_cv.models.backbones.vit_det import vit_det_backbone_presets
from keras_cv.models.object_detection.yolo_v8 import yolo_v8_backbone_presets

Expand All @@ -42,6 +43,7 @@
**efficientnet_lite_backbone_presets.backbone_presets_no_weights,
**yolo_v8_backbone_presets.backbone_presets_no_weights,
**vit_det_backbone_presets.backbone_presets_no_weights,
**video_swin_backbone_presets.backbone_presets_no_weights,
}

backbone_presets_with_weights = {
Expand All @@ -55,6 +57,7 @@
**efficientnet_lite_backbone_presets.backbone_presets_with_weights,
**yolo_v8_backbone_presets.backbone_presets_with_weights,
**vit_det_backbone_presets.backbone_presets_with_weights,
**video_swin_backbone_presets.backbone_presets_with_weights,
}

backbone_presets = {
Expand Down
13 changes: 13 additions & 0 deletions keras_cv/models/backbones/video_swin/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright 2024 The KerasCV Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
158 changes: 158 additions & 0 deletions keras_cv/models/backbones/video_swin/video_swin_aliases.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
# Copyright 2024 The KerasCV Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import copy

from keras_cv.models.backbones.video_swin.video_swin_backbone import (
VideoSwinBackbone,
)
from keras_cv.models.backbones.video_swin.video_swin_backbone_presets import (
backbone_presets,
)
from keras_cv.utils.python_utils import classproperty

ALIAS_DOCSTRING = """VideoSwin{size}Backbone model.
Reference:
- [Video Swin Transformer](https://arxiv.org/abs/2106.13230)
- [Video Swin Transformer GitHub](https://github.com/SwinTransformer/Video-Swin-Transformer)
For transfer learning use cases, make sure to read the
[guide to transfer learning & fine-tuning](https://keras.io/guides/transfer_learning/).
Examples:
```python
input_data = np.ones(shape=(1, 32, 224, 224, 3))
# Randomly initialized backbone
model = VideoSwin{size}Backbone()
output = model(input_data)
```
""" # noqa: E501


class VideoSwinTBackbone(VideoSwinBackbone):
def __new__(
cls,
embed_dim=96,
depths=[2, 2, 6, 2],
num_heads=[3, 6, 12, 24],
window_size=[8, 7, 7],
include_rescaling=True,
**kwargs,
):
kwargs.update(
{
"embed_dim": embed_dim,
"depths": depths,
"num_heads": num_heads,
"window_size": window_size,
"include_rescaling": include_rescaling,
}
)
return VideoSwinBackbone.from_preset("videoswin_tiny", **kwargs)

@classproperty
def presets(cls):
"""Dictionary of preset names and configurations."""
return {
"videoswin_tiny_kinetics400": copy.deepcopy(
backbone_presets["videoswin_tiny_kinetics400"]
),
}

@classproperty
def presets_with_weights(cls):
"""Dictionary of preset names and configurations that include
weights."""
return cls.presets


class VideoSwinSBackbone(VideoSwinBackbone):
def __new__(
cls,
embed_dim=96,
depths=[2, 2, 18, 2],
num_heads=[3, 6, 12, 24],
window_size=[8, 7, 7],
include_rescaling=True,
**kwargs,
):
kwargs.update(
{
"embed_dim": embed_dim,
"depths": depths,
"num_heads": num_heads,
"window_size": window_size,
"include_rescaling": include_rescaling,
}
)
return VideoSwinBackbone.from_preset("videoswin_small", **kwargs)

@classproperty
def presets(cls):
"""Dictionary of preset names and configurations."""
return {
"videoswin_small_kinetics400": copy.deepcopy(
backbone_presets["videoswin_small_kinetics400"]
),
}

@classproperty
def presets_with_weights(cls):
"""Dictionary of preset names and configurations that include
weights."""
return cls.presets


class VideoSwinBBackbone(VideoSwinBackbone):
def __new__(
cls,
embed_dim=128,
depths=[2, 2, 18, 2],
num_heads=[4, 8, 16, 32],
window_size=[8, 7, 7],
include_rescaling=True,
**kwargs,
):
kwargs.update(
{
"embed_dim": embed_dim,
"depths": depths,
"num_heads": num_heads,
"window_size": window_size,
"include_rescaling": include_rescaling,
}
)
return VideoSwinBackbone.from_preset("videoswin_base", **kwargs)

@classproperty
def presets(cls):
"""Dictionary of preset names and configurations."""
return {
"videoswin_base_kinetics400": copy.deepcopy(
backbone_presets["videoswin_base_kinetics400"]
),
}

@classproperty
def presets_with_weights(cls):
"""Dictionary of preset names and configurations that include
weights."""
return cls.presets


setattr(VideoSwinTBackbone, "__doc__", ALIAS_DOCSTRING.format(size="T"))
setattr(VideoSwinSBackbone, "__doc__", ALIAS_DOCSTRING.format(size="S"))
setattr(VideoSwinBBackbone, "__doc__", ALIAS_DOCSTRING.format(size="B"))
Loading

0 comments on commit 8fdabd5

Please sign in to comment.