Skip to content

Commit e9e3645

Browse files
committed
add top3 HF presets
1 parent d1c14ab commit e9e3645

File tree

5 files changed

+151
-47
lines changed

5 files changed

+151
-47
lines changed

keras_hub/src/models/mobilenet/mobilenet_backbone.py

+16-1
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,8 @@ class DepthwiseConvBlock(keras.layers.Layer):
142142
signal into before reexciting back out. If (>1) technically, it's an
143143
excite & squeeze layer. If this doesn't exist there is no
144144
SqueezeExcite layer.
145+
residual: bool, default False. True if we want a residual connection. If
146+
False, there is no residual connection.
145147
name: str, name of the layer
146148
dtype: `None` or str or `keras.mixed_precision.DTypePolicy`. The dtype
147149
to use for the model's computations and weights.
@@ -161,6 +163,7 @@ def __init__(
161163
kernel_size=3,
162164
stride=2,
163165
squeeze_excite_ratio=None,
166+
residual=False,
164167
name=None,
165168
dtype=None,
166169
**kwargs,
@@ -171,6 +174,7 @@ def __init__(
171174
self.kernel_size = kernel_size
172175
self.stride = stride
173176
self.squeeze_excite_ratio = squeeze_excite_ratio
177+
self.residual = residual
174178
self.name = name
175179

176180
channel_axis = (
@@ -256,11 +260,15 @@ def call(self, inputs):
256260
x = self.batch_normalization1(x)
257261
x = self.activation1(x)
258262

259-
if self.se_layer:
263+
if self.squeeze_excite_ratio:
260264
x = self.se_layer(x)
261265

262266
x = self.conv2(x)
263267
x = self.batch_normalization2(x)
268+
269+
if self.residual:
270+
x = x + inputs
271+
264272
return x
265273

266274
def get_config(self):
@@ -272,6 +280,7 @@ def get_config(self):
272280
"kernel_size": self.kernel_size,
273281
"stride": self.stride,
274282
"squeeze_excite_ratio": self.squeeze_excite_ratio,
283+
"residual": self.residual,
275284
"name": self.name,
276285
}
277286
)
@@ -675,6 +684,8 @@ def __init__(
675684
stackwise_padding,
676685
output_num_filters,
677686
depthwise_filters,
687+
depthwise_stride,
688+
depthwise_residual,
678689
last_layer_filter,
679690
squeeze_and_excite=None,
680691
image_shape=(None, None, 3),
@@ -722,7 +733,9 @@ def __init__(
722733
x = DepthwiseConvBlock(
723734
input_num_filters,
724735
depthwise_filters,
736+
stride=depthwise_stride,
725737
squeeze_excite_ratio=squeeze_and_excite,
738+
residual=depthwise_residual,
726739
name="block_0",
727740
dtype=dtype,
728741
)(x)
@@ -768,6 +781,7 @@ def __init__(
768781
self.input_num_filters = input_num_filters
769782
self.output_num_filters = output_num_filters
770783
self.depthwise_filters = depthwise_filters
784+
self.depthwise_stride = depthwise_stride
771785
self.last_layer_filter = last_layer_filter
772786
self.squeeze_and_excite = squeeze_and_excite
773787
self.input_activation = input_activation
@@ -790,6 +804,7 @@ def get_config(self):
790804
"input_num_filters": self.input_num_filters,
791805
"output_num_filters": self.output_num_filters,
792806
"depthwise_filters": self.depthwise_filters,
807+
"depthwise_stride": self.depthwise_stride,
793808
"last_layer_filter": self.last_layer_filter,
794809
"squeeze_and_excite": self.squeeze_and_excite,
795810
"input_activation": self.input_activation,

keras_hub/src/models/mobilenet/mobilenet_image_classifier.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ def __init__(
1818
self,
1919
backbone,
2020
num_classes,
21+
num_features=1024,
2122
preprocessor=None,
2223
head_dtype=None,
2324
**kwargs,
@@ -33,7 +34,7 @@ def __init__(
3334
)
3435

3536
self.output_conv = keras.layers.Conv2D(
36-
filters=1024,
37+
filters=num_features,
3738
kernel_size=(1, 1),
3839
strides=(1, 1),
3940
use_bias=True,

keras_hub/src/utils/preset_utils.py

+3
Original file line numberDiff line numberDiff line change
@@ -622,6 +622,9 @@ def load_task(self, cls, load_weights, load_task_weights, **kwargs):
622622
kwargs["preprocessor"] = self.load_preprocessor(
623623
cls.preprocessor_cls,
624624
)
625+
if "num_features" not in kwargs and "num_features" in self.config:
626+
kwargs["num_features"] = self.config["num_features"]
627+
625628
return cls(**kwargs)
626629

627630
def load_preprocessor(

keras_hub/src/utils/timm/convert_mobilenet.py

+120-44
Original file line numberDiff line numberDiff line change
@@ -8,64 +8,135 @@
88
def convert_backbone_config(timm_config):
99
timm_architecture = timm_config["architecture"]
1010

11-
if "mobilenetv3_" in timm_architecture:
12-
input_activation = "hard_swish"
13-
output_activation = "hard_swish"
14-
else:
15-
input_activation = "relu6"
16-
output_activation = "relu6"
17-
18-
if timm_architecture == "mobilenetv3_small_050":
19-
stackwise_num_blocks = [2, 3, 2, 3]
20-
stackwise_expansion = [
11+
kwargs = {
12+
"stackwise_num_blocks": [2, 3, 2, 3],
13+
"stackwise_expansion": [
2114
[40, 56],
2215
[64, 144, 144],
2316
[72, 72],
2417
[144, 288, 288],
25-
]
26-
stackwise_num_filters = [[16, 16], [24, 24, 24], [24, 24], [48, 48, 48]]
27-
stackwise_kernel_size = [[3, 3], [5, 5, 5], [5, 5], [5, 5, 5]]
28-
stackwise_num_strides = [[2, 1], [2, 1, 1], [1, 1], [2, 1, 1]]
29-
stackwise_se_ratio = [
18+
],
19+
"stackwise_num_filters": [
20+
[16, 16],
21+
[24, 24, 24],
22+
[24, 24],
23+
[48, 48, 48],
24+
],
25+
"stackwise_kernel_size": [[3, 3], [5, 5, 5], [5, 5], [5, 5, 5]],
26+
"stackwise_num_strides": [[2, 1], [2, 1, 1], [1, 1], [2, 1, 1]],
27+
"stackwise_se_ratio": [
3028
[None, None],
3129
[0.25, 0.25, 0.25],
3230
[0.25, 0.25],
3331
[0.25, 0.25, 0.25],
34-
]
35-
stackwise_activation = [
32+
],
33+
"stackwise_activation": [
3634
["relu", "relu"],
3735
["hard_swish", "hard_swish", "hard_swish"],
3836
["hard_swish", "hard_swish"],
3937
["hard_swish", "hard_swish", "hard_swish"],
40-
]
41-
stackwise_padding = [[1, 1], [2, 2, 2], [2, 2], [2, 2, 2]]
42-
output_num_filters = 1024
43-
input_num_filters = 16
44-
depthwise_filters = 8
45-
squeeze_and_excite = 0.5
46-
last_layer_filter = 288
38+
],
39+
"stackwise_padding": [[1, 1], [2, 2, 2], [2, 2], [2, 2, 2]],
40+
"output_num_filters": 1024,
41+
"input_num_filters": 16,
42+
"depthwise_filters": 8,
43+
"depthwise_stride": 2,
44+
"depthwise_residual": False,
45+
"squeeze_and_excite": 0.5,
46+
"last_layer_filter": 288,
47+
"input_activation": "relu6",
48+
"output_activation": "relu6",
49+
}
50+
51+
if "mobilenetv3_" in timm_architecture:
52+
kwargs["input_activation"] = "hard_swish"
53+
kwargs["output_activation"] = "hard_swish"
54+
55+
if timm_architecture == "mobilenetv3_small_050":
56+
pass
57+
elif timm_architecture == "mobilenetv3_small_100":
58+
modified_kwargs = {
59+
"stackwise_expansion": [
60+
[72, 88],
61+
[96, 240, 240],
62+
[120, 144],
63+
[288, 576, 576],
64+
],
65+
"stackwise_num_filters": [
66+
[24, 24],
67+
[40, 40, 40],
68+
[48, 48],
69+
[96, 96, 96],
70+
],
71+
"depthwise_filters": 16,
72+
"last_layer_filter": 576,
73+
}
74+
kwargs.update(modified_kwargs)
75+
elif timm_architecture.startswith("mobilenetv3_large_100"):
76+
modified_kwargs = {
77+
"stackwise_num_blocks": [2, 3, 4, 2, 3],
78+
"stackwise_expansion": [
79+
[64, 72],
80+
[72, 120, 120],
81+
[240, 200, 184, 184],
82+
[480, 672],
83+
[672, 960, 960],
84+
],
85+
"stackwise_num_filters": [
86+
[24, 24],
87+
[40, 40, 40],
88+
[80, 80, 80, 80],
89+
[112, 112],
90+
[160, 160, 160],
91+
],
92+
"stackwise_kernel_size": [
93+
[3, 3],
94+
[5, 5, 5],
95+
[3, 3, 3, 3],
96+
[3, 3],
97+
[5, 5, 5],
98+
],
99+
"stackwise_num_strides": [
100+
[2, 1],
101+
[2, 1, 1],
102+
[2, 1, 1, 1],
103+
[1, 1],
104+
[2, 1, 1],
105+
],
106+
"stackwise_se_ratio": [
107+
[None, None],
108+
[0.25, 0.25, 0.25],
109+
[None, None, None, None],
110+
[0.25, 0.25],
111+
[0.25, 0.25, 0.25],
112+
],
113+
"stackwise_activation": [
114+
["relu", "relu"],
115+
["relu", "relu", "relu"],
116+
["hard_swish", "hard_swish", "hard_swish", "hard_swish"],
117+
["hard_swish", "hard_swish"],
118+
["hard_swish", "hard_swish", "hard_swish"],
119+
],
120+
"stackwise_padding": [
121+
[1, 1],
122+
[2, 2, 2],
123+
[1, 1, 1, 1],
124+
[1, 1],
125+
[2, 2, 2],
126+
],
127+
"depthwise_filters": 16,
128+
"depthwise_stride": 1,
129+
"depthwise_residual": True,
130+
"squeeze_and_excite": None,
131+
"last_layer_filter": 960,
132+
}
133+
kwargs.update(modified_kwargs)
47134
else:
48135
raise ValueError(
49136
f"Currently, the architecture {timm_architecture} is not supported."
50137
)
51138

52-
return dict(
53-
input_num_filters=input_num_filters,
54-
input_activation=input_activation,
55-
depthwise_filters=depthwise_filters,
56-
squeeze_and_excite=squeeze_and_excite,
57-
stackwise_num_blocks=stackwise_num_blocks,
58-
stackwise_expansion=stackwise_expansion,
59-
stackwise_num_filters=stackwise_num_filters,
60-
stackwise_kernel_size=stackwise_kernel_size,
61-
stackwise_num_strides=stackwise_num_strides,
62-
stackwise_se_ratio=stackwise_se_ratio,
63-
stackwise_activation=stackwise_activation,
64-
stackwise_padding=stackwise_padding,
65-
output_num_filters=output_num_filters,
66-
output_activation=output_activation,
67-
last_layer_filter=last_layer_filter,
68-
)
139+
return kwargs
69140

70141

71142
def convert_weights(backbone, loader, timm_config):
@@ -120,9 +191,14 @@ def port_batch_normalization(keras_layer, hf_weight_prefix):
120191
port_conv2d(stem_block.conv1, f"{hf_name}.conv_dw")
121192
port_batch_normalization(stem_block.batch_normalization1, f"{hf_name}.bn1")
122193

123-
stem_se_block = stem_block.se_layer
124-
port_conv2d(stem_se_block.conv_reduce, f"{hf_name}.se.conv_reduce", True)
125-
port_conv2d(stem_se_block.conv_expand, f"{hf_name}.se.conv_expand", True)
194+
if stem_block.squeeze_excite_ratio:
195+
stem_se_block = stem_block.se_layer
196+
port_conv2d(
197+
stem_se_block.conv_reduce, f"{hf_name}.se.conv_reduce", True
198+
)
199+
port_conv2d(
200+
stem_se_block.conv_expand, f"{hf_name}.se.conv_expand", True
201+
)
126202

127203
port_conv2d(stem_block.conv2, f"{hf_name}.conv_pw")
128204
port_batch_normalization(stem_block.batch_normalization2, f"{hf_name}.bn2")

tools/checkpoint_conversion/convert_mobilenet_checkpoints.py

+10-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,13 @@
11
"""Convert mobilenet checkpoints.
22
33
python tools/checkpoint_conversion/convert_mobilenet_checkpoints.py \
4-
--preset mobilenetv3_small_050 --upload_uri kaggle://alexbutcher/mobilenet/keras/mobilenetv3_small_050
4+
--preset mobilenetv3_small_050 --upload_uri kaggle://keras/mobilenetv3/keras/mobilenet_v3_small_050_imagenet/1
5+
python tools/checkpoint_conversion/convert_mobilenet_checkpoints.py \
6+
--preset mobilenetv3_small_100 --upload_uri kaggle://keras/mobilenetv3/keras/mobilenet_v3_small_100_imagenet/1
7+
python tools/checkpoint_conversion/convert_mobilenet_checkpoints.py \
8+
--preset mobilenetv3_large_100.ra_in1k --upload_uri kaggle://keras/mobilenetv3/keras/mobilenet_v3_large_100_imagenet/1
9+
python tools/checkpoint_conversion/convert_mobilenet_checkpoints.py \
10+
--preset mobilenetv3_large_100.miil_in21k_ft_in1k --upload_uri kaggle://keras/mobilenetv3/keras/mobilenet_v3_large_100_imagenet_21k/1
511
"""
612

713
import os
@@ -19,6 +25,9 @@
1925

2026
PRESET_MAP = {
2127
"mobilenetv3_small_050": "timm/mobilenetv3_small_050.lamb_in1k",
28+
"mobilenetv3_small_100": "timm/mobilenetv3_small_100.lamb_in1k",
29+
"mobilenetv3_large_100.ra_in1k": "timm/mobilenetv3_large_100.ra_in1k",
30+
"mobilenetv3_large_100.miil_in21k_ft_in1k": "timm/mobilenetv3_large_100.miil_in21k_ft_in1k", # noqa: E501
2231
}
2332
FLAGS = flags.FLAGS
2433

0 commit comments

Comments
 (0)