Skip to content

Commit

Permalink
Fix ResNet based models to work w/ norm layers w/o affine params. Ref…
Browse files Browse the repository at this point in the history
…ormat long arg lists into vertical form.
  • Loading branch information
rwightman committed Dec 30, 2022
1 parent d5aa17e commit 6902c48
Show file tree
Hide file tree
Showing 6 changed files with 397 additions and 86 deletions.
163 changes: 134 additions & 29 deletions timm/models/byobnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -962,9 +962,21 @@ class BasicBlock(nn.Module):
"""

def __init__(
self, in_chs, out_chs, kernel_size=3, stride=1, dilation=(1, 1), group_size=None, bottle_ratio=1.0,
downsample='avg', attn_last=True, linear_out=False, layers: LayerFn = None, drop_block=None,
drop_path_rate=0.):
self,
in_chs,
out_chs,
kernel_size=3,
stride=1,
dilation=(1, 1),
group_size=None,
bottle_ratio=1.0,
downsample='avg',
attn_last=True,
linear_out=False,
layers: LayerFn = None,
drop_block=None,
drop_path_rate=0.,
):
super(BasicBlock, self).__init__()
layers = layers or LayerFn()
mid_chs = make_divisible(out_chs * bottle_ratio)
Expand All @@ -983,7 +995,7 @@ def __init__(
self.act = nn.Identity() if linear_out else layers.act(inplace=True)

def init_weights(self, zero_init_last: bool = False):
if zero_init_last and self.shortcut is not None:
if zero_init_last and self.shortcut is not None and getattr(self.conv2_kxk.bn, 'weight', None) is not None:
nn.init.zeros_(self.conv2_kxk.bn.weight)
for attn in (self.attn, self.attn_last):
if hasattr(attn, 'reset_parameters'):
Expand All @@ -1005,9 +1017,23 @@ class BottleneckBlock(nn.Module):
"""

