Skip to content

fix models for PyTorch v0.4 (remove .data and add _ for the initializations … #481

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Apr 30, 2018

Conversation

moskomule
Copy link
Contributor

Hi, it's about #479 .
Some models warn about initialization because of using nn.init.**(tensor) instead of nn.init.**_(tensor) so I moved them to nn.init.**_(). I also removed Variable.data.

elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
m.weight.fill_(1)

This comment was marked as off-topic.

Copy link
Member

@soumith soumith left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the .data is equivalent of: dont record history as I do these operations.
I think it needs to be preserved. We dont want to have an autograd graph defined (or backproping through) the initialization operations

@soumith
Copy link
Member

soumith commented Apr 27, 2018

the nn.init.**_() are great. please remove .data changes

@fmassa
Copy link
Member

fmassa commented Apr 27, 2018

I think another option (instead of using .data) would be to use with torch.no_grad():. Would that be more in line with best practices for v0.4?

@soumith
Copy link
Member

soumith commented Apr 27, 2018

no, let's not have to use with torch.no_grad in exchange for verbosity / readability in these simple ccases

@fmassa
Copy link
Member

fmassa commented Apr 27, 2018

I have another proposal: replace the hand-coded initialization with torch.nn.init. They internally use torch.no_grad(), so we can also remove the .data from the code. What do you think?

@moskomule
Copy link
Contributor Author

So m.weight.fill_(1), for example, to nn.init.constant_(m.weight, 1)?

@fmassa
Copy link
Member

fmassa commented Apr 28, 2018

@moskomule yes, precisely

Copy link
Member

@fmassa fmassa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the modifications! I still have some comments that could simplify some parts of the code.
Also, could you revert the unnecessary line changes / spaces added?

@@ -48,15 +48,15 @@ def _initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
nn.init.normal_(m.weight, 0, math.sqrt(2. / n))

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

@@ -4,10 +4,8 @@
import torch.nn.init as init
import torch.utils.model_zoo as model_zoo


This comment was marked as off-topic.

This comment was marked as off-topic.

@@ -113,18 +111,18 @@ def __init__(self, block, layers, num_classes=1000):
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
nn.init.normal_(m.weight, 0, math.sqrt(2. / n))

This comment was marked as off-topic.

nn.Conv2d(self.inplanes, planes * block.expansion,
kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(planes * block.expansion),
nn.Conv2d(self.inplanes, planes * block.expansion,

This comment was marked as off-topic.

This comment was marked as off-topic.

m.weight.data.copy_(values)
values = torch.Tensor(X.rvs(m.weight.numel()))
values = values.view(m.weight.size())
m.weight.copy_(values)

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

@fmassa fmassa mentioned this pull request Apr 29, 2018
@moskomule
Copy link
Contributor Author

Thanks for reviewing. For the styles, I've fixed. About the simplicity, can you check the comments?

@moskomule
Copy link
Contributor Author

For inception.py, if we have truncated_normal, then it will be simpler.

Copy link
Member

@fmassa fmassa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Almost good. A few more comments

@@ -112,11 +112,10 @@ def __init__(self, block, layers, num_classes=1000):

for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
nn.init.kaiming_normal_(m.weight, mode="fan_out")

This comment was marked as off-topic.

This comment was marked as off-topic.

@@ -130,11 +130,11 @@ def __init__(self, num_input_features, growth_rate, bn_size, drop_rate):
self.add_module('norm1', nn.BatchNorm2d(num_input_features)),
self.add_module('relu1', nn.ReLU(inplace=True)),
self.add_module('conv1', nn.Conv2d(num_input_features, bn_size *
growth_rate, kernel_size=1, stride=1, bias=False)),
growth_rate, kernel_size=1, stride=1, bias=False)),

This comment was marked as off-topic.

@@ -47,16 +47,15 @@ def forward(self, x):
def _initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
nn.init.kaiming_normal_(m.weight, mode="fan_out")

This comment was marked as off-topic.

@fmassa fmassa merged commit f87a896 into pytorch:master Apr 30, 2018
@fmassa
Copy link
Member

fmassa commented Apr 30, 2018

Looks great, thanks @moskomule ! Build failures are unrelated

varunagrawal pushed a commit to varunagrawal/vision that referenced this pull request Jul 23, 2018
…ations … (pytorch#481)

* fix for PyTorch v0.4 (remove .data and add _ for the initializations in nn.init)

* fix m.**.**() style to nn.init.**(**) style

* remove .idea

* fix lines and indents

* fix lines and indents

* change to use `kaming_normal_`

* add `.data` for safety

* add nonlinearity='relu' for sure

* fix indents
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants