1
+ # Copyright (c) Open-MMLab. All rights reserved.
2
+ # Source: https://github.com/open-mmlab/mmcv/blob/master/mmcv/cnn/vgg.py
3
+ from mmcv .runner import load_checkpoint
4
+ import torch .nn as nn
5
+ from .weight_init import constant_init , kaiming_init , normal_init
6
+
7
+
8
+ def conv3x3 (in_planes , out_planes , dilation = 1 ):
9
+ "3x3 convolution with padding"
10
+ return nn .Conv2d (
11
+ in_planes ,
12
+ out_planes ,
13
+ kernel_size = 3 ,
14
+ padding = dilation ,
15
+ dilation = dilation )
16
+
17
+
18
+ def make_vgg_layer (inplanes ,
19
+ planes ,
20
+ num_blocks ,
21
+ dilation = 1 ,
22
+ with_bn = False ,
23
+ ceil_mode = False ):
24
+ layers = []
25
+ for _ in range (num_blocks ):
26
+ layers .append (conv3x3 (inplanes , planes , dilation ))
27
+ if with_bn :
28
+ layers .append (nn .BatchNorm2d (planes ))
29
+ layers .append (nn .ReLU (inplace = True ))
30
+ inplanes = planes
31
+ layers .append (nn .MaxPool2d (kernel_size = 2 , stride = 2 , ceil_mode = ceil_mode ))
32
+
33
+ return layers
34
+
35
+
36
+ class VGG (nn .Module ):
37
+ """VGG backbone.
38
+ Args:
39
+ depth (int): Depth of vgg, from {11, 13, 16, 19}.
40
+ with_bn (bool): Use BatchNorm or not.
41
+ num_classes (int): number of classes for classification.
42
+ num_stages (int): VGG stages, normally 5.
43
+ dilations (Sequence[int]): Dilation of each stage.
44
+ out_indices (Sequence[int]): Output from which stages.
45
+ frozen_stages (int): Stages to be frozen (all param fixed). -1 means
46
+ not freezing any parameters.
47
+ bn_eval (bool): Whether to set BN layers as eval mode, namely, freeze
48
+ running stats (mean and var).
49
+ bn_frozen (bool): Whether to freeze weight and bias of BN layers.
50
+ """
51
+
52
+ arch_settings = {
53
+ 11 : (1 , 1 , 2 , 2 , 2 ),
54
+ 13 : (2 , 2 , 2 , 2 , 2 ),
55
+ 16 : (2 , 2 , 3 , 3 , 3 ),
56
+ 19 : (2 , 2 , 4 , 4 , 4 )
57
+ }
58
+
59
+ def __init__ (self ,
60
+ depth ,
61
+ with_bn = False ,
62
+ num_classes = - 1 ,
63
+ num_stages = 5 ,
64
+ dilations = (1 , 1 , 1 , 1 , 1 ),
65
+ out_indices = (0 , 1 , 2 , 3 , 4 ),
66
+ frozen_stages = - 1 ,
67
+ bn_eval = True ,
68
+ bn_frozen = False ,
69
+ ceil_mode = False ,
70
+ with_last_pool = True ):
71
+ super (VGG , self ).__init__ ()
72
+ if depth not in self .arch_settings :
73
+ raise KeyError ('invalid depth {} for vgg' .format (depth ))
74
+ assert num_stages >= 1 and num_stages <= 5
75
+ stage_blocks = self .arch_settings [depth ]
76
+ self .stage_blocks = stage_blocks [:num_stages ]
77
+ assert len (dilations ) == num_stages
78
+ assert max (out_indices ) <= num_stages
79
+
80
+ self .num_classes = num_classes
81
+ self .out_indices = out_indices
82
+ self .frozen_stages = frozen_stages
83
+ self .bn_eval = bn_eval
84
+ self .bn_frozen = bn_frozen
85
+
86
+ self .inplanes = 3
87
+ start_idx = 0
88
+ vgg_layers = []
89
+ self .range_sub_modules = []
90
+ for i , num_blocks in enumerate (self .stage_blocks ):
91
+ num_modules = num_blocks * (2 + with_bn ) + 1
92
+ end_idx = start_idx + num_modules
93
+ dilation = dilations [i ]
94
+ planes = 64 * 2 ** i if i < 4 else 512
95
+ vgg_layer = make_vgg_layer (
96
+ self .inplanes ,
97
+ planes ,
98
+ num_blocks ,
99
+ dilation = dilation ,
100
+ with_bn = with_bn ,
101
+ ceil_mode = ceil_mode )
102
+ vgg_layers .extend (vgg_layer )
103
+ self .inplanes = planes
104
+ self .range_sub_modules .append ([start_idx , end_idx ])
105
+ start_idx = end_idx
106
+ if not with_last_pool :
107
+ vgg_layers .pop (- 1 )
108
+ self .range_sub_modules [- 1 ][1 ] -= 1
109
+ self .module_name = 'features'
110
+ self .add_module (self .module_name , nn .Sequential (* vgg_layers ))
111
+
112
+ if self .num_classes > 0 :
113
+ self .classifier = nn .Sequential (
114
+ nn .Linear (512 * 7 * 7 , 4096 ),
115
+ nn .ReLU (True ),
116
+ nn .Dropout (),
117
+ nn .Linear (4096 , 4096 ),
118
+ nn .ReLU (True ),
119
+ nn .Dropout (),
120
+ nn .Linear (4096 , num_classes ),
121
+ )
122
+
123
+ # initialize the model by random
124
+ self .init_weights ()
125
+ # Optionally freeze (requires_grad=False) parts of the backbone
126
+ self ._freeze_backbone (self .frozen_stages )
127
+
128
+ def _freeze_backbone (self , freeze_at ):
129
+ if freeze_at < 0 :
130
+ return
131
+
132
+ vgg_layers = getattr (self , self .module_name )
133
+ for i in range (freeze_at ):
134
+ for j in range (* self .range_sub_modules [i ]):
135
+ mod = vgg_layers [j ]
136
+ mod .eval ()
137
+ for param in mod .parameters ():
138
+ param .requires_grad = False
139
+
140
+ def init_weights (self , pretrained = None ):
141
+ if isinstance (pretrained , str ):
142
+ load_checkpoint (self , pretrained , strict = False )
143
+ elif pretrained is None :
144
+ for m in self .modules ():
145
+ if isinstance (m , nn .Conv2d ):
146
+ kaiming_init (m )
147
+ elif isinstance (m , nn .BatchNorm2d ):
148
+ constant_init (m , 1 )
149
+ elif isinstance (m , nn .Linear ):
150
+ normal_init (m , std = 0.01 )
151
+ else :
152
+ raise TypeError ('pretrained must be a str or None' )
153
+
154
+ def forward (self , x ):
155
+ outs = []
156
+ vgg_layers = getattr (self , self .module_name )
157
+ for i , num_blocks in enumerate (self .stage_blocks ):
158
+ for j in range (* self .range_sub_modules [i ]):
159
+ vgg_layer = vgg_layers [j ]
160
+ x = vgg_layer (x )
161
+ if i in self .out_indices :
162
+ outs .append (x )
163
+ if self .num_classes > 0 :
164
+ x = x .view (x .size (0 ), - 1 )
165
+ x = self .classifier (x )
166
+ outs .append (x )
167
+ if len (outs ) == 1 :
168
+ return outs [0 ]
169
+ else :
170
+ return tuple (outs )
0 commit comments