-
Notifications
You must be signed in to change notification settings - Fork 7.1k
fix models for PyTorch v0.4 (remove .data and add _ for the initializations … #481
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
torchvision/models/densenet.py
Outdated
elif isinstance(m, nn.BatchNorm2d): | ||
m.weight.data.fill_(1) | ||
m.bias.data.zero_() | ||
m.weight.fill_(1) |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the .data
is equivalent of: dont record history as I do these operations.
I think it needs to be preserved. We dont want to have an autograd graph defined (or backproping through) the initialization operations
the |
I think another option (instead of using |
no, let's not have to use |
I have another proposal: replace the hand-coded initialization with |
So |
@moskomule yes, precisely |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the modifications! I still have some comments that could simplify some parts of the code.
Also, could you revert the unnecessary line changes / spaces added?
torchvision/models/vgg.py
Outdated
@@ -48,15 +48,15 @@ def _initialize_weights(self): | |||
for m in self.modules(): | |||
if isinstance(m, nn.Conv2d): | |||
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels | |||
m.weight.data.normal_(0, math.sqrt(2. / n)) | |||
nn.init.normal_(m.weight, 0, math.sqrt(2. / n)) |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torchvision/models/squeezenet.py
Outdated
@@ -4,10 +4,8 @@ | |||
import torch.nn.init as init | |||
import torch.utils.model_zoo as model_zoo | |||
|
|||
|
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torchvision/models/resnet.py
Outdated
@@ -113,18 +111,18 @@ def __init__(self, block, layers, num_classes=1000): | |||
for m in self.modules(): | |||
if isinstance(m, nn.Conv2d): | |||
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels | |||
m.weight.data.normal_(0, math.sqrt(2. / n)) | |||
nn.init.normal_(m.weight, 0, math.sqrt(2. / n)) |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torchvision/models/resnet.py
Outdated
nn.Conv2d(self.inplanes, planes * block.expansion, | ||
kernel_size=1, stride=stride, bias=False), | ||
nn.BatchNorm2d(planes * block.expansion), | ||
nn.Conv2d(self.inplanes, planes * block.expansion, |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torchvision/models/inception.py
Outdated
m.weight.data.copy_(values) | ||
values = torch.Tensor(X.rvs(m.weight.numel())) | ||
values = values.view(m.weight.size()) | ||
m.weight.copy_(values) |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
Thanks for reviewing. For the styles, I've fixed. About the simplicity, can you check the comments? |
For |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Almost good. A few more comments
torchvision/models/resnet.py
Outdated
@@ -112,11 +112,10 @@ def __init__(self, block, layers, num_classes=1000): | |||
|
|||
for m in self.modules(): | |||
if isinstance(m, nn.Conv2d): | |||
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels | |||
m.weight.data.normal_(0, math.sqrt(2. / n)) | |||
nn.init.kaiming_normal_(m.weight, mode="fan_out") |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torchvision/models/densenet.py
Outdated
@@ -130,11 +130,11 @@ def __init__(self, num_input_features, growth_rate, bn_size, drop_rate): | |||
self.add_module('norm1', nn.BatchNorm2d(num_input_features)), | |||
self.add_module('relu1', nn.ReLU(inplace=True)), | |||
self.add_module('conv1', nn.Conv2d(num_input_features, bn_size * | |||
growth_rate, kernel_size=1, stride=1, bias=False)), | |||
growth_rate, kernel_size=1, stride=1, bias=False)), |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torchvision/models/vgg.py
Outdated
@@ -47,16 +47,15 @@ def forward(self, x): | |||
def _initialize_weights(self): | |||
for m in self.modules(): | |||
if isinstance(m, nn.Conv2d): | |||
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels | |||
m.weight.data.normal_(0, math.sqrt(2. / n)) | |||
nn.init.kaiming_normal_(m.weight, mode="fan_out") |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
Looks great, thanks @moskomule ! Build failures are unrelated |
…ations … (pytorch#481) * fix for PyTorch v0.4 (remove .data and add _ for the initializations in nn.init) * fix m.**.**() style to nn.init.**(**) style * remove .idea * fix lines and indents * fix lines and indents * change to use `kaming_normal_` * add `.data` for safety * add nonlinearity='relu' for sure * fix indents
Hi, it's about #479 .
Some models warn about initialization because of using
nn.init.**(tensor)
instead ofnn.init.**_(tensor)
so I moved them tonn.init.**_()
. I also removedVariable.data
.