-
Notifications
You must be signed in to change notification settings - Fork 1.8k
/
Copy pathefficientnet_builder.py
331 lines (281 loc) · 11.7 KB
/
efficientnet_builder.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Model Builder for EfficientNet."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import functools
import os
import re
from absl import logging
import numpy as np
import six
import tensorflow.compat.v1 as tf
import efficientnet_model
import utils
MEAN_RGB = [0.485 * 255, 0.456 * 255, 0.406 * 255]
STDDEV_RGB = [0.229 * 255, 0.224 * 255, 0.225 * 255]
def efficientnet_params(model_name):
"""Get efficientnet params based on model name."""
params_dict = {
# (width_coefficient, depth_coefficient, resolution, dropout_rate)
'efficientnet-b0': (1.0, 1.0, 224, 0.2),
'efficientnet-b1': (1.0, 1.1, 240, 0.2),
'efficientnet-b2': (1.1, 1.2, 260, 0.3),
'efficientnet-b3': (1.2, 1.4, 300, 0.3),
'efficientnet-b4': (1.4, 1.8, 380, 0.4),
'efficientnet-b5': (1.6, 2.2, 456, 0.4),
'efficientnet-b6': (1.8, 2.6, 528, 0.5),
'efficientnet-b7': (2.0, 3.1, 600, 0.5),
'efficientnet-b8': (2.2, 3.6, 672, 0.5),
'efficientnet-l2': (4.3, 5.3, 800, 0.5),
}
return params_dict[model_name]
class BlockDecoder(object):
"""Block Decoder for readability."""
def _decode_block_string(self, block_string):
"""Gets a block through a string notation of arguments."""
if six.PY2:
assert isinstance(block_string, (str, unicode))
else:
assert isinstance(block_string, str)
ops = block_string.split('_')
options = {}
for op in ops:
splits = re.split(r'(\d.*)', op)
if len(splits) >= 2:
key, value = splits[:2]
options[key] = value
if 's' not in options or len(options['s']) != 2:
raise ValueError('Strides options should be a pair of integers.')
return efficientnet_model.BlockArgs(
kernel_size=int(options['k']),
num_repeat=int(options['r']),
input_filters=int(options['i']),
output_filters=int(options['o']),
expand_ratio=int(options['e']),
id_skip=('noskip' not in block_string),
se_ratio=float(options['se']) if 'se' in options else None,
strides=[int(options['s'][0]),
int(options['s'][1])],
conv_type=int(options['c']) if 'c' in options else 0,
fused_conv=int(options['f']) if 'f' in options else 0,
space2depth=int(options['d']) if 'd' in options else 0,
condconv=('cc' in block_string),
activation_fn=(tf.nn.relu if int(options['a']) == 0
else tf.nn.swish) if 'a' in options else None)
def _encode_block_string(self, block):
"""Encodes a block to a string."""
args = [
'r%d' % block.num_repeat,
'k%d' % block.kernel_size,
's%d%d' % (block.strides[0], block.strides[1]),
'e%s' % block.expand_ratio,
'i%d' % block.input_filters,
'o%d' % block.output_filters,
'c%d' % block.conv_type,
'f%d' % block.fused_conv,
'd%d' % block.space2depth,
]
if block.se_ratio > 0 and block.se_ratio <= 1:
args.append('se%s' % block.se_ratio)
if block.id_skip is False: # pylint: disable=g-bool-id-comparison
args.append('noskip')
if block.condconv:
args.append('cc')
return '_'.join(args)
def decode(self, string_list):
"""Decodes a list of string notations to specify blocks inside the network.
Args:
string_list: a list of strings, each string is a notation of block.
Returns:
A list of namedtuples to represent blocks arguments.
"""
assert isinstance(string_list, list)
blocks_args = []
for block_string in string_list:
blocks_args.append(self._decode_block_string(block_string))
return blocks_args
def encode(self, blocks_args):
"""Encodes a list of Blocks to a list of strings.
Args:
blocks_args: A list of namedtuples to represent blocks arguments.
Returns:
a list of strings, each string is a notation of block.
"""
block_strings = []
for block in blocks_args:
block_strings.append(self._encode_block_string(block))
return block_strings
def swish(features, use_native=True, use_hard=False):
"""Computes the Swish activation function.
We provide three alternnatives:
- Native tf.nn.swish, use less memory during training than composable swish.
- Quantization friendly hard swish.
- A composable swish, equivalant to tf.nn.swish, but more general for
finetuning and TF-Hub.
Args:
features: A `Tensor` representing preactivation values.
use_native: Whether to use the native swish from tf.nn that uses a custom
gradient to reduce memory usage, or to use customized swish that uses
default TensorFlow gradient computation.
use_hard: Whether to use quantization-friendly hard swish.
Returns:
The activation value.
"""
if use_native and use_hard:
raise ValueError('Cannot specify both use_native and use_hard.')
if use_native:
return tf.nn.swish(features)
if use_hard:
return features * tf.nn.relu6(features + np.float32(3)) * (1. / 6.)
features = tf.convert_to_tensor(features, name='features')
return features * tf.nn.sigmoid(features)
_DEFAULT_BLOCKS_ARGS = [
'r1_k3_s11_e1_i32_o16_se0.25', 'r2_k3_s22_e6_i16_o24_se0.25',
'r2_k5_s22_e6_i24_o40_se0.25', 'r3_k3_s22_e6_i40_o80_se0.25',
'r3_k5_s11_e6_i80_o112_se0.25', 'r4_k5_s22_e6_i112_o192_se0.25',
'r1_k3_s11_e6_i192_o320_se0.25',
]
def efficientnet(width_coefficient=None,
depth_coefficient=None,
dropout_rate=0.2,
survival_prob=0.8):
"""Creates a efficientnet model."""
global_params = efficientnet_model.GlobalParams(
blocks_args=_DEFAULT_BLOCKS_ARGS,
batch_norm_momentum=0.99,
batch_norm_epsilon=1e-3,
dropout_rate=dropout_rate,
survival_prob=survival_prob,
data_format='channels_last',
num_classes=1000,
width_coefficient=width_coefficient,
depth_coefficient=depth_coefficient,
depth_divisor=8,
min_depth=None,
relu_fn=tf.nn.swish,
# The default is TPU-specific batch norm.
# The alternative is tf.layers.BatchNormalization.
batch_norm=utils.train_batch_norm, # TPU-specific requirement.
use_se=True,
clip_projection_output=False)
return global_params
def get_model_params(model_name, override_params):
"""Get the block args and global params for a given model."""
if model_name.startswith('efficientnet'):
width_coefficient, depth_coefficient, _, dropout_rate = (
efficientnet_params(model_name))
global_params = efficientnet(
width_coefficient, depth_coefficient, dropout_rate)
else:
raise NotImplementedError('model name is not pre-defined: %s' % model_name)
if override_params:
# ValueError will be raised here if override_params has fields not included
# in global_params.
global_params = global_params._replace(**override_params)
decoder = BlockDecoder()
blocks_args = decoder.decode(global_params.blocks_args)
logging.info('global_params= %s', global_params)
return blocks_args, global_params
def build_model(images,
model_name,
training,
override_params=None,
model_dir=None,
fine_tuning=False,
features_only=False,
pooled_features_only=False):
"""A helper function to create a model and return predicted logits.
Args:
images: input images tensor.
model_name: string, the predefined model name.
training: boolean, whether the model is constructed for training.
override_params: A dictionary of params for overriding. Fields must exist in
efficientnet_model.GlobalParams.
model_dir: string, optional model dir for saving configs.
fine_tuning: boolean, whether the model is used for finetuning.
features_only: build the base feature network only (excluding final
1x1 conv layer, global pooling, dropout and fc head).
pooled_features_only: build the base network for features extraction (after
1x1 conv layer and global pooling, but before dropout and fc head).
Returns:
logits: the logits tensor of classes.
endpoints: the endpoints for each layer.
Raises:
When model_name specified an undefined model, raises NotImplementedError.
When override_params has invalid fields, raises ValueError.
"""
assert isinstance(images, tf.Tensor)
assert not (features_only and pooled_features_only)
# For backward compatibility.
if override_params and override_params.get('drop_connect_rate', None):
override_params['survival_prob'] = 1 - override_params['drop_connect_rate']
if not training or fine_tuning:
if not override_params:
override_params = {}
override_params['batch_norm'] = utils.eval_batch_norm
if fine_tuning:
override_params['relu_fn'] = functools.partial(swish, use_native=False)
blocks_args, global_params = get_model_params(model_name, override_params)
if model_dir:
param_file = os.path.join(model_dir, 'model_params.txt')
if not tf.gfile.Exists(param_file):
if not tf.gfile.Exists(model_dir):
tf.gfile.MakeDirs(model_dir)
with tf.gfile.GFile(param_file, 'w') as f:
logging.info('writing to %s', param_file)
f.write('model_name= %s\n\n' % model_name)
f.write('global_params= %s\n\n' % str(global_params))
f.write('blocks_args= %s\n\n' % str(blocks_args))
with tf.variable_scope(model_name):
model = efficientnet_model.Model(blocks_args, global_params)
outputs = model(
images,
training=training,
features_only=features_only,
pooled_features_only=pooled_features_only)
if features_only:
outputs = tf.identity(outputs, 'features')
elif pooled_features_only:
outputs = tf.identity(outputs, 'pooled_features')
else:
outputs = tf.identity(outputs, 'logits')
return outputs, model.endpoints
def build_model_base(images, model_name, training, override_params=None):
"""Create a base feature network and return the features before pooling.
Args:
images: input images tensor.
model_name: string, the predefined model name.
training: boolean, whether the model is constructed for training.
override_params: A dictionary of params for overriding. Fields must exist in
efficientnet_model.GlobalParams.
Returns:
features: base features before pooling.
endpoints: the endpoints for each layer.
Raises:
When model_name specified an undefined model, raises NotImplementedError.
When override_params has invalid fields, raises ValueError.
"""
assert isinstance(images, tf.Tensor)
# For backward compatibility.
if override_params and override_params.get('drop_connect_rate', None):
override_params['survival_prob'] = 1 - override_params['drop_connect_rate']
blocks_args, global_params = get_model_params(model_name, override_params)
with tf.variable_scope(model_name):
model = efficientnet_model.Model(blocks_args, global_params)
features = model(images, training=training, features_only=True)
features = tf.identity(features, 'features')
return features, model.endpoints