|
8 | 8 | def convert_backbone_config(timm_config):
|
9 | 9 | timm_architecture = timm_config["architecture"]
|
10 | 10 |
|
11 |
| - if "mobilenetv3_" in timm_architecture: |
12 |
| - input_activation = "hard_swish" |
13 |
| - output_activation = "hard_swish" |
14 |
| - else: |
15 |
| - input_activation = "relu6" |
16 |
| - output_activation = "relu6" |
17 |
| - |
18 |
| - if timm_architecture == "mobilenetv3_small_050": |
19 |
| - stackwise_num_blocks = [2, 3, 2, 3] |
20 |
| - stackwise_expansion = [ |
| 11 | + kwargs = { |
| 12 | + "stackwise_num_blocks": [2, 3, 2, 3], |
| 13 | + "stackwise_expansion": [ |
21 | 14 | [40, 56],
|
22 | 15 | [64, 144, 144],
|
23 | 16 | [72, 72],
|
24 | 17 | [144, 288, 288],
|
25 |
| - ] |
26 |
| - stackwise_num_filters = [[16, 16], [24, 24, 24], [24, 24], [48, 48, 48]] |
27 |
| - stackwise_kernel_size = [[3, 3], [5, 5, 5], [5, 5], [5, 5, 5]] |
28 |
| - stackwise_num_strides = [[2, 1], [2, 1, 1], [1, 1], [2, 1, 1]] |
29 |
| - stackwise_se_ratio = [ |
| 18 | + ], |
| 19 | + "stackwise_num_filters": [ |
| 20 | + [16, 16], |
| 21 | + [24, 24, 24], |
| 22 | + [24, 24], |
| 23 | + [48, 48, 48], |
| 24 | + ], |
| 25 | + "stackwise_kernel_size": [[3, 3], [5, 5, 5], [5, 5], [5, 5, 5]], |
| 26 | + "stackwise_num_strides": [[2, 1], [2, 1, 1], [1, 1], [2, 1, 1]], |
| 27 | + "stackwise_se_ratio": [ |
30 | 28 | [None, None],
|
31 | 29 | [0.25, 0.25, 0.25],
|
32 | 30 | [0.25, 0.25],
|
33 | 31 | [0.25, 0.25, 0.25],
|
34 |
| - ] |
35 |
| - stackwise_activation = [ |
| 32 | + ], |
| 33 | + "stackwise_activation": [ |
36 | 34 | ["relu", "relu"],
|
37 | 35 | ["hard_swish", "hard_swish", "hard_swish"],
|
38 | 36 | ["hard_swish", "hard_swish"],
|
39 | 37 | ["hard_swish", "hard_swish", "hard_swish"],
|
40 |
| - ] |
41 |
| - stackwise_padding = [[1, 1], [2, 2, 2], [2, 2], [2, 2, 2]] |
42 |
| - output_num_filters = 1024 |
43 |
| - input_num_filters = 16 |
44 |
| - depthwise_filters = 8 |
45 |
| - squeeze_and_excite = 0.5 |
46 |
| - last_layer_filter = 288 |
| 38 | + ], |
| 39 | + "stackwise_padding": [[1, 1], [2, 2, 2], [2, 2], [2, 2, 2]], |
| 40 | + "output_num_filters": 1024, |
| 41 | + "input_num_filters": 16, |
| 42 | + "depthwise_filters": 8, |
| 43 | + "depthwise_stride": 2, |
| 44 | + "depthwise_residual": False, |
| 45 | + "squeeze_and_excite": 0.5, |
| 46 | + "last_layer_filter": 288, |
| 47 | + "input_activation": "relu6", |
| 48 | + "output_activation": "relu6", |
| 49 | + } |
| 50 | + |
| 51 | + if "mobilenetv3_" in timm_architecture: |
| 52 | + kwargs["input_activation"] = "hard_swish" |
| 53 | + kwargs["output_activation"] = "hard_swish" |
| 54 | + |
| 55 | + if timm_architecture == "mobilenetv3_small_050": |
| 56 | + pass |
| 57 | + elif timm_architecture == "mobilenetv3_small_100": |
| 58 | + modified_kwargs = { |
| 59 | + "stackwise_expansion": [ |
| 60 | + [72, 88], |
| 61 | + [96, 240, 240], |
| 62 | + [120, 144], |
| 63 | + [288, 576, 576], |
| 64 | + ], |
| 65 | + "stackwise_num_filters": [ |
| 66 | + [24, 24], |
| 67 | + [40, 40, 40], |
| 68 | + [48, 48], |
| 69 | + [96, 96, 96], |
| 70 | + ], |
| 71 | + "depthwise_filters": 16, |
| 72 | + "last_layer_filter": 576, |
| 73 | + } |
| 74 | + kwargs.update(modified_kwargs) |
| 75 | + elif timm_architecture.startswith("mobilenetv3_large_100"): |
| 76 | + modified_kwargs = { |
| 77 | + "stackwise_num_blocks": [2, 3, 4, 2, 3], |
| 78 | + "stackwise_expansion": [ |
| 79 | + [64, 72], |
| 80 | + [72, 120, 120], |
| 81 | + [240, 200, 184, 184], |
| 82 | + [480, 672], |
| 83 | + [672, 960, 960], |
| 84 | + ], |
| 85 | + "stackwise_num_filters": [ |
| 86 | + [24, 24], |
| 87 | + [40, 40, 40], |
| 88 | + [80, 80, 80, 80], |
| 89 | + [112, 112], |
| 90 | + [160, 160, 160], |
| 91 | + ], |
| 92 | + "stackwise_kernel_size": [ |
| 93 | + [3, 3], |
| 94 | + [5, 5, 5], |
| 95 | + [3, 3, 3, 3], |
| 96 | + [3, 3], |
| 97 | + [5, 5, 5], |
| 98 | + ], |
| 99 | + "stackwise_num_strides": [ |
| 100 | + [2, 1], |
| 101 | + [2, 1, 1], |
| 102 | + [2, 1, 1, 1], |
| 103 | + [1, 1], |
| 104 | + [2, 1, 1], |
| 105 | + ], |
| 106 | + "stackwise_se_ratio": [ |
| 107 | + [None, None], |
| 108 | + [0.25, 0.25, 0.25], |
| 109 | + [None, None, None, None], |
| 110 | + [0.25, 0.25], |
| 111 | + [0.25, 0.25, 0.25], |
| 112 | + ], |
| 113 | + "stackwise_activation": [ |
| 114 | + ["relu", "relu"], |
| 115 | + ["relu", "relu", "relu"], |
| 116 | + ["hard_swish", "hard_swish", "hard_swish", "hard_swish"], |
| 117 | + ["hard_swish", "hard_swish"], |
| 118 | + ["hard_swish", "hard_swish", "hard_swish"], |
| 119 | + ], |
| 120 | + "stackwise_padding": [ |
| 121 | + [1, 1], |
| 122 | + [2, 2, 2], |
| 123 | + [1, 1, 1, 1], |
| 124 | + [1, 1], |
| 125 | + [2, 2, 2], |
| 126 | + ], |
| 127 | + "depthwise_filters": 16, |
| 128 | + "depthwise_stride": 1, |
| 129 | + "depthwise_residual": True, |
| 130 | + "squeeze_and_excite": None, |
| 131 | + "last_layer_filter": 960, |
| 132 | + } |
| 133 | + kwargs.update(modified_kwargs) |
47 | 134 | else:
|
48 | 135 | raise ValueError(
|
49 | 136 | f"Currently, the architecture {timm_architecture} is not supported."
|
50 | 137 | )
|
51 | 138 |
|
52 |
| - return dict( |
53 |
| - input_num_filters=input_num_filters, |
54 |
| - input_activation=input_activation, |
55 |
| - depthwise_filters=depthwise_filters, |
56 |
| - squeeze_and_excite=squeeze_and_excite, |
57 |
| - stackwise_num_blocks=stackwise_num_blocks, |
58 |
| - stackwise_expansion=stackwise_expansion, |
59 |
| - stackwise_num_filters=stackwise_num_filters, |
60 |
| - stackwise_kernel_size=stackwise_kernel_size, |
61 |
| - stackwise_num_strides=stackwise_num_strides, |
62 |
| - stackwise_se_ratio=stackwise_se_ratio, |
63 |
| - stackwise_activation=stackwise_activation, |
64 |
| - stackwise_padding=stackwise_padding, |
65 |
| - output_num_filters=output_num_filters, |
66 |
| - output_activation=output_activation, |
67 |
| - last_layer_filter=last_layer_filter, |
68 |
| - ) |
| 139 | + return kwargs |
69 | 140 |
|
70 | 141 |
|
71 | 142 | def convert_weights(backbone, loader, timm_config):
|
@@ -120,9 +191,14 @@ def port_batch_normalization(keras_layer, hf_weight_prefix):
|
120 | 191 | port_conv2d(stem_block.conv1, f"{hf_name}.conv_dw")
|
121 | 192 | port_batch_normalization(stem_block.batch_normalization1, f"{hf_name}.bn1")
|
122 | 193 |
|
123 |
| - stem_se_block = stem_block.se_layer |
124 |
| - port_conv2d(stem_se_block.conv_reduce, f"{hf_name}.se.conv_reduce", True) |
125 |
| - port_conv2d(stem_se_block.conv_expand, f"{hf_name}.se.conv_expand", True) |
| 194 | + if stem_block.squeeze_excite_ratio: |
| 195 | + stem_se_block = stem_block.se_layer |
| 196 | + port_conv2d( |
| 197 | + stem_se_block.conv_reduce, f"{hf_name}.se.conv_reduce", True |
| 198 | + ) |
| 199 | + port_conv2d( |
| 200 | + stem_se_block.conv_expand, f"{hf_name}.se.conv_expand", True |
| 201 | + ) |
126 | 202 |
|
127 | 203 | port_conv2d(stem_block.conv2, f"{hf_name}.conv_pw")
|
128 | 204 | port_batch_normalization(stem_block.batch_normalization2, f"{hf_name}.bn2")
|
|
0 commit comments