Skip to content

Commit

Permalink
fix private and public variable naming in ofa xception
Browse files Browse the repository at this point in the history
  • Loading branch information
richeekSony committed Jun 28, 2022
1 parent 81d4af2 commit 8aab0ff
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 78 deletions.
16 changes: 8 additions & 8 deletions nnabla_nas/contrib/classification/ofa/networks/ofa_xception.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def __init__(self,
last_channel = width_list[-1]

# list of max supported depth for each block in the middle flow
self.middle_flow_max_depth_list = [max(self._depth_list)] * OFAXceptionNet.NUM_MIDDLE_BLOCKS
self._middle_flow_max_depth_list = [max(self._depth_list)] * OFAXceptionNet.NUM_MIDDLE_BLOCKS

# Entry flow
# first conv layer
Expand All @@ -158,7 +158,7 @@ def __init__(self,

# Middle flow blocks
self.middleblocks = []
for depth in self.middle_flow_max_depth_list:
for depth in self._middle_flow_max_depth_list:
# 8 blocks with each block having 1/2/3 layers of relu+sep_conv
self.middleblocks.append(DynamicXPLayer(
in_channel_list=val2list(mid_block_width),
Expand Down Expand Up @@ -188,7 +188,7 @@ def __init__(self,
self.set_bn_param(decay_rate=bn_param[0], eps=bn_param[1])

# initialise the runtime depth of each block in the middle flow
self.middle_flow_runtime_depth_list = self.middle_flow_max_depth_list.copy()
self.middle_flow_runtime_depth_list = self._middle_flow_max_depth_list.copy()

# set static/dynamic bn
for _, m in self.get_modules():
Expand All @@ -213,7 +213,7 @@ def call(self, x):
x = self.entryblocks[2](x)
# xception has only one stage in the middle flow
for middleblock, runtime_depth in zip(self.middleblocks, self.middle_flow_runtime_depth_list):
middleblock._runtime_depth = runtime_depth
middleblock.runtime_depth = runtime_depth
x = middleblock(x)
x = self.exitblocks[0](x)
x = self.expand_block1(x)
Expand Down Expand Up @@ -242,7 +242,7 @@ def set_active_subnet(self, ks=None, e=None, d=None, **kwargs):

for i, d in enumerate(depth):
if d is not None:
self.middle_flow_runtime_depth_list[i] = min(self.middle_flow_max_depth_list[i], d)
self.middle_flow_runtime_depth_list[i] = min(self._middle_flow_max_depth_list[i], d)

def sample_active_subnet(self):
ks_candidates = self._ks_list
Expand Down Expand Up @@ -449,11 +449,11 @@ def __init__(self,
preserve_weight = True if weights is not None else False

blocks = []
input_channel = self.entryblocks[-1]._out_channels
input_channel = self.entryblocks[-1].out_channels
for middleblock, runtime_depth in zip(self.middleblocks, self.middle_flow_runtime_depth_list):
middleblock._runtime_depth = runtime_depth
middleblock.runtime_depth = runtime_depth
blocks.append(middleblock.get_active_subnet(input_channel, preserve_weight))
input_channel = blocks[-1]._out_channels
input_channel = blocks[-1].out_channels

self.middleblocks = Mo.ModuleList(blocks)

Expand Down
100 changes: 50 additions & 50 deletions nnabla_nas/contrib/common/ofa/elastic_nn/modules/dynamic_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,75 +313,75 @@ def __init__(self, in_channel_list, out_channel_list,
self._kernel_size_list = val2list(kernel_size_list)
self._expand_ratio_list = val2list(expand_ratio_list)
self._stride = stride
self._runtime_depth = depth

# build modules
max_middle_channel = make_divisible(
round(max(self._in_channel_list) * max(self._expand_ratio_list)))

self.depth_conv1 = Mo.Sequential(OrderedDict([
self._depth_conv1 = Mo.Sequential(OrderedDict([
('act', build_activation('relu')),
('dwconv', DynamicSeparableConv2d(max(self._in_channel_list), self._kernel_size_list, self._stride)),
]))

self.point_linear1 = Mo.Sequential(OrderedDict([
self._point_linear1 = Mo.Sequential(OrderedDict([
('ptconv', DynamicConv2d(max(self._in_channel_list), max_middle_channel)),
('bn', DynamicBatchNorm2d(max_middle_channel, 4))
]))

self.depth_conv2 = Mo.Sequential(OrderedDict([
self._depth_conv2 = Mo.Sequential(OrderedDict([
('act', build_activation('relu')),
('dwconv', DynamicSeparableConv2d(max_middle_channel, self._kernel_size_list, self._stride)),
]))

self.point_linear2 = Mo.Sequential(OrderedDict([
self._point_linear2 = Mo.Sequential(OrderedDict([
('ptconv', DynamicConv2d(max_middle_channel, max_middle_channel)),
('bn', DynamicBatchNorm2d(max_middle_channel, 4))
]))

self.depth_conv3 = Mo.Sequential(OrderedDict([
self._depth_conv3 = Mo.Sequential(OrderedDict([
('act', build_activation('relu')),
('dwconv', DynamicSeparableConv2d(max_middle_channel, self._kernel_size_list, self._stride)),
]))

self.point_linear3 = Mo.Sequential(OrderedDict([
self._point_linear3 = Mo.Sequential(OrderedDict([
('ptconv', DynamicConv2d(max_middle_channel, max(self._out_channel_list))),
('bn', DynamicBatchNorm2d(max(self._out_channel_list), 4))
]))

self.runtime_depth = depth
self.active_kernel_size = max(self._kernel_size_list)
self.active_expand_ratio = max(self._expand_ratio_list)
self.active_out_channel = max(self._out_channel_list)

def call(self, inp):
in_channel = inp.shape[1]

self.depth_conv1.dwconv.active_kernel_size = self.active_kernel_size
self.point_linear1.ptconv.active_out_channel = \
self._depth_conv1.dwconv.active_kernel_size = self.active_kernel_size
self._point_linear1.ptconv.active_out_channel = \
make_divisible(round(in_channel * self.active_expand_ratio))

self.depth_conv2.dwconv.active_kernel_size = self.active_kernel_size
self.point_linear2.ptconv.active_out_channel = \
self._depth_conv2.dwconv.active_kernel_size = self.active_kernel_size
self._point_linear2.ptconv.active_out_channel = \
make_divisible(round(in_channel * self.active_expand_ratio))

self.depth_conv3.dwconv.active_kernel_size = self.active_kernel_size
self.point_linear3.ptconv.active_out_channel = self.active_out_channel
self._depth_conv3.dwconv.active_kernel_size = self.active_kernel_size
self._point_linear3.ptconv.active_out_channel = self.active_out_channel

if self._runtime_depth == 1:
self.point_linear1.ptconv.active_out_channel = self.active_out_channel
elif self._runtime_depth == 2:
self.point_linear2.ptconv.active_out_channel = self.active_out_channel
if self.runtime_depth == 1:
self._point_linear1.ptconv.active_out_channel = self.active_out_channel
elif self.runtime_depth == 2:
self._point_linear2.ptconv.active_out_channel = self.active_out_channel

x = self.depth_conv1(inp)
x = self.point_linear1(x)
x = self._depth_conv1(inp)
x = self._point_linear1(x)

if self._runtime_depth > 1: # runtime depth
x = self.depth_conv2(x)
x = self.point_linear2(x)
if self.runtime_depth > 1: # runtime depth
x = self._depth_conv2(x)
x = self._point_linear2(x)

if self._runtime_depth > 2: # runtime depth
x = self.depth_conv3(x)
x = self.point_linear3(x)
if self.runtime_depth > 2: # runtime depth
x = self._depth_conv3(x)
x = self._point_linear3(x)

# Skip is a simple shortcut ->
skip = inp
Expand All @@ -392,7 +392,7 @@ def extra_repr(self):
return get_extra_repr(self)

def re_organize_middle_weights(self, expand_ratio_stage=0):
importance = np.sum(np.abs(self.point_linear3.ptconv.conv._W.d), axis=(0, 2, 3))
importance = np.sum(np.abs(self._point_linear3.ptconv.conv._W.d), axis=(0, 2, 3))
if expand_ratio_stage > 0: # ranking channels
sorted_expand_list = copy.deepcopy(self._expand_ratio_list)
sorted_expand_list.sort(reverse=True)
Expand All @@ -410,26 +410,26 @@ def re_organize_middle_weights(self, expand_ratio_stage=0):
larger_stage = smaller_stage

sorted_idx = np.argsort(-importance)
self.point_linear3.ptconv.conv._W.d = np.stack(
[self.point_linear3.ptconv.conv._W.d[:, idx, :, :] for idx in sorted_idx], axis=1)
adjust_bn_according_to_idx(self.point_linear3.bn.bn, sorted_idx)
self._point_linear3.ptconv.conv._W.d = np.stack(
[self._point_linear3.ptconv.conv._W.d[:, idx, :, :] for idx in sorted_idx], axis=1)
adjust_bn_according_to_idx(self._point_linear3.bn.bn, sorted_idx)

self.point_linear2.ptconv.conv._W.d = np.stack(
[self.point_linear2.ptconv.conv._W.d[:, idx, :, :] for idx in sorted_idx], axis=1)
adjust_bn_according_to_idx(self.point_linear2.bn.bn, sorted_idx)
self._point_linear2.ptconv.conv._W.d = np.stack(
[self._point_linear2.ptconv.conv._W.d[:, idx, :, :] for idx in sorted_idx], axis=1)
adjust_bn_according_to_idx(self._point_linear2.bn.bn, sorted_idx)

self.point_linear1.ptconv.conv._W.d = np.stack(
[self.point_linear1.ptconv.conv._W.d[:, idx, :, :] for idx in sorted_idx], axis=1)
adjust_bn_according_to_idx(self.point_linear1.bn.bn, sorted_idx)
self._point_linear1.ptconv.conv._W.d = np.stack(
[self._point_linear1.ptconv.conv._W.d[:, idx, :, :] for idx in sorted_idx], axis=1)
adjust_bn_according_to_idx(self._point_linear1.bn.bn, sorted_idx)

self.depth_conv3.dwconv.conv._W.d = np.stack(
[self.depth_conv3.dwconv.conv._W.d[idx, :, :, :] for idx in sorted_idx], axis=0)
self._depth_conv3.dwconv.conv._W.d = np.stack(
[self._depth_conv3.dwconv.conv._W.d[idx, :, :, :] for idx in sorted_idx], axis=0)

self.depth_conv2.dwconv.conv._W.d = np.stack(
[self.depth_conv2.dwconv.conv._W.d[idx, :, :, :] for idx in sorted_idx], axis=0)
self._depth_conv2.dwconv.conv._W.d = np.stack(
[self._depth_conv2.dwconv.conv._W.d[idx, :, :, :] for idx in sorted_idx], axis=0)

self.depth_conv1.dwconv.conv._W.d = np.stack(
[self.depth_conv1.dwconv.conv._W.d[idx, :, :, :] for idx in sorted_idx], axis=0)
self._depth_conv1.dwconv.conv._W.d = np.stack(
[self._depth_conv1.dwconv.conv._W.d[idx, :, :, :] for idx in sorted_idx], axis=0)

@property
def in_channels(self):
Expand All @@ -451,27 +451,27 @@ def get_active_subnet(self, in_channel, preserve_weight=True):

middle_channel = self.active_middle_channel(in_channel)

active_filter = self.depth_conv1.dwconv.get_active_filter(in_channel, self.active_kernel_size)
active_filter = self._depth_conv1.dwconv.get_active_filter(in_channel, self.active_kernel_size)
sub_layer.depth_conv1.dwconv._W.d = active_filter.d

active_filter = self.depth_conv2.dwconv.get_active_filter(middle_channel, self.active_kernel_size)
active_filter = self._depth_conv2.dwconv.get_active_filter(middle_channel, self.active_kernel_size)
sub_layer.depth_conv2.dwconv._W.d = active_filter.d

active_filter = self.depth_conv3.dwconv.get_active_filter(middle_channel, self.active_kernel_size)
active_filter = self._depth_conv3.dwconv.get_active_filter(middle_channel, self.active_kernel_size)
sub_layer.depth_conv3.dwconv._W.d = active_filter.d

copy_bn(sub_layer.point_linear1.bn, self.point_linear1.bn.bn)
copy_bn(sub_layer.point_linear2.bn, self.point_linear2.bn.bn)
copy_bn(sub_layer.point_linear3.bn, self.point_linear3.bn.bn)
copy_bn(sub_layer.point_linear1.bn, self._point_linear1.bn.bn)
copy_bn(sub_layer.point_linear2.bn, self._point_linear2.bn.bn)
copy_bn(sub_layer.point_linear3.bn, self._point_linear3.bn.bn)

sub_layer.point_linear1.ptconv._W.d =\
self.point_linear1.ptconv.conv._W.d[:middle_channel, :in_channel, :, :]
self._point_linear1.ptconv.conv._W.d[:middle_channel, :in_channel, :, :]

sub_layer.point_linear2.ptconv._W.d =\
self.point_linear2.ptconv.conv._W.d[:middle_channel, :middle_channel, :, :]
self._point_linear2.ptconv.conv._W.d[:middle_channel, :middle_channel, :, :]

sub_layer.point_linear3.ptconv._W.d =\
self.point_linear3.ptconv.conv._W.d[:self.active_out_channel, :middle_channel, :, :]
self._point_linear3.ptconv.conv._W.d[:self.active_out_channel, :middle_channel, :, :]

nn.set_auto_forward(False)

Expand Down
48 changes: 28 additions & 20 deletions nnabla_nas/contrib/common/ofa/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,25 +354,25 @@ def __init__(
use_bn=True, act_fn=None):
super(SeparableConv, self).__init__()

self.conv1 = Mo.Conv(in_channels, in_channels, kernel, pad=pad, dilation=dilation,
stride=stride, with_bias=False, group=in_channels)
self.pointwise = Mo.Conv(in_channels, out_channels, (1, 1), stride=(1, 1),
pad=(0, 0), dilation=(1, 1), group=1, with_bias=False)
self._conv1 = Mo.Conv(in_channels, in_channels, kernel, pad=pad, dilation=dilation,
stride=stride, with_bias=False, group=in_channels)
self._pointwise = Mo.Conv(in_channels, out_channels, (1, 1), stride=(1, 1),
pad=(0, 0), dilation=(1, 1), group=1, with_bias=False)

self.use_bn = use_bn
self._use_bn = use_bn
if use_bn:
self.bn = Mo.BatchNormalization(out_channels, 4)
self.act = build_activation(act_fn)
self._bn = Mo.BatchNormalization(out_channels, 4)
self._act = build_activation(act_fn)

def call(self, x):
x = self.conv1(x)
x = self.pointwise(x)
x = self._conv1(x)
x = self._pointwise(x)

if self.use_bn:
x = self.bn(x)
x = self._bn(x)

if self.act is not None:
x = self.act(x)
x = self._act(x)

return x

Expand Down Expand Up @@ -400,11 +400,11 @@ def __init__(
self._out_channels = out_channels

if out_channels != in_channels or stride != (1, 1):
self.skip = Mo.Conv(in_channels, out_channels,
(1, 1), stride=stride, with_bias=False)
self.skipbn = Mo.BatchNormalization(out_channels, 4)
self._skip = Mo.Conv(in_channels, out_channels,
(1, 1), stride=stride, with_bias=False)
self._skipbn = Mo.BatchNormalization(out_channels, 4)
else:
self.skip = None
self._skip = None

rep = []
mid_channels = out_channels if grow_first else in_channels
Expand Down Expand Up @@ -432,20 +432,28 @@ def __init__(

if stride != (1, 1):
rep.append(Mo.MaxPool((3, 3), stride=stride, pad=(1, 1)))
self.rep = Mo.Sequential(*rep)
self._rep = Mo.Sequential(*rep)

def call(self, inp):
x = self.rep(inp)
x = self._rep(inp)

if self.skip is not None:
skip = self.skip(inp)
skip = self.skipbn(skip)
if self._skip is not None:
skip = self._skip(inp)
skip = self._skipbn(skip)
else:
skip = inp

x += skip
return x

@property
def in_channels(self):
return self._in_channels

@property
def out_channels(self):
return self._out_channels

@staticmethod
def build_from_config(config):
return XceptionBlock(**config)
Expand Down

0 comments on commit 8aab0ff

Please sign in to comment.