Skip to content

Commit 99b82ae

Browse files
authored
Merge pull request #389 from rwightman/norm_free_models
Normalizer-Free RegNet and ResNet impl
2 parents 9a38416 + f0e65e3 commit 99b82ae

File tree

12 files changed

+612
-59
lines changed

12 files changed

+612
-59
lines changed

README.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@
22

33
## What's New
44

5+
### Jan 30, 2012
6+
* Add initial "Normalization Free" NF-RegNet-B* and NF-ResNet model definitions based on [paper](https://arxiv.org/abs/2101.08692)
7+
58
### Jan 25, 2021
69
* Add ResNetV2 Big Transfer (BiT) models w/ ImageNet-1k and 21k weights from https://github.com/google-research/big_transfer
710
* Add official R50+ViT-B/16 hybrid models + weights from https://github.com/google-research/vision_transformer
@@ -164,6 +167,7 @@ A full version of the list below with source links can be found in the [document
164167
* Inception-ResNet-V2 and Inception-V4 - https://arxiv.org/abs/1602.07261
165168
* MobileNet-V3 (MBConvNet w/ Efficient Head) - https://arxiv.org/abs/1905.02244
166169
* NASNet-A - https://arxiv.org/abs/1707.07012
170+
* NF-RegNet / NF-ResNet - https://arxiv.org/abs/2101.08692
167171
* PNasNet - https://arxiv.org/abs/1712.00559
168172
* RegNet - https://arxiv.org/abs/2003.13678
169173
* ResNet/ResNeXt

timm/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from .inception_v4 import *
1212
from .mobilenetv3 import *
1313
from .nasnet import *
14+
from .nfnet import *
1415
from .pnasnet import *
1516
from .regnet import *
1617
from .res2net import *

timm/models/layers/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,13 @@
1010
from .conv2d_same import Conv2dSame, conv2d_same
1111
from .conv_bn_act import ConvBnAct
1212
from .create_act import create_act_layer, get_act_layer, get_act_fn
13-
from .create_attn import create_attn
13+
from .create_attn import get_attn, create_attn
1414
from .create_conv2d import create_conv2d
1515
from .create_norm_act import create_norm_act, get_norm_act_layer
1616
from .drop import DropBlock2d, DropPath, drop_block_2d, drop_path
1717
from .eca import EcaModule, CecaModule
1818
from .evo_norm import EvoNormBatch2d, EvoNormSample2d
19-
from .helpers import to_ntuple, to_2tuple, to_3tuple, to_4tuple
19+
from .helpers import to_ntuple, to_2tuple, to_3tuple, to_4tuple, make_divisible
2020
from .inplace_abn import InplaceAbn
2121
from .linear import Linear
2222
from .mixed_conv2d import MixedConv2d
@@ -29,5 +29,6 @@
2929
from .space_to_depth import SpaceToDepthModule
3030
from .split_attn import SplitAttnConv2d
3131
from .split_batchnorm import SplitBatchNorm2d, convert_splitbn_model
32+
from .std_conv import StdConv2d, StdConv2dSame, ScaledStdConv2d
3233
from .test_time_pool import TestTimePoolHead, apply_test_time_pool
3334
from .weight_init import trunc_normal_

timm/models/layers/create_attn.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from .cbam import CbamModule, LightCbamModule
99

1010

11-
def create_attn(attn_type, channels, **kwargs):
11+
def get_attn(attn_type):
1212
module_cls = None
1313
if attn_type is not None:
1414
if isinstance(attn_type, str):
@@ -32,6 +32,12 @@ def create_attn(attn_type, channels, **kwargs):
3232
module_cls = SEModule
3333
else:
3434
module_cls = attn_type
35+
return module_cls
36+
37+
38+
def create_attn(attn_type, channels, **kwargs):
39+
module_cls = get_attn(attn_type)
3540
if module_cls is not None:
41+
# NOTE: it's expected the first (positional) argument of all attention layers is the # input channels
3642
return module_cls(channels, **kwargs)
3743
return None

timm/models/layers/helpers.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,10 @@ def parse(x):
2222
to_ntuple = _ntuple
2323

2424

25-
26-
27-
25+
def make_divisible(v, divisor=8, min_value=None):
26+
min_value = min_value or divisor
27+
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
28+
# Make sure that round down does not go down by more than 10%.
29+
if new_v < 0.9 * v:
30+
new_v += divisor
31+
return new_v

timm/models/layers/se.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,27 @@
11
from torch import nn as nn
2+
import torch.nn.functional as F
3+
24
from .create_act import create_act_layer
5+
from .helpers import make_divisible
36

47

58
class SEModule(nn.Module):
6-
7-
def __init__(self, channels, reduction=16, act_layer=nn.ReLU, min_channels=8, reduction_channels=None,
8-
gate_layer='sigmoid'):
9+
""" SE Module as defined in original SE-Nets with a few additions
10+
Additions include:
11+
* min_channels can be specified to keep reduced channel count at a minimum (default: 8)
12+
* divisor can be specified to keep channels rounded to specified values (default: 1)
13+
* reduction channels can be specified directly by arg (if reduction_channels is set)
14+
* reduction channels can be specified by float ratio (if reduction_ratio is set)
15+
"""
16+
def __init__(self, channels, reduction=16, act_layer=nn.ReLU, gate_layer='sigmoid',
17+
reduction_ratio=None, reduction_channels=None, min_channels=8, divisor=1):
918
super(SEModule, self).__init__()
10-
reduction_channels = reduction_channels or max(channels // reduction, min_channels)
19+
if reduction_channels is not None:
20+
reduction_channels = reduction_channels # direct specification highest priority, no rounding/min done
21+
elif reduction_ratio is not None:
22+
reduction_channels = make_divisible(channels * reduction_ratio, divisor, min_channels)
23+
else:
24+
reduction_channels = make_divisible(channels // reduction, divisor, min_channels)
1125
self.fc1 = nn.Conv2d(channels, reduction_channels, kernel_size=1, bias=True)
1226
self.act = act_layer(inplace=True)
1327
self.fc2 = nn.Conv2d(reduction_channels, channels, kernel_size=1, bias=True)

timm/models/layers/std_conv.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
import torch
2+
import torch.nn as nn
3+
import torch.nn.functional as F
4+
5+
from .padding import get_padding
6+
from .conv2d_same import conv2d_same
7+
8+
9+
def get_weight(module):
10+
std, mean = torch.std_mean(module.weight, dim=[1, 2, 3], keepdim=True, unbiased=False)
11+
weight = (module.weight - mean) / (std + module.eps)
12+
return weight
13+
14+
15+
class StdConv2d(nn.Conv2d):
16+
"""Conv2d with Weight Standardization. Used for BiT ResNet-V2 models.
17+
18+
Paper: `Micro-Batch Training with Batch-Channel Normalization and Weight Standardization` -
19+
https://arxiv.org/abs/1903.10520v2
20+
"""
21+
def __init__(
22+
self, in_channel, out_channels, kernel_size, stride=1,
23+
padding=None, dilation=1, groups=1, bias=False, eps=1e-5):
24+
if padding is None:
25+
padding = get_padding(kernel_size, stride, dilation)
26+
super().__init__(
27+
in_channel, out_channels, kernel_size, stride=stride,
28+
padding=padding, dilation=dilation, groups=groups, bias=bias)
29+
self.eps = eps
30+
31+
def get_weight(self):
32+
std, mean = torch.std_mean(self.weight, dim=[1, 2, 3], keepdim=True, unbiased=False)
33+
weight = (self.weight - mean) / (std + self.eps)
34+
return weight
35+
36+
def forward(self, x):
37+
x = F.conv2d(x, self.get_weight(), self.bias, self.stride, self.padding, self.dilation, self.groups)
38+
return x
39+
40+
41+
class StdConv2dSame(nn.Conv2d):
42+
"""Conv2d with Weight Standardization. TF compatible SAME padding. Used for ViT Hybrid model.
43+
44+
Paper: `Micro-Batch Training with Batch-Channel Normalization and Weight Standardization` -
45+
https://arxiv.org/abs/1903.10520v2
46+
"""
47+
def __init__(
48+
self, in_channel, out_channels, kernel_size, stride=1, dilation=1, groups=1, bias=False, eps=1e-5):
49+
super().__init__(
50+
in_channel, out_channels, kernel_size, stride=stride,
51+
padding=0, dilation=dilation, groups=groups, bias=bias)
52+
self.eps = eps
53+
54+
def get_weight(self):
55+
std, mean = torch.std_mean(self.weight, dim=[1, 2, 3], keepdim=True, unbiased=False)
56+
weight = (self.weight - mean) / (std + self.eps)
57+
return weight
58+
59+
def forward(self, x):
60+
x = conv2d_same(x, self.get_weight(), self.bias, self.stride, self.padding, self.dilation, self.groups)
61+
return x
62+
63+
64+
class ScaledStdConv2d(nn.Conv2d):
65+
"""Conv2d layer with Scaled Weight Standardization.
66+
67+
Paper: `Characterizing signal propagation to close the performance gap in unnormalized ResNets` -
68+
https://arxiv.org/abs/2101.08692
69+
"""
70+
71+
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=None, dilation=1, groups=1,
72+
bias=True, gain=True, gamma=1.0, eps=1e-5, use_layernorm=False):
73+
if padding is None:
74+
padding = get_padding(kernel_size, stride, dilation)
75+
super().__init__(
76+
in_channels, out_channels, kernel_size, stride=stride,
77+
padding=padding, dilation=dilation, groups=groups, bias=bias)
78+
self.gain = nn.Parameter(torch.ones(self.out_channels, 1, 1, 1)) if gain else None
79+
self.scale = gamma * self.weight[0].numel() ** -0.5 # gamma * 1 / sqrt(fan-in)
80+
self.eps = eps ** 2 if use_layernorm else eps
81+
self.use_layernorm = use_layernorm # experimental, slightly faster/less GPU memory use
82+
83+
def get_weight(self):
84+
if self.use_layernorm:
85+
weight = self.scale * F.layer_norm(self.weight, self.weight.shape[1:], eps=self.eps)
86+
else:
87+
std, mean = torch.std_mean(self.weight, dim=[1, 2, 3], keepdim=True, unbiased=False)
88+
weight = self.scale * (self.weight - mean) / (std + self.eps)
89+
if self.gain is not None:
90+
weight = weight * self.gain
91+
return weight
92+
93+
def forward(self, x):
94+
return F.conv2d(x, self.get_weight(), self.bias, self.stride, self.padding, self.dilation, self.groups)

0 commit comments

Comments
 (0)