11import math
22from collections import OrderedDict
33from functools import partial
4- from typing import Any , Callable , Optional
4+ from typing import Any , Callable , List , NamedTuple , Optional
55
66import torch
77import torch .nn as nn
88
99from .._internally_replaced_utils import load_state_dict_from_url
10+ from ..ops .misc import ConvNormActivation
1011from ..utils import _log_api_usage_once
1112
1213__all__ = [
2526}
2627
2728
29+ class ConvStemConfig (NamedTuple ):
30+ out_channels : int
31+ kernel_size : int
32+ stride : int
33+ norm_layer : Callable [..., nn .Module ] = nn .BatchNorm2d
34+ activation_layer : Callable [..., nn .Module ] = nn .ReLU
35+
36+
2837class MLPBlock (nn .Sequential ):
2938 """Transformer MLP block."""
3039
@@ -134,6 +143,7 @@ def __init__(
134143 num_classes : int = 1000 ,
135144 representation_size : Optional [int ] = None ,
136145 norm_layer : Callable [..., torch .nn .Module ] = partial (nn .LayerNorm , eps = 1e-6 ),
146+ conv_stem_configs : Optional [List [ConvStemConfig ]] = None ,
137147 ):
138148 super ().__init__ ()
139149 _log_api_usage_once (self )
@@ -148,11 +158,31 @@ def __init__(
148158 self .representation_size = representation_size
149159 self .norm_layer = norm_layer
150160
151- input_channels = 3
152-
153- # The conv_proj is a more efficient version of reshaping, permuting
154- # and projecting the input
155- self .conv_proj = nn .Conv2d (input_channels , hidden_dim , kernel_size = patch_size , stride = patch_size )
161+ if conv_stem_configs is not None :
162+ # As per https://arxiv.org/abs/2106.14881
163+ seq_proj = nn .Sequential ()
164+ prev_channels = 3
165+ for i , conv_stem_layer_config in enumerate (conv_stem_configs ):
166+ seq_proj .add_module (
167+ f"conv_bn_relu_{ i } " ,
168+ ConvNormActivation (
169+ in_channels = prev_channels ,
170+ out_channels = conv_stem_layer_config .out_channels ,
171+ kernel_size = conv_stem_layer_config .kernel_size ,
172+ stride = conv_stem_layer_config .stride ,
173+ norm_layer = conv_stem_layer_config .norm_layer ,
174+ activation_layer = conv_stem_layer_config .activation_layer ,
175+ ),
176+ )
177+ prev_channels = conv_stem_layer_config .out_channels
178+ seq_proj .add_module (
179+ "conv_last" , nn .Conv2d (in_channels = prev_channels , out_channels = hidden_dim , kernel_size = 1 )
180+ )
181+ self .conv_proj : nn .Module = seq_proj
182+ else :
183+ self .conv_proj = nn .Conv2d (
184+ in_channels = 3 , out_channels = hidden_dim , kernel_size = patch_size , stride = patch_size
185+ )
156186
157187 seq_length = (image_size // patch_size ) ** 2
158188
@@ -184,9 +214,17 @@ def __init__(
184214 self ._init_weights ()
185215
186216 def _init_weights (self ):
187- fan_in = self .conv_proj .in_channels * self .conv_proj .kernel_size [0 ] * self .conv_proj .kernel_size [1 ]
188- nn .init .trunc_normal_ (self .conv_proj .weight , std = math .sqrt (1 / fan_in ))
189- nn .init .zeros_ (self .conv_proj .bias )
217+ if isinstance (self .conv_proj , nn .Conv2d ):
218+ # Init the patchify stem
219+ fan_in = self .conv_proj .in_channels * self .conv_proj .kernel_size [0 ] * self .conv_proj .kernel_size [1 ]
220+ nn .init .trunc_normal_ (self .conv_proj .weight , std = math .sqrt (1 / fan_in ))
221+ nn .init .zeros_ (self .conv_proj .bias )
222+ else :
223+ # Init the last 1x1 conv of the conv stem
224+ nn .init .normal_ (
225+ self .conv_proj .conv_last .weight , mean = 0.0 , std = math .sqrt (2.0 / self .conv_proj .conv_last .out_channels )
226+ )
227+ nn .init .zeros_ (self .conv_proj .conv_last .bias )
190228
191229 if hasattr (self .heads , "pre_logits" ):
192230 fan_in = self .heads .pre_logits .in_features
0 commit comments