def __init__(
self, in_chs, out_chs, kernel_size=3, stride=1, dilation=(1, 1), bottle_ratio=1., group_size=None,
downsample='avg', attn_last=False, linear_out=False, extra_conv=False, bottle_in=False,
layers: LayerFn = None, drop_block=None, drop_path_rate=0.):
self,
in_chs,
out_chs,
kernel_size=3,
stride=1,
dilation=(1, 1),
bottle_ratio=1.,
group_size=None,
downsample='avg',
attn_last=False,
linear_out=False,
extra_conv=False,
bottle_in=False,
layers: LayerFn = None,
drop_block=None,
drop_path_rate=0.,
):
super(BottleneckBlock, self).__init__()
layers = layers or LayerFn()
mid_chs = make_divisible((in_chs if bottle_in else out_chs) * bottle_ratio)
Expand All @@ -1031,7 +1057,7 @@ def __init__(
self.act = nn.Identity() if linear_out else layers.act(inplace=True)

def init_weights(self, zero_init_last: bool = False):
if zero_init_last and self.shortcut is not None:
if zero_init_last and self.shortcut is not None and getattr(self.conv3_1x1.bn, 'weight', None) is not None:
nn.init.zeros_(self.conv3_1x1.bn.weight)
for attn in (self.attn, self.attn_last):
if hasattr(attn, 'reset_parameters'):
Expand Down Expand Up @@ -1063,9 +1089,21 @@ class DarkBlock(nn.Module):
"""

def __init__(
self, in_chs, out_chs, kernel_size=3, stride=1, dilation=(1, 1), bottle_ratio=1.0, group_size=None,
downsample='avg', attn_last=True, linear_out=False, layers: LayerFn = None, drop_block=None,
drop_path_rate=0.):
self,
in_chs,
out_chs,
kernel_size=3,
stride=1,
dilation=(1, 1),
bottle_ratio=1.0,
group_size=None,
downsample='avg',
attn_last=True,
linear_out=False,
layers: LayerFn = None,
drop_block=None,
drop_path_rate=0.,
):
super(DarkBlock, self).__init__()
layers = layers or LayerFn()
mid_chs = make_divisible(out_chs * bottle_ratio)
Expand All @@ -1085,7 +1123,7 @@ def __init__(
self.act = nn.Identity() if linear_out else layers.act(inplace=True)

def init_weights(self, zero_init_last: bool = False):
if zero_init_last and self.shortcut is not None:
if zero_init_last and self.shortcut is not None and getattr(self.conv2_kxk.bn, 'weight', None) is not None:
nn.init.zeros_(self.conv2_kxk.bn.weight)
for attn in (self.attn, self.attn_last):
if hasattr(attn, 'reset_parameters'):
Expand Down Expand Up @@ -1114,9 +1152,21 @@ class EdgeBlock(nn.Module):
"""

def __init__(
self, in_chs, out_chs, kernel_size=3, stride=1, dilation=(1, 1), bottle_ratio=1.0, group_size=None,
downsample='avg', attn_last=False, linear_out=False, layers: LayerFn = None,
drop_block=None, drop_path_rate=0.):
self,
in_chs,
out_chs,
kernel_size=3,
stride=1,
dilation=(1, 1),
bottle_ratio=1.0,
group_size=None,
downsample='avg',
attn_last=False,
linear_out=False,
layers: LayerFn = None,
drop_block=None,
drop_path_rate=0.,
):
super(EdgeBlock, self).__init__()
layers = layers or LayerFn()
mid_chs = make_divisible(out_chs * bottle_ratio)
Expand All @@ -1135,7 +1185,7 @@ def __init__(
self.act = nn.Identity() if linear_out else layers.act(inplace=True)

def init_weights(self, zero_init_last: bool = False):
if zero_init_last and self.shortcut is not None:
if zero_init_last and self.shortcut is not None and getattr(self.conv2_1x1.bn, 'weight', None) is not None:
nn.init.zeros_(self.conv2_1x1.bn.weight)
for attn in (self.attn, self.attn_last):
if hasattr(attn, 'reset_parameters'):
Expand All @@ -1162,8 +1212,19 @@ class RepVggBlock(nn.Module):
"""

def __init__(
self, in_chs, out_chs, kernel_size=3, stride=1, dilation=(1, 1), bottle_ratio=1.0, group_size=None,
downsample='', layers: LayerFn = None, drop_block=None, drop_path_rate=0.):
self,
in_chs,
out_chs,
kernel_size=3,
stride=1,
dilation=(1, 1),
bottle_ratio=1.0,
group_size=None,
downsample='',
layers: LayerFn = None,
drop_block=None,
drop_path_rate=0.,
):
super(RepVggBlock, self).__init__()
layers = layers or LayerFn()
groups = num_groups(group_size, in_chs)
Expand Down Expand Up @@ -1204,9 +1265,24 @@ class SelfAttnBlock(nn.Module):
"""

def __init__(
self, in_chs, out_chs, kernel_size=3, stride=1, dilation=(1, 1), bottle_ratio=1., group_size=None,
downsample='avg', extra_conv=False, linear_out=False, bottle_in=False, post_attn_na=True,
feat_size=None, layers: LayerFn = None, drop_block=None, drop_path_rate=0.):
self,
in_chs,
out_chs,
kernel_size=3,
stride=1,
dilation=(1, 1),
bottle_ratio=1.,
group_size=None,
downsample='avg',
extra_conv=False,
linear_out=False,
bottle_in=False,
post_attn_na=True,
feat_size=None,
layers: LayerFn = None,
drop_block=None,
drop_path_rate=0.,
):
super(SelfAttnBlock, self).__init__()
assert layers is not None
mid_chs = make_divisible((in_chs if bottle_in else out_chs) * bottle_ratio)
Expand All @@ -1233,7 +1309,7 @@ def __init__(
self.act = nn.Identity() if linear_out else layers.act(inplace=True)

def init_weights(self, zero_init_last: bool = False):
if zero_init_last and self.shortcut is not None:
if zero_init_last and self.shortcut is not None and getattr(self.conv3_1x1.bn, 'weight', None) is not None:
nn.init.zeros_(self.conv3_1x1.bn.weight)
if hasattr(self.self_attn, 'reset_parameters'):
self.self_attn.reset_parameters()
Expand Down Expand Up @@ -1274,8 +1350,17 @@ def create_block(block: Union[str, nn.Module], **kwargs):
class Stem(nn.Sequential):

def __init__(
self, in_chs, out_chs, kernel_size=3, stride=4, pool='maxpool',
num_rep=3, num_act=None, chs_decay=0.5, layers: LayerFn = None):
self,
in_chs,
out_chs,
kernel_size=3,
stride=4,
pool='maxpool',
num_rep=3,
num_act=None,
chs_decay=0.5,
layers: LayerFn = None,
):
super().__init__()
assert stride in (2, 4)
layers = layers or LayerFn()
Expand Down Expand Up @@ -1319,7 +1404,14 @@ def __init__(
assert curr_stride == stride


def create_byob_stem(in_chs, out_chs, stem_type='', pool_type='', feat_prefix='stem', layers: LayerFn = None):
def create_byob_stem(
in_chs,
out_chs,
stem_type='',
pool_type='',
feat_prefix='stem',
layers: LayerFn = None,
):
layers = layers or LayerFn()
assert stem_type in ('', 'quad', 'quad2', 'tiered', 'deep', 'rep', '7x7', '3x3')
if 'quad' in stem_type:
Expand Down Expand Up @@ -1407,10 +1499,14 @@ def update_block_kwargs(block_kwargs: Dict[str, Any], block_cfg: ByoBlockCfg, mo


def create_byob_stages(
cfg: ByoModelCfg, drop_path_rate: float, output_stride: int, stem_feat: Dict[str, Any],
cfg: ByoModelCfg,
drop_path_rate: float,
output_stride: int,
stem_feat: Dict[str, Any],
feat_size: Optional[int] = None,
layers: Optional[LayerFn] = None,
block_kwargs_fn: Optional[Callable] = update_block_kwargs):
block_kwargs_fn: Optional[Callable] = update_block_kwargs,
):

layers = layers or LayerFn()
feature_info = []
Expand Down Expand Up @@ -1485,8 +1581,17 @@ class ByobNet(nn.Module):
Current assumption is that both stem and blocks are in conv-bn-act order (w/ block ending in act).
"""
def __init__(
self, cfg: ByoModelCfg, num_classes=1000, in_chans=3, global_pool='avg', output_stride=32,
zero_init_last=True, img_size=None, drop_rate=0., drop_path_rate=0.):
self,
cfg: ByoModelCfg,
num_classes=1000,
in_chans=3,
global_pool='avg',
output_stride=32,
zero_init_last=True,
img_size=None,
drop_rate=0.,
drop_path_rate=0.,
):
super().__init__()
self.num_classes = num_classes
self.drop_rate = drop_rate
Expand Down
21 changes: 17 additions & 4 deletions timm/models/res2net.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,21 @@ class Bottle2neck(nn.Module):
expansion = 4

def __init__(
self, inplanes, planes, stride=1, downsample=None,
cardinality=1, base_width=26, scale=4, dilation=1, first_dilation=None,
act_layer=nn.ReLU, norm_layer=None, attn_layer=None, **_):
self,
inplanes,
planes,
stride=1,
downsample=None,
cardinality=1,
base_width=26,
scale=4,
dilation=1,
first_dilation=None,
act_layer=nn.ReLU,
norm_layer=None,
attn_layer=None,
**_,
):
super(Bottle2neck, self).__init__()
self.scale = scale
self.is_first = stride > 1 or downsample is not None
Expand Down Expand Up @@ -89,7 +101,8 @@ def __init__(
self.downsample = downsample

def zero_init_last(self):
nn.init.zeros_(self.bn3.weight)
if getattr(self.bn3, 'weight', None) is not None:
nn.init.zeros_(self.bn3.weight)

def forward(self, x):
shortcut = x
Expand Down
28 changes: 23 additions & 5 deletions timm/models/resnest.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,27 @@ class ResNestBottleneck(nn.Module):
expansion = 4

def __init__(
self, inplanes, planes, stride=1, downsample=None,
radix=1, cardinality=1, base_width=64, avd=False, avd_first=False, is_first=False,
reduce_first=1, dilation=1, first_dilation=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d,
attn_layer=None, aa_layer=None, drop_block=None, drop_path=None):
self,
inplanes,
planes,
stride=1,
downsample=None,
radix=1,
cardinality=1,
base_width=64,
avd=False,
avd_first=False,
is_first=False,
reduce_first=1,
dilation=1,
first_dilation=None,
act_layer=nn.ReLU,
norm_layer=nn.BatchNorm2d,
attn_layer=None,
aa_layer=None,
drop_block=None,
drop_path=None,
):
super(ResNestBottleneck, self).__init__()
assert reduce_first == 1 # not supported
assert attn_layer is None # not supported
Expand Down Expand Up @@ -103,7 +120,8 @@ def __init__(
self.downsample = downsample

def zero_init_last(self):
nn.init.zeros_(self.bn3.weight)
if getattr(self.bn3, 'weight', None) is not None:
nn.init.zeros_(self.bn3.weight)

def forward(self, x):
shortcut = x
Expand Down
Loading

0 comments on commit 6902c48

Please sign in to comment.