From 6902c48a5f0637c8155c1c4bc10ad35930f3e772 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 29 Dec 2022 16:32:26 -0800 Subject: [PATCH] Fix ResNet based models to work w/ norm layers w/o affine params. Reformat long arg lists into vertical form. --- timm/models/byobnet.py | 163 +++++++++++++++++++++++++++++++++------- timm/models/res2net.py | 21 +++++- timm/models/resnest.py | 28 +++++-- timm/models/resnet.py | 122 ++++++++++++++++++++++++------ timm/models/resnetv2.py | 101 ++++++++++++++++++++----- timm/models/sknet.py | 48 ++++++++++-- 6 files changed, 397 insertions(+), 86 deletions(-) diff --git a/timm/models/byobnet.py b/timm/models/byobnet.py index 0e5c9c7fa9..15f78044b0 100644 --- a/timm/models/byobnet.py +++ b/timm/models/byobnet.py @@ -962,9 +962,21 @@ class BasicBlock(nn.Module): """ def __init__( - self, in_chs, out_chs, kernel_size=3, stride=1, dilation=(1, 1), group_size=None, bottle_ratio=1.0, - downsample='avg', attn_last=True, linear_out=False, layers: LayerFn = None, drop_block=None, - drop_path_rate=0.): + self, + in_chs, + out_chs, + kernel_size=3, + stride=1, + dilation=(1, 1), + group_size=None, + bottle_ratio=1.0, + downsample='avg', + attn_last=True, + linear_out=False, + layers: LayerFn = None, + drop_block=None, + drop_path_rate=0., + ): super(BasicBlock, self).__init__() layers = layers or LayerFn() mid_chs = make_divisible(out_chs * bottle_ratio) @@ -983,7 +995,7 @@ def __init__( self.act = nn.Identity() if linear_out else layers.act(inplace=True) def init_weights(self, zero_init_last: bool = False): - if zero_init_last and self.shortcut is not None: + if zero_init_last and self.shortcut is not None and getattr(self.conv2_kxk.bn, 'weight', None) is not None: nn.init.zeros_(self.conv2_kxk.bn.weight) for attn in (self.attn, self.attn_last): if hasattr(attn, 'reset_parameters'): @@ -1005,9 +1017,23 @@ class BottleneckBlock(nn.Module): """ def __init__( - self, in_chs, out_chs, kernel_size=3, stride=1, dilation=(1, 1), bottle_ratio=1., group_size=None, - downsample='avg', attn_last=False, linear_out=False, extra_conv=False, bottle_in=False, - layers: LayerFn = None, drop_block=None, drop_path_rate=0.): + self, + in_chs, + out_chs, + kernel_size=3, + stride=1, + dilation=(1, 1), + bottle_ratio=1., + group_size=None, + downsample='avg', + attn_last=False, + linear_out=False, + extra_conv=False, + bottle_in=False, + layers: LayerFn = None, + drop_block=None, + drop_path_rate=0., + ): super(BottleneckBlock, self).__init__() layers = layers or LayerFn() mid_chs = make_divisible((in_chs if bottle_in else out_chs) * bottle_ratio) @@ -1031,7 +1057,7 @@ def __init__( self.act = nn.Identity() if linear_out else layers.act(inplace=True) def init_weights(self, zero_init_last: bool = False): - if zero_init_last and self.shortcut is not None: + if zero_init_last and self.shortcut is not None and getattr(self.conv3_1x1.bn, 'weight', None) is not None: nn.init.zeros_(self.conv3_1x1.bn.weight) for attn in (self.attn, self.attn_last): if hasattr(attn, 'reset_parameters'): @@ -1063,9 +1089,21 @@ class DarkBlock(nn.Module): """ def __init__( - self, in_chs, out_chs, kernel_size=3, stride=1, dilation=(1, 1), bottle_ratio=1.0, group_size=None, - downsample='avg', attn_last=True, linear_out=False, layers: LayerFn = None, drop_block=None, - drop_path_rate=0.): + self, + in_chs, + out_chs, + kernel_size=3, + stride=1, + dilation=(1, 1), + bottle_ratio=1.0, + group_size=None, + downsample='avg', + attn_last=True, + linear_out=False, + layers: LayerFn = None, + drop_block=None, + drop_path_rate=0., + ): super(DarkBlock, self).__init__() layers = layers or LayerFn() mid_chs = make_divisible(out_chs * bottle_ratio) @@ -1085,7 +1123,7 @@ def __init__( self.act = nn.Identity() if linear_out else layers.act(inplace=True) def init_weights(self, zero_init_last: bool = False): - if zero_init_last and self.shortcut is not None: + if zero_init_last and self.shortcut is not None and getattr(self.conv2_kxk.bn, 'weight', None) is not None: nn.init.zeros_(self.conv2_kxk.bn.weight) for attn in (self.attn, self.attn_last): if hasattr(attn, 'reset_parameters'): @@ -1114,9 +1152,21 @@ class EdgeBlock(nn.Module): """ def __init__( - self, in_chs, out_chs, kernel_size=3, stride=1, dilation=(1, 1), bottle_ratio=1.0, group_size=None, - downsample='avg', attn_last=False, linear_out=False, layers: LayerFn = None, - drop_block=None, drop_path_rate=0.): + self, + in_chs, + out_chs, + kernel_size=3, + stride=1, + dilation=(1, 1), + bottle_ratio=1.0, + group_size=None, + downsample='avg', + attn_last=False, + linear_out=False, + layers: LayerFn = None, + drop_block=None, + drop_path_rate=0., + ): super(EdgeBlock, self).__init__() layers = layers or LayerFn() mid_chs = make_divisible(out_chs * bottle_ratio) @@ -1135,7 +1185,7 @@ def __init__( self.act = nn.Identity() if linear_out else layers.act(inplace=True) def init_weights(self, zero_init_last: bool = False): - if zero_init_last and self.shortcut is not None: + if zero_init_last and self.shortcut is not None and getattr(self.conv2_1x1.bn, 'weight', None) is not None: nn.init.zeros_(self.conv2_1x1.bn.weight) for attn in (self.attn, self.attn_last): if hasattr(attn, 'reset_parameters'): @@ -1162,8 +1212,19 @@ class RepVggBlock(nn.Module): """ def __init__( - self, in_chs, out_chs, kernel_size=3, stride=1, dilation=(1, 1), bottle_ratio=1.0, group_size=None, - downsample='', layers: LayerFn = None, drop_block=None, drop_path_rate=0.): + self, + in_chs, + out_chs, + kernel_size=3, + stride=1, + dilation=(1, 1), + bottle_ratio=1.0, + group_size=None, + downsample='', + layers: LayerFn = None, + drop_block=None, + drop_path_rate=0., + ): super(RepVggBlock, self).__init__() layers = layers or LayerFn() groups = num_groups(group_size, in_chs) @@ -1204,9 +1265,24 @@ class SelfAttnBlock(nn.Module): """ def __init__( - self, in_chs, out_chs, kernel_size=3, stride=1, dilation=(1, 1), bottle_ratio=1., group_size=None, - downsample='avg', extra_conv=False, linear_out=False, bottle_in=False, post_attn_na=True, - feat_size=None, layers: LayerFn = None, drop_block=None, drop_path_rate=0.): + self, + in_chs, + out_chs, + kernel_size=3, + stride=1, + dilation=(1, 1), + bottle_ratio=1., + group_size=None, + downsample='avg', + extra_conv=False, + linear_out=False, + bottle_in=False, + post_attn_na=True, + feat_size=None, + layers: LayerFn = None, + drop_block=None, + drop_path_rate=0., + ): super(SelfAttnBlock, self).__init__() assert layers is not None mid_chs = make_divisible((in_chs if bottle_in else out_chs) * bottle_ratio) @@ -1233,7 +1309,7 @@ def __init__( self.act = nn.Identity() if linear_out else layers.act(inplace=True) def init_weights(self, zero_init_last: bool = False): - if zero_init_last and self.shortcut is not None: + if zero_init_last and self.shortcut is not None and getattr(self.conv3_1x1.bn, 'weight', None) is not None: nn.init.zeros_(self.conv3_1x1.bn.weight) if hasattr(self.self_attn, 'reset_parameters'): self.self_attn.reset_parameters() @@ -1274,8 +1350,17 @@ def create_block(block: Union[str, nn.Module], **kwargs): class Stem(nn.Sequential): def __init__( - self, in_chs, out_chs, kernel_size=3, stride=4, pool='maxpool', - num_rep=3, num_act=None, chs_decay=0.5, layers: LayerFn = None): + self, + in_chs, + out_chs, + kernel_size=3, + stride=4, + pool='maxpool', + num_rep=3, + num_act=None, + chs_decay=0.5, + layers: LayerFn = None, + ): super().__init__() assert stride in (2, 4) layers = layers or LayerFn() @@ -1319,7 +1404,14 @@ def __init__( assert curr_stride == stride -def create_byob_stem(in_chs, out_chs, stem_type='', pool_type='', feat_prefix='stem', layers: LayerFn = None): +def create_byob_stem( + in_chs, + out_chs, + stem_type='', + pool_type='', + feat_prefix='stem', + layers: LayerFn = None, +): layers = layers or LayerFn() assert stem_type in ('', 'quad', 'quad2', 'tiered', 'deep', 'rep', '7x7', '3x3') if 'quad' in stem_type: @@ -1407,10 +1499,14 @@ def update_block_kwargs(block_kwargs: Dict[str, Any], block_cfg: ByoBlockCfg, mo def create_byob_stages( - cfg: ByoModelCfg, drop_path_rate: float, output_stride: int, stem_feat: Dict[str, Any], + cfg: ByoModelCfg, + drop_path_rate: float, + output_stride: int, + stem_feat: Dict[str, Any], feat_size: Optional[int] = None, layers: Optional[LayerFn] = None, - block_kwargs_fn: Optional[Callable] = update_block_kwargs): + block_kwargs_fn: Optional[Callable] = update_block_kwargs, +): layers = layers or LayerFn() feature_info = [] @@ -1485,8 +1581,17 @@ class ByobNet(nn.Module): Current assumption is that both stem and blocks are in conv-bn-act order (w/ block ending in act). """ def __init__( - self, cfg: ByoModelCfg, num_classes=1000, in_chans=3, global_pool='avg', output_stride=32, - zero_init_last=True, img_size=None, drop_rate=0., drop_path_rate=0.): + self, + cfg: ByoModelCfg, + num_classes=1000, + in_chans=3, + global_pool='avg', + output_stride=32, + zero_init_last=True, + img_size=None, + drop_rate=0., + drop_path_rate=0., + ): super().__init__() self.num_classes = num_classes self.drop_rate = drop_rate diff --git a/timm/models/res2net.py b/timm/models/res2net.py index 4724df2a90..607ba722e3 100644 --- a/timm/models/res2net.py +++ b/timm/models/res2net.py @@ -51,9 +51,21 @@ class Bottle2neck(nn.Module): expansion = 4 def __init__( - self, inplanes, planes, stride=1, downsample=None, - cardinality=1, base_width=26, scale=4, dilation=1, first_dilation=None, - act_layer=nn.ReLU, norm_layer=None, attn_layer=None, **_): + self, + inplanes, + planes, + stride=1, + downsample=None, + cardinality=1, + base_width=26, + scale=4, + dilation=1, + first_dilation=None, + act_layer=nn.ReLU, + norm_layer=None, + attn_layer=None, + **_, + ): super(Bottle2neck, self).__init__() self.scale = scale self.is_first = stride > 1 or downsample is not None @@ -89,7 +101,8 @@ def __init__( self.downsample = downsample def zero_init_last(self): - nn.init.zeros_(self.bn3.weight) + if getattr(self.bn3, 'weight', None) is not None: + nn.init.zeros_(self.bn3.weight) def forward(self, x): shortcut = x diff --git a/timm/models/resnest.py b/timm/models/resnest.py index 3b001c7bfd..853ee1d020 100644 --- a/timm/models/resnest.py +++ b/timm/models/resnest.py @@ -57,10 +57,27 @@ class ResNestBottleneck(nn.Module): expansion = 4 def __init__( - self, inplanes, planes, stride=1, downsample=None, - radix=1, cardinality=1, base_width=64, avd=False, avd_first=False, is_first=False, - reduce_first=1, dilation=1, first_dilation=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, - attn_layer=None, aa_layer=None, drop_block=None, drop_path=None): + self, + inplanes, + planes, + stride=1, + downsample=None, + radix=1, + cardinality=1, + base_width=64, + avd=False, + avd_first=False, + is_first=False, + reduce_first=1, + dilation=1, + first_dilation=None, + act_layer=nn.ReLU, + norm_layer=nn.BatchNorm2d, + attn_layer=None, + aa_layer=None, + drop_block=None, + drop_path=None, + ): super(ResNestBottleneck, self).__init__() assert reduce_first == 1 # not supported assert attn_layer is None # not supported @@ -103,7 +120,8 @@ def __init__( self.downsample = downsample def zero_init_last(self): - nn.init.zeros_(self.bn3.weight) + if getattr(self.bn3, 'weight', None) is not None: + nn.init.zeros_(self.bn3.weight) def forward(self, x): shortcut = x diff --git a/timm/models/resnet.py b/timm/models/resnet.py index 508490178b..2976c1f98f 100644 --- a/timm/models/resnet.py +++ b/timm/models/resnet.py @@ -337,9 +337,23 @@ class BasicBlock(nn.Module): expansion = 1 def __init__( - self, inplanes, planes, stride=1, downsample=None, cardinality=1, base_width=64, - reduce_first=1, dilation=1, first_dilation=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, - attn_layer=None, aa_layer=None, drop_block=None, drop_path=None): + self, + inplanes, + planes, + stride=1, + downsample=None, + cardinality=1, + base_width=64, + reduce_first=1, + dilation=1, + first_dilation=None, + act_layer=nn.ReLU, + norm_layer=nn.BatchNorm2d, + attn_layer=None, + aa_layer=None, + drop_block=None, + drop_path=None, + ): super(BasicBlock, self).__init__() assert cardinality == 1, 'BasicBlock only supports cardinality of 1' @@ -370,7 +384,8 @@ def __init__( self.drop_path = drop_path def zero_init_last(self): - nn.init.zeros_(self.bn2.weight) + if getattr(self.bn2, 'weight', None) is not None: + nn.init.zeros_(self.bn2.weight) def forward(self, x): shortcut = x @@ -402,9 +417,23 @@ class Bottleneck(nn.Module): expansion = 4 def __init__( - self, inplanes, planes, stride=1, downsample=None, cardinality=1, base_width=64, - reduce_first=1, dilation=1, first_dilation=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, - attn_layer=None, aa_layer=None, drop_block=None, drop_path=None): + self, + inplanes, + planes, + stride=1, + downsample=None, + cardinality=1, + base_width=64, + reduce_first=1, + dilation=1, + first_dilation=None, + act_layer=nn.ReLU, + norm_layer=nn.BatchNorm2d, + attn_layer=None, + aa_layer=None, + drop_block=None, + drop_path=None, + ): super(Bottleneck, self).__init__() width = int(math.floor(planes * (base_width / 64)) * cardinality) @@ -437,7 +466,8 @@ def __init__( self.drop_path = drop_path def zero_init_last(self): - nn.init.zeros_(self.bn3.weight) + if getattr(self.bn3, 'weight', None) is not None: + nn.init.zeros_(self.bn3.weight) def forward(self, x): shortcut = x @@ -508,8 +538,18 @@ def drop_blocks(drop_prob=0.): def make_blocks( - block_fn, channels, block_repeats, inplanes, reduce_first=1, output_stride=32, - down_kernel_size=1, avg_down=False, drop_block_rate=0., drop_path_rate=0., **kwargs): + block_fn, + channels, + block_repeats, + inplanes, + reduce_first=1, + output_stride=32, + down_kernel_size=1, + avg_down=False, + drop_block_rate=0., + drop_path_rate=0., + **kwargs, +): stages = [] feature_info = [] net_num_blocks = sum(block_repeats) @@ -528,8 +568,14 @@ def make_blocks( downsample = None if stride != 1 or inplanes != planes * block_fn.expansion: down_kwargs = dict( - in_channels=inplanes, out_channels=planes * block_fn.expansion, kernel_size=down_kernel_size, - stride=stride, dilation=dilation, first_dilation=prev_dilation, norm_layer=kwargs.get('norm_layer')) + in_channels=inplanes, + out_channels=planes * block_fn.expansion, + kernel_size=down_kernel_size, + stride=stride, + dilation=dilation, + first_dilation=prev_dilation, + norm_layer=kwargs.get('norm_layer'), + ) downsample = downsample_avg(**down_kwargs) if avg_down else downsample_conv(**down_kwargs) block_kwargs = dict(reduce_first=reduce_first, dilation=dilation, drop_block=db, **kwargs) @@ -609,10 +655,30 @@ class ResNet(nn.Module): """ def __init__( - self, block, layers, num_classes=1000, in_chans=3, output_stride=32, global_pool='avg', - cardinality=1, base_width=64, stem_width=64, stem_type='', replace_stem_pool=False, block_reduce_first=1, - down_kernel_size=1, avg_down=False, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, aa_layer=None, - drop_rate=0.0, drop_path_rate=0., drop_block_rate=0., zero_init_last=True, block_args=None): + self, + block, + layers, + num_classes=1000, + in_chans=3, + output_stride=32, + global_pool='avg', + cardinality=1, + base_width=64, + stem_width=64, + stem_type='', + replace_stem_pool=False, + block_reduce_first=1, + down_kernel_size=1, + avg_down=False, + act_layer=nn.ReLU, + norm_layer=nn.BatchNorm2d, + aa_layer=None, + drop_rate=0.0, + drop_path_rate=0., + drop_block_rate=0., + zero_init_last=True, + block_args=None, + ): super(ResNet, self).__init__() block_args = block_args or dict() assert output_stride in (8, 16, 32) @@ -663,10 +729,23 @@ def __init__( # Feature Blocks channels = [64, 128, 256, 512] stage_modules, stage_feature_info = make_blocks( - block, channels, layers, inplanes, cardinality=cardinality, base_width=base_width, - output_stride=output_stride, reduce_first=block_reduce_first, avg_down=avg_down, - down_kernel_size=down_kernel_size, act_layer=act_layer, norm_layer=norm_layer, aa_layer=aa_layer, - drop_block_rate=drop_block_rate, drop_path_rate=drop_path_rate, **block_args) + block, + channels, + layers, + inplanes, + cardinality=cardinality, + base_width=base_width, + output_stride=output_stride, + reduce_first=block_reduce_first, + avg_down=avg_down, + down_kernel_size=down_kernel_size, + act_layer=act_layer, + norm_layer=norm_layer, + aa_layer=aa_layer, + drop_block_rate=drop_block_rate, + drop_path_rate=drop_path_rate, + **block_args, + ) for stage in stage_modules: self.add_module(*stage) # layer1, layer2, etc self.feature_info.extend(stage_feature_info) @@ -687,9 +766,6 @@ def init_weights(self, zero_init_last=True): for n, m in self.named_modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') - elif isinstance(m, nn.BatchNorm2d): - nn.init.ones_(m.weight) - nn.init.zeros_(m.bias) if zero_init_last: for m in self.modules(): if hasattr(m, 'zero_init_last'): diff --git a/timm/models/resnetv2.py b/timm/models/resnetv2.py index f8c4298b24..a55f48ac0f 100644 --- a/timm/models/resnetv2.py +++ b/timm/models/resnetv2.py @@ -155,8 +155,20 @@ class PreActBottleneck(nn.Module): """ def __init__( - self, in_chs, out_chs=None, bottle_ratio=0.25, stride=1, dilation=1, first_dilation=None, groups=1, - act_layer=None, conv_layer=None, norm_layer=None, proj_layer=None, drop_path_rate=0.): + self, + in_chs, + out_chs=None, + bottle_ratio=0.25, + stride=1, + dilation=1, + first_dilation=None, + groups=1, + act_layer=None, + conv_layer=None, + norm_layer=None, + proj_layer=None, + drop_path_rate=0., + ): super().__init__() first_dilation = first_dilation or dilation conv_layer = conv_layer or StdConv2d @@ -202,8 +214,20 @@ class Bottleneck(nn.Module): """Non Pre-activation bottleneck block, equiv to V1.5/V1b Bottleneck. Used for ViT. """ def __init__( - self, in_chs, out_chs=None, bottle_ratio=0.25, stride=1, dilation=1, first_dilation=None, groups=1, - act_layer=None, conv_layer=None, norm_layer=None, proj_layer=None, drop_path_rate=0.): + self, + in_chs, + out_chs=None, + bottle_ratio=0.25, + stride=1, + dilation=1, + first_dilation=None, + groups=1, + act_layer=None, + conv_layer=None, + norm_layer=None, + proj_layer=None, + drop_path_rate=0., + ): super().__init__() first_dilation = first_dilation or dilation act_layer = act_layer or nn.ReLU @@ -229,7 +253,8 @@ def __init__( self.act3 = act_layer(inplace=True) def zero_init_last(self): - nn.init.zeros_(self.norm3.weight) + if getattr(self.norm3, 'weight', None) is not None: + nn.init.zeros_(self.norm3.weight) def forward(self, x): # shortcut branch @@ -283,9 +308,22 @@ def forward(self, x): class ResNetStage(nn.Module): """ResNet Stage.""" def __init__( - self, in_chs, out_chs, stride, dilation, depth, bottle_ratio=0.25, groups=1, - avg_down=False, block_dpr=None, block_fn=PreActBottleneck, - act_layer=None, conv_layer=None, norm_layer=None, **block_kwargs): + self, + in_chs, + out_chs, + stride, + dilation, + depth, + bottle_ratio=0.25, + groups=1, + avg_down=False, + block_dpr=None, + block_fn=PreActBottleneck, + act_layer=None, + conv_layer=None, + norm_layer=None, + **block_kwargs, + ): super(ResNetStage, self).__init__() first_dilation = 1 if dilation in (1, 2) else 2 layer_kwargs = dict(act_layer=act_layer, conv_layer=conv_layer, norm_layer=norm_layer) @@ -313,8 +351,13 @@ def is_stem_deep(stem_type): def create_resnetv2_stem( - in_chs, out_chs=64, stem_type='', preact=True, - conv_layer=StdConv2d, norm_layer=partial(GroupNormAct, num_groups=32)): + in_chs, + out_chs=64, + stem_type='', + preact=True, + conv_layer=StdConv2d, + norm_layer=partial(GroupNormAct, num_groups=32), +): stem = OrderedDict() assert stem_type in ('', 'fixed', 'same', 'deep', 'deep_fixed', 'deep_same', 'tiered') @@ -357,11 +400,25 @@ class ResNetV2(nn.Module): """ def __init__( - self, layers, channels=(256, 512, 1024, 2048), - num_classes=1000, in_chans=3, global_pool='avg', output_stride=32, - width_factor=1, stem_chs=64, stem_type='', avg_down=False, preact=True, - act_layer=nn.ReLU, conv_layer=StdConv2d, norm_layer=partial(GroupNormAct, num_groups=32), - drop_rate=0., drop_path_rate=0., zero_init_last=False): + self, + layers, + channels=(256, 512, 1024, 2048), + num_classes=1000, + in_chans=3, + global_pool='avg', + output_stride=32, + width_factor=1, + stem_chs=64, + stem_type='', + avg_down=False, + preact=True, + act_layer=nn.ReLU, + conv_layer=StdConv2d, + norm_layer=partial(GroupNormAct, num_groups=32), + drop_rate=0., + drop_path_rate=0., + zero_init_last=False, + ): super().__init__() self.num_classes = num_classes self.drop_rate = drop_rate @@ -387,8 +444,18 @@ def __init__( dilation *= stride stride = 1 stage = ResNetStage( - prev_chs, out_chs, stride=stride, dilation=dilation, depth=d, avg_down=avg_down, - act_layer=act_layer, conv_layer=conv_layer, norm_layer=norm_layer, block_dpr=bdpr, block_fn=block_fn) + prev_chs, + out_chs, + stride=stride, + dilation=dilation, + depth=d, + avg_down=avg_down, + act_layer=act_layer, + conv_layer=conv_layer, + norm_layer=norm_layer, + block_dpr=bdpr, + block_fn=block_fn, + ) prev_chs = out_chs curr_stride *= stride self.feature_info += [dict(num_chs=prev_chs, reduction=curr_stride, module=f'stages.{stage_idx}')] diff --git a/timm/models/sknet.py b/timm/models/sknet.py index 5a29b9a43e..425bd7c219 100644 --- a/timm/models/sknet.py +++ b/timm/models/sknet.py @@ -47,9 +47,24 @@ class SelectiveKernelBasic(nn.Module): expansion = 1 def __init__( - self, inplanes, planes, stride=1, downsample=None, cardinality=1, base_width=64, - sk_kwargs=None, reduce_first=1, dilation=1, first_dilation=None, act_layer=nn.ReLU, - norm_layer=nn.BatchNorm2d, attn_layer=None, aa_layer=None, drop_block=None, drop_path=None): + self, + inplanes, + planes, + stride=1, + downsample=None, + cardinality=1, + base_width=64, + sk_kwargs=None, + reduce_first=1, + dilation=1, + first_dilation=None, + act_layer=nn.ReLU, + norm_layer=nn.BatchNorm2d, + attn_layer=None, + aa_layer=None, + drop_block=None, + drop_path=None, + ): super(SelectiveKernelBasic, self).__init__() sk_kwargs = sk_kwargs or {} @@ -71,7 +86,8 @@ def __init__( self.drop_path = drop_path def zero_init_last(self): - nn.init.zeros_(self.conv2.bn.weight) + if getattr(self.conv2.bn, 'weight', None) is not None: + nn.init.zeros_(self.conv2.bn.weight) def forward(self, x): shortcut = x @@ -92,9 +108,24 @@ class SelectiveKernelBottleneck(nn.Module): expansion = 4 def __init__( - self, inplanes, planes, stride=1, downsample=None, cardinality=1, base_width=64, sk_kwargs=None, - reduce_first=1, dilation=1, first_dilation=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, - attn_layer=None, aa_layer=None, drop_block=None, drop_path=None): + self, + inplanes, + planes, + stride=1, + downsample=None, + cardinality=1, + base_width=64, + sk_kwargs=None, + reduce_first=1, + dilation=1, + first_dilation=None, + act_layer=nn.ReLU, + norm_layer=nn.BatchNorm2d, + attn_layer=None, + aa_layer=None, + drop_block=None, + drop_path=None, + ): super(SelectiveKernelBottleneck, self).__init__() sk_kwargs = sk_kwargs or {} @@ -115,7 +146,8 @@ def __init__( self.drop_path = drop_path def zero_init_last(self): - nn.init.zeros_(self.conv3.bn.weight) + if getattr(self.conv3.bn, 'weight', None) is not None: + nn.init.zeros_(self.conv3.bn.weight) def forward(self, x): shortcut = x