-
Notifications
You must be signed in to change notification settings - Fork 330
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Video Swin Transformer #2369
Merged
Merged
Changes from all commits
Commits
Show all changes
94 commits
Select commit
Hold shift + click to select a range
f961e75
init video swin
innat 578205a
add: 3d window size computation
innat 9817025
add: mlp layer
innat 3343db1
add: patch embedding layer
innat 7ab5cab
add: patch merging layer
innat f70a61b
add: window attention layer
innat 5472fc6
add: basic layer for video swin
innat 76d444b
update: basic layer for video swin
innat 715b8a3
add: swin blocks for video swin
innat 3ca0042
create and add: video swin backbone
innat 3d845c5
rename: video swin layers to model specific
innat 1af8bd4
update module import
innat ed2864d
update module import
innat bf70fa9
set class method to private usage
innat eca5023
set init params for backbone
innat 420e229
rm redundant imports
innat f73e25b
add video swin layer test cases
innat 1ccf7ee
add: videoswin backbone aliases
innat c5d5fa2
add: video swin backbone presets
innat 27b6596
add: video swin backbone presets test
innat 814db52
update: video swin backbone presets test
innat cc6ac21
add: video classifier task
innat d2d883d
add: video swin classifier presets
innat 125b2dc
run formatters
innat 9827302
rename module name/id"
innat 89a715a
add hard-coded normalization for include rescaling=true
innat 36db030
add docstring for videoswin backbone
innat 7aa27a4
update metadata: backbone presets no weights
innat 62a8703
update: backbone presets no weights test
innat aad5661
update video swin aliases for no weights
innat 048d85a
add: video swin backbone presets with weights
innat 1423e83
update: video swin aliases with weights presets
innat 2eaf8b0
update video swin layer test cases
innat f713304
added patch merging test
innat 44dae81
imported video swins presets to backbone presets list"
innat daca84f
fix: typos"
innat b1a5427
run formatters"
innat c66673c
fix: linting issue
innat 84d4e03
fix: linting issue
innat d126b7c
fix: video swin layer test cases"
innat 61303be
add: video swin backbone test
innat af5878c
rm redundant code
innat ffe457c
disable preset test temporary
innat f8d3e26
set include rescale to true
innat 1d0ad36
add video swin components to __init__
innat 838a506
update docstrings: video siwn layers scripts
innat b4f1534
update copywrite status: video siwn layers test scripts
innat 75c5b66
update copywrite status: video siwn backbone scripts
innat 0b9808b
bug fixes: video swin backbone layers
innat 0a4e2cb
update get config of video swin backbone
innat fb732d0
enable: video swin backbone test cases
innat 4443335
update: video swin backbone test cases
innat f3411cb
update: video swin backbone preset test cases
innat 00c67ba
run formatters
innat 9d3ab2e
fix typos: video swin backbone test cases
innat 5bdc8b4
add: non implemented property for test reason
innat cb5da28
fix: typos
innat 82a8497
add: video classifier test
innat e2f5056
update: video classifier test
innat 146f32f
update: video classifier test input shape
innat d25746b
bug fix: mlp layer build method
innat 9779ad4
updated: swin back layer build method
innat 7fa3f83
bug fix: use tf.TensorShape in compute_output_shape method
innat c8aea50
update: video_classifier_test model.predict to model.call
innat 8287395
update test cases and format the code
innat e9a3997
update docstrings and preset config
innat aab1a6c
fix jax DynamicJaxprTrace issue for
innat ac78108
update config of backbone aliases
innat 1dbded9
add can run in mixed precision test
innat 42003a2
add can run on gray video
innat e731389
minor fix
innat 77197c2
specify axis in keras.ops.take to match with tf.gather
innat aa20067
specify include rescaling to backbone class
innat 11f33d7
remove shift size form get config of video basic layer
innat a2961b9
add support arbitrary input shape
innat 49b074a
minor updates to swin layers
innat 204e4b1
test method update for swin layers
innat 251495b
update test method to swin backbone
innat 599d481
remove unsed code
innat a849b38
bug fix in call method of patch embed layer
innat f611b0e
fix typo in patch merging layer
innat b7d26e4
minor fix
innat e3e02dc
fix keras.ops.cond issue with jax
innat a626b1f
no test for jit compile in torch
innat c484445
reduce tensor size for forward test
innat 45945c9
minor fix
innat f866d12
remove kcv export decorator
innat bfb62a4
update keras.Layer import
innat 57f0012
remove unused layer import
innat 7602052
replace keras.layers instead of layers
innat 837286d
update keras.Layer to keras.layers.Layer for keras2
innat 6d44eca
add window_size param to aliases
innat f5dce04
move vide swin layer to model specific directory
innat 0ba9fdf
minor fix
innat File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
# Copyright 2024 The KerasCV Authors | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. |
158 changes: 158 additions & 0 deletions
158
keras_cv/models/backbones/video_swin/video_swin_aliases.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,158 @@ | ||
# Copyright 2024 The KerasCV Authors | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import copy | ||
|
||
from keras_cv.models.backbones.video_swin.video_swin_backbone import ( | ||
VideoSwinBackbone, | ||
) | ||
from keras_cv.models.backbones.video_swin.video_swin_backbone_presets import ( | ||
backbone_presets, | ||
) | ||
from keras_cv.utils.python_utils import classproperty | ||
|
||
ALIAS_DOCSTRING = """VideoSwin{size}Backbone model. | ||
|
||
Reference: | ||
- [Video Swin Transformer](https://arxiv.org/abs/2106.13230) | ||
- [Video Swin Transformer GitHub](https://github.com/SwinTransformer/Video-Swin-Transformer) | ||
|
||
For transfer learning use cases, make sure to read the | ||
[guide to transfer learning & fine-tuning](https://keras.io/guides/transfer_learning/). | ||
|
||
Examples: | ||
```python | ||
input_data = np.ones(shape=(1, 32, 224, 224, 3)) | ||
|
||
# Randomly initialized backbone | ||
model = VideoSwin{size}Backbone() | ||
output = model(input_data) | ||
``` | ||
""" # noqa: E501 | ||
|
||
|
||
class VideoSwinTBackbone(VideoSwinBackbone): | ||
def __new__( | ||
cls, | ||
embed_dim=96, | ||
depths=[2, 2, 6, 2], | ||
num_heads=[3, 6, 12, 24], | ||
window_size=[8, 7, 7], | ||
include_rescaling=True, | ||
**kwargs, | ||
): | ||
kwargs.update( | ||
{ | ||
"embed_dim": embed_dim, | ||
"depths": depths, | ||
"num_heads": num_heads, | ||
"window_size": window_size, | ||
"include_rescaling": include_rescaling, | ||
} | ||
) | ||
return VideoSwinBackbone.from_preset("videoswin_tiny", **kwargs) | ||
|
||
@classproperty | ||
def presets(cls): | ||
"""Dictionary of preset names and configurations.""" | ||
return { | ||
"videoswin_tiny_kinetics400": copy.deepcopy( | ||
backbone_presets["videoswin_tiny_kinetics400"] | ||
), | ||
} | ||
|
||
@classproperty | ||
def presets_with_weights(cls): | ||
"""Dictionary of preset names and configurations that include | ||
weights.""" | ||
return cls.presets | ||
|
||
|
||
class VideoSwinSBackbone(VideoSwinBackbone): | ||
def __new__( | ||
cls, | ||
embed_dim=96, | ||
depths=[2, 2, 18, 2], | ||
num_heads=[3, 6, 12, 24], | ||
window_size=[8, 7, 7], | ||
include_rescaling=True, | ||
**kwargs, | ||
): | ||
kwargs.update( | ||
{ | ||
"embed_dim": embed_dim, | ||
"depths": depths, | ||
"num_heads": num_heads, | ||
"window_size": window_size, | ||
"include_rescaling": include_rescaling, | ||
} | ||
) | ||
return VideoSwinBackbone.from_preset("videoswin_small", **kwargs) | ||
|
||
@classproperty | ||
def presets(cls): | ||
"""Dictionary of preset names and configurations.""" | ||
return { | ||
"videoswin_small_kinetics400": copy.deepcopy( | ||
backbone_presets["videoswin_small_kinetics400"] | ||
), | ||
} | ||
|
||
@classproperty | ||
def presets_with_weights(cls): | ||
"""Dictionary of preset names and configurations that include | ||
weights.""" | ||
return cls.presets | ||
|
||
|
||
class VideoSwinBBackbone(VideoSwinBackbone): | ||
def __new__( | ||
cls, | ||
embed_dim=128, | ||
depths=[2, 2, 18, 2], | ||
num_heads=[4, 8, 16, 32], | ||
window_size=[8, 7, 7], | ||
include_rescaling=True, | ||
**kwargs, | ||
): | ||
kwargs.update( | ||
{ | ||
"embed_dim": embed_dim, | ||
"depths": depths, | ||
"num_heads": num_heads, | ||
"window_size": window_size, | ||
"include_rescaling": include_rescaling, | ||
} | ||
) | ||
return VideoSwinBackbone.from_preset("videoswin_base", **kwargs) | ||
|
||
@classproperty | ||
def presets(cls): | ||
"""Dictionary of preset names and configurations.""" | ||
return { | ||
"videoswin_base_kinetics400": copy.deepcopy( | ||
backbone_presets["videoswin_base_kinetics400"] | ||
), | ||
} | ||
|
||
@classproperty | ||
def presets_with_weights(cls): | ||
"""Dictionary of preset names and configurations that include | ||
weights.""" | ||
return cls.presets | ||
|
||
|
||
setattr(VideoSwinTBackbone, "__doc__", ALIAS_DOCSTRING.format(size="T")) | ||
setattr(VideoSwinSBackbone, "__doc__", ALIAS_DOCSTRING.format(size="S")) | ||
setattr(VideoSwinBBackbone, "__doc__", ALIAS_DOCSTRING.format(size="B")) |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The backbone base model has more than one checkpints.
How to facilitate the
preset
method for all of these?