diff --git a/torchvision/models/densenet.py b/torchvision/models/densenet.py index e95018a970a..82863811bbc 100644 --- a/torchvision/models/densenet.py +++ b/torchvision/models/densenet.py @@ -209,7 +209,7 @@ def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16), # Official init from torch repo. for m in self.modules(): if isinstance(m, nn.Conv2d): - nn.init.kaiming_normal(m.weight.data) + nn.init.kaiming_normal_(m.weight.data) elif isinstance(m, nn.BatchNorm2d): m.weight.data.fill_(1) m.bias.data.zero_()