Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions torchvision/models/densenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,7 @@ class DenseNet(nn.Module):
drop_rate (float) - dropout rate after each dense layer
num_classes (int) - number of classification classes
"""

def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16),
num_init_features=64, bn_size=4, drop_rate=0, num_classes=1000):

Expand Down Expand Up @@ -209,12 +210,12 @@ def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16),
# Official init from torch repo.
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal(m.weight.data)
nn.init.kaiming_normal_(m.weight)
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
m.bias.data.zero_()
nn.init.constant_(m.bias, 0)

def forward(self, x):
features = self.features(x)
Expand Down
8 changes: 4 additions & 4 deletions torchvision/models/inception.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,12 +61,12 @@ def __init__(self, num_classes=1000, aux_logits=True, transform_input=False):
import scipy.stats as stats
stddev = m.stddev if hasattr(m, 'stddev') else 0.1
X = stats.truncnorm(-2, 2, scale=stddev)
values = torch.Tensor(X.rvs(m.weight.data.numel()))
values = values.view(m.weight.data.size())
values = torch.Tensor(X.rvs(m.weight.numel()))
values = values.view(m.weight.size())
m.weight.data.copy_(values)
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)

def forward(self, x):
if self.transform_input:
Expand Down
7 changes: 3 additions & 4 deletions torchvision/models/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,11 +112,10 @@ def __init__(self, block, layers, num_classes=1000):

for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)

def _make_layer(self, block, planes, blocks, stride=1):
downsample = None
Expand Down
6 changes: 3 additions & 3 deletions torchvision/models/squeezenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,11 +89,11 @@ def __init__(self, version=1.0, num_classes=1000):
for m in self.modules():
if isinstance(m, nn.Conv2d):
if m is final_conv:
init.normal(m.weight.data, mean=0.0, std=0.01)
init.normal_(m.weight, mean=0.0, std=0.01)
else:
init.kaiming_uniform(m.weight.data)
init.kaiming_uniform_(m.weight)
if m.bias is not None:
m.bias.data.zero_()
init.constant_(m.bias, 0)

def forward(self, x):
x = self.features(x)
Expand Down
13 changes: 6 additions & 7 deletions torchvision/models/vgg.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,16 +47,15 @@ def forward(self, x):
def _initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
if m.bias is not None:
m.bias.data.zero_()
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
m.weight.data.normal_(0, 0.01)
m.bias.data.zero_()
nn.init.normal_(m.weight, 0, 0.01)
nn.init.constant_(m.bias, 0)


def make_layers(cfg, batch_norm=False):
Expand Down