Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ResNet model changes for PS paper #991

Draft
wants to merge 2 commits into
base: master
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 29 additions & 10 deletions models/official/resnet/resnet_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,7 +415,8 @@ def residual_block(inputs, filters, is_training, strides,
dropblock_keep_prob=None, dropblock_size=None,
pre_activation=False, norm_act_layer=LAYER_BN_RELU,
resnetd_shortcut=False, se_ratio=None,
drop_connect_rate=None, bn_momentum=MOVING_AVERAGE_DECAY):
drop_connect_rate=None, bn_momentum=MOVING_AVERAGE_DECAY,
second_conv=True):
"""Standard building block for residual networks with BN after convolutions.

Args:
Expand All @@ -442,6 +443,8 @@ def residual_block(inputs, filters, is_training, strides,
se_ratio: `float` or None. Squeeze-and-Excitation ratio for the SE layer.
drop_connect_rate: `float` or None. Drop connect rate for this block.
bn_momentum: `float` momentum for batch norm layer.
second_conv: Whether to apply second convolution:
Instead of conv(down) + norm + relu + conv + norm -> conv(down) + norm

Returns:
The output `Tensor` of the block.
Expand Down Expand Up @@ -469,12 +472,13 @@ def residual_block(inputs, filters, is_training, strides,
inputs = conv2d_fixed_padding(
inputs=inputs, filters=filters, kernel_size=3, strides=strides,
data_format=data_format)
inputs = norm_activation(inputs, is_training, data_format=data_format,
layer=norm_act_layer, bn_momentum=bn_momentum)
if second_conv:
inputs = norm_activation(inputs, is_training, data_format=data_format,
layer=norm_act_layer, bn_momentum=bn_momentum)

inputs = conv2d_fixed_padding(
inputs=inputs, filters=filters, kernel_size=3, strides=1,
data_format=data_format)
inputs = conv2d_fixed_padding(
inputs=inputs, filters=filters, kernel_size=3, strides=1,
data_format=data_format)
if pre_activation:
return inputs + shortcut
else:
Expand All @@ -490,7 +494,8 @@ def bottleneck_block(inputs, filters, is_training, strides,
dropblock_keep_prob=None, dropblock_size=None,
pre_activation=False, norm_act_layer=LAYER_BN_RELU,
resnetd_shortcut=False, se_ratio=None,
drop_connect_rate=None, bn_momentum=MOVING_AVERAGE_DECAY):
drop_connect_rate=None, bn_momentum=MOVING_AVERAGE_DECAY,
second_conv=True):
"""Bottleneck block variant for residual networks with BN after convolutions.

Args:
Expand All @@ -517,6 +522,8 @@ def bottleneck_block(inputs, filters, is_training, strides,
se_ratio: `float` or None. Squeeze-and-Excitation ratio for the SE layer.
drop_connect_rate: `float` or None. Drop connect rate for this block.
bn_momentum: `float` momentum for batch norm layer.
second_conv: Whether to apply second convolution:
Instead of conv(down) + norm + relu + conv + norm -> conv(down) + norm

Returns:
The output `Tensor` of the block.
Expand Down Expand Up @@ -629,6 +636,7 @@ def block_group(inputs, filters, block_fn, blocks, strides, is_training, name,
Returns:
The output `Tensor` of the block layer.
"""
second_conv = False if blocks == 0 else True
# Only the first block per block_group uses projection shortcut and strides.
inputs = block_fn(inputs, filters, is_training, strides,
use_projection=True, data_format=data_format,
Expand All @@ -639,7 +647,8 @@ def block_group(inputs, filters, block_fn, blocks, strides, is_training, name,
se_ratio=se_ratio,
resnetd_shortcut=resnetd_shortcut,
drop_connect_rate=drop_connect_rate,
bn_momentum=bn_momentum)
bn_momentum=bn_momentum,
second_conv=second_conv)

for _ in range(1, blocks):
inputs = block_fn(inputs, filters, is_training, 1,
Expand All @@ -651,7 +660,8 @@ def block_group(inputs, filters, block_fn, blocks, strides, is_training, name,
se_ratio=se_ratio,
resnetd_shortcut=resnetd_shortcut,
drop_connect_rate=drop_connect_rate,
bn_momentum=bn_momentum)
bn_momentum=bn_momentum,
second_conv=second_conv)

return tf.identity(inputs, name)

Expand Down Expand Up @@ -768,30 +778,38 @@ def model(inputs, is_training):
num_layers = len(layers) + 1
stride_c2 = 2 if skip_stem_max_pool else 1

endpoints = dict()
inputs = custom_block_group(
inputs=inputs, filters=64, block_fn=block_fn, blocks=layers[0],
strides=stride_c2, is_training=is_training, name='block_group1',
dropblock_keep_prob=dropblock_keep_probs[0],
drop_connect_rate=resnet_layers.get_drop_connect_rate(
drop_connect_rate, 2, num_layers))
endpoints['reduction_1'] = inputs

inputs = custom_block_group(
inputs=inputs, filters=128, block_fn=block_fn, blocks=layers[1],
strides=2, is_training=is_training, name='block_group2',
dropblock_keep_prob=dropblock_keep_probs[1],
drop_connect_rate=resnet_layers.get_drop_connect_rate(
drop_connect_rate, 3, num_layers))
endpoints['reduction_2'] = inputs

inputs = custom_block_group(
inputs=inputs, filters=256, block_fn=block_fn, blocks=layers[2],
strides=2, is_training=is_training, name='block_group3',
dropblock_keep_prob=dropblock_keep_probs[2],
drop_connect_rate=resnet_layers.get_drop_connect_rate(
drop_connect_rate, 4, num_layers))
endpoints['reduction_3'] = inputs

inputs = custom_block_group(
inputs=inputs, filters=512, block_fn=block_fn, blocks=layers[3],
strides=2, is_training=is_training, name='block_group4',
dropblock_keep_prob=dropblock_keep_probs[3],
drop_connect_rate=resnet_layers.get_drop_connect_rate(
drop_connect_rate, 5, num_layers))
endpoints['reduction_4'] = inputs

if pre_activation:
inputs = norm_activation(inputs, is_training, data_format=data_format,
Expand Down Expand Up @@ -820,7 +838,7 @@ def model(inputs, is_training):
units=num_classes,
kernel_initializer=tf.random_normal_initializer(stddev=.01))
inputs = tf.identity(inputs, 'final_dense')
return inputs
return inputs, endpoints

model.default_image_size = 224
return model
Expand All @@ -835,6 +853,7 @@ def resnet(resnet_depth, num_classes, data_format='channels_first',
bn_momentum=MOVING_AVERAGE_DECAY):
"""Returns the ResNet model for a given size and number of output classes."""
model_params = {
6: {'block': residual_block, 'layers': [0, 0, 0, 0]},
18: {'block': residual_block, 'layers': [2, 2, 2, 2]},
34: {'block': residual_block, 'layers': [3, 4, 6, 3]},
50: {'block': bottleneck_block, 'layers': [3, 4, 6, 3]},
Expand Down