Skip to content
This repository has been archived by the owner on Feb 12, 2022. It is now read-only.

cannot assign 'torch.cuda.FloatTensor' as parameter 'weight' #79

Open
shirishr opened this issue Oct 26, 2018 · 6 comments
Open

cannot assign 'torch.cuda.FloatTensor' as parameter 'weight' #79

shirishr opened this issue Oct 26, 2018 · 6 comments

Comments

@shirishr
Copy link

Am working with pytorch version 1.0.0.dev20181019 build channel py3.6_cuda9.0.176_cudnn7.1.2_0

When I run

python -u main.py --epochs 500 --data data/wikitext-2 --clip 0.25 --dropouti 0.4 --dropouth 0.2 --nhid 1550 --nlayers 4 --seed 4002 --model QRNN --wdrop 0.1 --batch_size 20 --save WT2.pt

I get this error:

TypeError: cannot assign 'torch.cuda.FloatTensor' as parameter 'weight' (torch.nn.Parameter or None expected)

Full trace as under:

Traceback (most recent call last):
File "main.py", line 240, in
train()
File "main.py", line 196, in train
output, hidden, rnn_hs, dropped_rnn_hs = model(data, hidden, return_h=True)
File "/home/sam/anaconda3/envs/py36/lib/python3.6/site-packages/torch/nn/modules/module.py", line 477, in call
result = self.forward(*input, **kwargs)
File "/home/sam/Documents/NLP/awd-lstm-lm/model.py", line 81, in forward
raw_output, new_h = rnn(raw_output, hidden[l])
File "/home/sam/anaconda3/envs/py36/lib/python3.6/site-packages/torch/nn/modules/module.py", line 477, in call
result = self.forward(*input, **kwargs)
File "/home/sam/anaconda3/envs/py36/lib/python3.6/site-packages/torchqrnn/qrnn.py", line 70, in forward
Y = self.linear(source)
File "/home/sam/anaconda3/envs/py36/lib/python3.6/site-packages/torch/nn/modules/module.py", line 477, in call
result = self.forward(*input, **kwargs)
File "/home/sam/Documents/NLP/awd-lstm-lm/weight_drop.py", line 46, in forward
self._setweights()
File "/home/sam/Documents/NLP/awd-lstm-lm/weight_drop.py", line 43, in _setweights
setattr(self.module, name_w, w)
File "/home/sam/anaconda3/envs/py36/lib/python3.6/site-packages/torch/nn/modules/module.py", line 537, in setattr
.format(torch.typename(value), name))
TypeError: cannot assign 'torch.cuda.FloatTensor' as parameter 'weight' (torch.nn.Parameter or None expected)

@shirishr
Copy link
Author

I found out a correct way of assignment of a Tensor as a parameter is:

w = torch.nn.Parameter(........)

That means lines 43~46 in weight_drop.py should be:

            w = torch.nn.Parameter(mask.expand_as(raw_w) * raw_w)
        else:
            w = torch.nn.Parameter(torch.nn.functional.dropout(raw_w, p=self.dropout, training=self.training))

I am not closing this yet though

@JACKHAHA363
Copy link

This should be in a PR

@amirakazadeh
Copy link

hi
I got this error too and when I want to replace line 43 i got this error:

Loading cached dataset...
Applying weight drop of 0.5 to weight_hh_l0
Applying weight drop of 0.5 to weight_hh_l0
Applying weight drop of 0.5 to weight_hh_l0
[WeightDrop(
(module): LSTM(400, 1150)
), WeightDrop(
(module): LSTM(1150, 1150)
), WeightDrop(
(module): LSTM(1150, 400)
)]
Using []
Args: Namespace(alpha=2, batch_size=20, beta=1, bptt=70, clip=0.25, cuda=True, data='data/penn', dropout=0.4, dropoute=0.1, dropouth=0.25, dropouti=0.4, emsize=400, epochs=500, log_interval=200, lr=30, model='LSTM', nhid=1150, nlayers=3, nonmono=5, optimizer='sgd', resume='', save='PTB.pt', seed=141, tied=True, wdecay=1.2e-06, wdrop=0.5, when=[-1])
Model total parameters: 24221600
Traceback (most recent call last):
File "/content/drive/My Drive/Colab Notebooks/awd-lstm-lm-master/main.py", line 240, in
train()
File "/content/drive/My Drive/Colab Notebooks/awd-lstm-lm-master/main.py", line 196, in train
output, hidden, rnn_hs, dropped_rnn_hs = model(data, hidden, return_h=True)
File "/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py", line 493, in call
result = self.forward(*input, **kwargs)
File "/content/drive/My Drive/Colab Notebooks/awd-lstm-lm-master/model.py", line 81, in forward
raw_output, new_h = rnn(raw_output, hidden[l])
File "/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py", line 493, in call
result = self.forward(*input, **kwargs)
File "/content/drive/My Drive/Colab Notebooks/awd-lstm-lm-master/weight_drop.py", line 46, in forward
self._setweights()
File "/content/drive/My Drive/Colab Notebooks/awd-lstm-lm-master/weight_drop.py", line 40, in _setweights
w = torch.nn.Parameter(mask.expand_as(raw_w) * raw_w)
UnboundLocalError: local variable 'mask' referenced before assignment

can anybody help me please?!

tonghuikang added a commit to tonghuikang/awd-lstm-lm that referenced this issue Jun 7, 2019
@benjaminfspector
Copy link

@amirakazadeh

I had the same issue for a moment when my indenting was wrong. (The indenting provided in the previous answer does not match that of the repository.) Perhaps you just need to indent once more for those lines?

@acriptis
Copy link

@shirishr Thanks, your fix rocks!

@todpole3
Copy link

todpole3 commented Mar 26, 2020

I'm using Pytorch 1.1 and @shirishr's fix doesn't give the correct results for me. (The behaviors of nn.Parameter and tensor may have changed.)

The fix I ended up having is

    def _setweights(self):
        for name_w in self.weights:
            raw_w = getattr(self.module, name_w + '_raw')
            if self.variational:
                mask = torch.autograd.Variable(torch.ones(raw_w.size(0), 1))
                if raw_w.is_cuda: mask = mask.cuda()
                mask = torch.nn.functional.dropout(mask, p=self.dropout, training=True)
                w = mask.expand_as(raw_w) * raw_w
            else:
                w = torch.nn.functional.dropout(raw_w, p=self.dropout, training=self.training)
            if not self.training:
                w = w.data
            setattr(self.module, name_w, w)

The root cause of this issue is that if you pass nn.Parameter to torch.nn.functional.dropout, during training it will return a tensor as the result of dropout operation during training but nn.Parameter during evaluation since dropout is not applied. (Each time I evaluate the model on dev set during training I run into this issue.) I suspect this is a bug in Pytorch and will follow up with them.

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

No branches or pull requests

6 participants