-
Notifications
You must be signed in to change notification settings - Fork 95
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
5f5d038
commit c4dcf75
Showing
6 changed files
with
402 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
# ___Keras ResNet Family___ | ||
*** | ||
|
||
## Summary | ||
- Keras implementation of [Github facebookresearch/ResNeXt](https://github.com/facebookresearch/ResNeXt). Paper [PDF 1611.05431 Aggregated Residual Transformations for Deep Neural Networks](https://arxiv.org/pdf/1611.05431.pdf). | ||
- Model weights reloaded from [Tensorflow keras/applications](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/keras/applications/resnet.py). | ||
*** | ||
|
||
## Models | ||
| Model | Params | Image resolution | Top1 Acc | Download | | ||
| --------------------- | ------ | ----------------- | -------- | ------------------- | | ||
| resnext50 (32x4d) | 25M | 224 | 77.74 | [resnext50.h5](https://github.com/leondgarse/keras_cv_attention_models/releases/download/resnext/resnext50.h5) | | ||
| resnext101 (32x4d) | 42M | 224 | 78.73 | [resnext101.h5](https://github.com/leondgarse/keras_cv_attention_models/releases/download/resnext/resnext101.h5) | | ||
## Usage | ||
```py | ||
from keras_cv_attention_models import resnext | ||
|
||
# Will download and load pretrained imagenet weights. | ||
mm = resnext.ResNeXt50(pretrained="imagenet") | ||
|
||
# Run prediction | ||
from skimage.data import chelsea | ||
imm = keras.applications.imagenet_utils.preprocess_input(chelsea(), mode='tf') # Chelsea the cat | ||
pred = mm(tf.expand_dims(tf.image.resize(imm, mm.input_shape[1:3]), 0)).numpy() | ||
print(keras.applications.imagenet_utils.decode_predictions(pred)[0]) | ||
# [('n02124075', 'Egyptian_cat', 0.98292357), | ||
# ('n02123045', 'tabby', 0.009655442), | ||
# ('n02123159', 'tiger_cat', 0.0057404325), | ||
# ('n02127052', 'lynx', 0.00089362176), | ||
# ('n04209239', 'shower_curtain', 0.00013918217)] | ||
``` | ||
**Set new input resolution** | ||
```py | ||
from keras_cv_attention_models import resnext | ||
mm = resnext.ResNeXt101(input_shape=(320, 320, 3), num_classes=0) | ||
print(mm(np.ones([1, 320, 320, 3])).shape) | ||
# (1, 10, 10, 2048) | ||
|
||
mm = resnext.ResNeXt101(input_shape=(512, 512, 3), num_classes=0) | ||
print(mm(np.ones([1, 512, 512, 3])).shape) | ||
# (1, 16, 16, 2048) | ||
``` | ||
*** |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
from keras_cv_attention_models.resnet_family.resnext import ResNeXt, ResNeXt50, ResNeXt101, groups_depthwise | ||
|
||
|
||
__head_doc__ = """ | ||
Keras implementation of [Github facebookresearch/ResNeXt](https://github.com/facebookresearch/ResNeXt). | ||
Paper [PDF 1611.05431 Aggregated Residual Transformations for Deep Neural Networks](https://arxiv.org/pdf/1611.05431.pdf). | ||
""" | ||
|
||
__tail_doc__ = """ strides: a `number` or `list`, indicates strides used in the last stack or list value for all stacks. | ||
If a number, it will be `[1, 2, 2, strides]`. | ||
out_channels: default as `[128, 256, 512, 1024]`. Output channel for each stack. | ||
stem_width: output dimension for stem block. | ||
deep_stem: Boolean value if use deep stem. | ||
stem_downsample: Boolean value if ass `MaxPooling2D` layer after stem block. | ||
cardinality: Control channel expansion in each block, the bigger the widder. | ||
Also the `groups` number for `groups_depthwise` in each block, bigger `cardinality` leads to less `groups`. | ||
input_shape: it should have exactly 3 inputs channels, default `(224, 224, 3)`. | ||
num_classes: number of classes to classify images into. Set `0` to exclude top layers. | ||
activation: activation used in whole model, default `relu`. | ||
classifier_activation: A `str` or callable. The activation function to use on the "top" layer if `num_classes > 0`. | ||
Set `classifier_activation=None` to return the logits of the "top" layer. | ||
Default is `softmax`. | ||
pretrained: one of `None` (random initialization) or 'imagenet' (pre-training on ImageNet). | ||
Will try to download and load pre-trained model weights if not None. | ||
**kwargs: other parameters if available. | ||
Returns: | ||
A `keras.Model` instance. | ||
""" | ||
|
||
ResNeXt.__doc__ = __head_doc__ + """ | ||
Args: | ||
num_blocks: number of blocks in each stack. | ||
model_name: string, model name. | ||
""" + __tail_doc__ + """ | ||
Model architectures: | ||
| Model | Params | Image resolution | Top1 Acc | | ||
| -------------- | ------ | ----------------- | -------- | | ||
| resnext50 | 25M | 224 | 77.8 | | ||
| resnext101 | 42M | 224 | 80.9 | | ||
""" | ||
|
||
ResNeXt50.__doc__ = __head_doc__ + """ | ||
Args: | ||
""" + __tail_doc__ | ||
|
||
ResNeXt101.__doc__ = ResNeXt50.__doc__ | ||
|
||
groups_depthwise.__doc__ = __head_doc__ + """ | ||
Grouped depthwise. Callable function, NOT defined as a layer. | ||
Args: | ||
inputs: input tensor. | ||
groups: number of groups splitted for `DepthwiseConv2D` result. | ||
kernel_size: kernel size for `DepthwiseConv2D`. | ||
strides: strides for `DepthwiseConv2D`. | ||
padding: padding for `DepthwiseConv2D`. | ||
Examples: | ||
>>> from keras_cv_attention_models import attention_layers | ||
>>> inputs = keras.layers.Input([28, 28, 192]) | ||
>>> nn = attention_layers.groups_depthwise(inputs, groups=32) | ||
>>> dd = keras.models.Model(inputs, nn) | ||
>>> dd.output_shape | ||
(None, 28, 28, 192) | ||
>>> dd.summary() | ||
_________________________________________________________________ | ||
Layer (type) Output Shape Param # | ||
================================================================= | ||
input_2 (InputLayer) [(None, 28, 28, 192)] 0 | ||
_________________________________________________________________ | ||
zero_padding2d (ZeroPadding2 (None, 30, 30, 192) 0 | ||
_________________________________________________________________ | ||
depthwise_conv2d (DepthwiseC (None, 28, 28, 1152) 10368 | ||
_________________________________________________________________ | ||
reshape (Reshape) (None, 28, 28, 32, 6, 6) 0 | ||
_________________________________________________________________ | ||
tf.math.reduce_sum (TFOpLamb (None, 28, 28, 32, 6) 0 | ||
_________________________________________________________________ | ||
reshape_1 (Reshape) (None, 28, 28, 192) 0 | ||
================================================================= | ||
Total params: 10,368 | ||
Trainable params: 10,368 | ||
Non-trainable params: 0 | ||
_________________________________________________________________ | ||
""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
from keras_cv_attention_models.aotnet import AotNet | ||
import os | ||
|
||
def ResNetD(num_blocks, input_shape=(224, 224, 3), pretrained="imagenet", deep_stem=True, stem_width=32, strides=2, **kwargs): | ||
strides = strides if isinstance(strides, (list, tuple)) else [1, 2, 2, strides] | ||
model = AotNet(num_blocks, input_shape=input_shape, deep_stem=deep_stem, stem_width=stem_width, strides=strides, **kwargs) | ||
reload_model_weights(model, input_shape, pretrained) | ||
return model | ||
|
||
|
||
def reload_model_weights(model, input_shape=(224, 224, 3), pretrained="imagenet"): | ||
pretrained_dd = { | ||
"resnet50d": ["imagenet"], | ||
} | ||
if model.name not in pretrained_dd or pretrained not in pretrained_dd[model.name]: | ||
print(">>>> No pretraind available, model will be randomly initialized") | ||
return | ||
|
||
pre_url = "https://github.com/leondgarse/keras_cv_attention_models/releases/download/resnet_family/{}_{}.h5" | ||
url = pre_url.format(model.name, pretrained) | ||
file_name = os.path.basename(url) | ||
try: | ||
pretrained_model = keras.utils.get_file(file_name, url, cache_subdir="models") | ||
except: | ||
print("[Error] will not load weights, url not found or download failed:", url) | ||
return | ||
else: | ||
print(">>>> Load pretraind from:", pretrained_model) | ||
model.load_weights(pretrained_model, by_name=True, skip_mismatch=True) | ||
|
||
|
||
def ResNet50D(input_shape=(224, 224, 3), num_classes=1000, activation="relu", classifier_activation="softmax", pretrained="imagenet", **kwargs): | ||
num_blocks = [3, 4, 6, 3] | ||
return ResNetD(**locals(), model_name="resnet50d", **kwargs) | ||
|
||
|
||
def ResNet101D(input_shape=(224, 224, 3), num_classes=1000, activation="relu", classifier_activation="softmax", pretrained="imagenet", **kwargs): | ||
num_blocks = [3, 4, 23, 3] | ||
return ResNetD(**locals(), model_name="resnet101d", **kwargs) | ||
|
||
|
||
def ResNet152D(input_shape=(224, 224, 3), num_classes=1000, activation="relu", classifier_activation="softmax", pretrained="imagenet", **kwargs): | ||
num_blocks = [3, 8, 36, 3] | ||
return ResNetD(**locals(), model_name="resnet152d", **kwargs) | ||
|
||
|
||
def ResNet200D(input_shape=(224, 224, 3), num_classes=1000, activation="relu", classifier_activation="softmax", pretrained="imagenet", **kwargs): | ||
num_blocks = [3, 24, 36, 3] | ||
return ResNetD(**locals(), model_name="resnet200d", **kwargs) |
Oops, something went wrong.