Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@

## What's New

### Jan 30, 2012
* Add initial "Normalization Free" NF-RegNet-B* and NF-ResNet model definitions based on [paper](https://arxiv.org/abs/2101.08692)

### Jan 25, 2021
* Add ResNetV2 Big Transfer (BiT) models w/ ImageNet-1k and 21k weights from https://github.com/google-research/big_transfer
* Add official R50+ViT-B/16 hybrid models + weights from https://github.com/google-research/vision_transformer
Expand Down Expand Up @@ -164,6 +167,7 @@ A full version of the list below with source links can be found in the [document
* Inception-ResNet-V2 and Inception-V4 - https://arxiv.org/abs/1602.07261
* MobileNet-V3 (MBConvNet w/ Efficient Head) - https://arxiv.org/abs/1905.02244
* NASNet-A - https://arxiv.org/abs/1707.07012
* NF-RegNet / NF-ResNet - https://arxiv.org/abs/2101.08692
* PNasNet - https://arxiv.org/abs/1712.00559
* RegNet - https://arxiv.org/abs/2003.13678
* ResNet/ResNeXt
Expand Down
1 change: 1 addition & 0 deletions timm/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from .inception_v4 import *
from .mobilenetv3 import *
from .nasnet import *
from .nfnet import *
from .pnasnet import *
from .regnet import *
from .res2net import *
Expand Down
5 changes: 3 additions & 2 deletions timm/models/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,13 @@
from .conv2d_same import Conv2dSame, conv2d_same
from .conv_bn_act import ConvBnAct
from .create_act import create_act_layer, get_act_layer, get_act_fn
from .create_attn import create_attn
from .create_attn import get_attn, create_attn
from .create_conv2d import create_conv2d
from .create_norm_act import create_norm_act, get_norm_act_layer
from .drop import DropBlock2d, DropPath, drop_block_2d, drop_path
from .eca import EcaModule, CecaModule
from .evo_norm import EvoNormBatch2d, EvoNormSample2d
from .helpers import to_ntuple, to_2tuple, to_3tuple, to_4tuple
from .helpers import to_ntuple, to_2tuple, to_3tuple, to_4tuple, make_divisible
from .inplace_abn import InplaceAbn
from .linear import Linear
from .mixed_conv2d import MixedConv2d
Expand All @@ -29,5 +29,6 @@
from .space_to_depth import SpaceToDepthModule
from .split_attn import SplitAttnConv2d
from .split_batchnorm import SplitBatchNorm2d, convert_splitbn_model
from .std_conv import StdConv2d, StdConv2dSame, ScaledStdConv2d
from .test_time_pool import TestTimePoolHead, apply_test_time_pool
from .weight_init import trunc_normal_
8 changes: 7 additions & 1 deletion timm/models/layers/create_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from .cbam import CbamModule, LightCbamModule


def create_attn(attn_type, channels, **kwargs):
def get_attn(attn_type):
module_cls = None
if attn_type is not None:
if isinstance(attn_type, str):
Expand All @@ -32,6 +32,12 @@ def create_attn(attn_type, channels, **kwargs):
module_cls = SEModule
else:
module_cls = attn_type
return module_cls


def create_attn(attn_type, channels, **kwargs):
module_cls = get_attn(attn_type)
if module_cls is not None:
# NOTE: it's expected the first (positional) argument of all attention layers is the # input channels
return module_cls(channels, **kwargs)
return None
10 changes: 7 additions & 3 deletions timm/models/layers/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@ def parse(x):
to_ntuple = _ntuple





def make_divisible(v, divisor=8, min_value=None):
min_value = min_value or divisor
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
# Make sure that round down does not go down by more than 10%.
if new_v < 0.9 * v:
new_v += divisor
return new_v
22 changes: 18 additions & 4 deletions timm/models/layers/se.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,27 @@
from torch import nn as nn
import torch.nn.functional as F

from .create_act import create_act_layer
from .helpers import make_divisible


class SEModule(nn.Module):

def __init__(self, channels, reduction=16, act_layer=nn.ReLU, min_channels=8, reduction_channels=None,
gate_layer='sigmoid'):
""" SE Module as defined in original SE-Nets with a few additions
Additions include:
* min_channels can be specified to keep reduced channel count at a minimum (default: 8)
* divisor can be specified to keep channels rounded to specified values (default: 1)
* reduction channels can be specified directly by arg (if reduction_channels is set)
* reduction channels can be specified by float ratio (if reduction_ratio is set)
"""
def __init__(self, channels, reduction=16, act_layer=nn.ReLU, gate_layer='sigmoid',
reduction_ratio=None, reduction_channels=None, min_channels=8, divisor=1):
super(SEModule, self).__init__()
reduction_channels = reduction_channels or max(channels // reduction, min_channels)
if reduction_channels is not None:
reduction_channels = reduction_channels # direct specification highest priority, no rounding/min done
elif reduction_ratio is not None:
reduction_channels = make_divisible(channels * reduction_ratio, divisor, min_channels)
else:
reduction_channels = make_divisible(channels // reduction, divisor, min_channels)
self.fc1 = nn.Conv2d(channels, reduction_channels, kernel_size=1, bias=True)
self.act = act_layer(inplace=True)
self.fc2 = nn.Conv2d(reduction_channels, channels, kernel_size=1, bias=True)
Expand Down
94 changes: 94 additions & 0 deletions timm/models/layers/std_conv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
import torch
import torch.nn as nn
import torch.nn.functional as F

from .padding import get_padding
from .conv2d_same import conv2d_same


def get_weight(module):
std, mean = torch.std_mean(module.weight, dim=[1, 2, 3], keepdim=True, unbiased=False)
weight = (module.weight - mean) / (std + module.eps)
return weight


class StdConv2d(nn.Conv2d):
"""Conv2d with Weight Standardization. Used for BiT ResNet-V2 models.

Paper: `Micro-Batch Training with Batch-Channel Normalization and Weight Standardization` -
https://arxiv.org/abs/1903.10520v2
"""
def __init__(
self, in_channel, out_channels, kernel_size, stride=1,
padding=None, dilation=1, groups=1, bias=False, eps=1e-5):
if padding is None:
padding = get_padding(kernel_size, stride, dilation)
super().__init__(
in_channel, out_channels, kernel_size, stride=stride,
padding=padding, dilation=dilation, groups=groups, bias=bias)
self.eps = eps

def get_weight(self):
std, mean = torch.std_mean(self.weight, dim=[1, 2, 3], keepdim=True, unbiased=False)
weight = (self.weight - mean) / (std + self.eps)
return weight

def forward(self, x):
x = F.conv2d(x, self.get_weight(), self.bias, self.stride, self.padding, self.dilation, self.groups)
return x


class StdConv2dSame(nn.Conv2d):
"""Conv2d with Weight Standardization. TF compatible SAME padding. Used for ViT Hybrid model.

Paper: `Micro-Batch Training with Batch-Channel Normalization and Weight Standardization` -
https://arxiv.org/abs/1903.10520v2
"""
def __init__(
self, in_channel, out_channels, kernel_size, stride=1, dilation=1, groups=1, bias=False, eps=1e-5):
super().__init__(
in_channel, out_channels, kernel_size, stride=stride,
padding=0, dilation=dilation, groups=groups, bias=bias)
self.eps = eps

def get_weight(self):
std, mean = torch.std_mean(self.weight, dim=[1, 2, 3], keepdim=True, unbiased=False)
weight = (self.weight - mean) / (std + self.eps)
return weight

def forward(self, x):
x = conv2d_same(x, self.get_weight(), self.bias, self.stride, self.padding, self.dilation, self.groups)
return x


class ScaledStdConv2d(nn.Conv2d):
"""Conv2d layer with Scaled Weight Standardization.

Paper: `Characterizing signal propagation to close the performance gap in unnormalized ResNets` -
https://arxiv.org/abs/2101.08692
"""

def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=None, dilation=1, groups=1,
bias=True, gain=True, gamma=1.0, eps=1e-5, use_layernorm=False):
if padding is None:
padding = get_padding(kernel_size, stride, dilation)
super().__init__(
in_channels, out_channels, kernel_size, stride=stride,
padding=padding, dilation=dilation, groups=groups, bias=bias)
self.gain = nn.Parameter(torch.ones(self.out_channels, 1, 1, 1)) if gain else None
self.scale = gamma * self.weight[0].numel() ** -0.5 # gamma * 1 / sqrt(fan-in)
self.eps = eps ** 2 if use_layernorm else eps
self.use_layernorm = use_layernorm # experimental, slightly faster/less GPU memory use

def get_weight(self):
if self.use_layernorm:
weight = self.scale * F.layer_norm(self.weight, self.weight.shape[1:], eps=self.eps)
else:
std, mean = torch.std_mean(self.weight, dim=[1, 2, 3], keepdim=True, unbiased=False)
weight = self.scale * (self.weight - mean) / (std + self.eps)
if self.gain is not None:
weight = weight * self.gain
return weight

def forward(self, x):
return F.conv2d(x, self.get_weight(), self.bias, self.stride, self.padding, self.dilation, self.groups)
Loading