Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Torch.cat error #389

Open
yimjinkyu1 opened this issue Jul 11, 2019 · 2 comments
Open

Torch.cat error #389

yimjinkyu1 opened this issue Jul 11, 2019 · 2 comments

Comments

@yimjinkyu1
Copy link

I want to insert residual layer to neural net.
But I notice below bug in amp.

Out += torch.cat((shortcut, padding),1)

RuntimeError: Expected object of scalar type Half but got scalar type Float for sequence element 1 in sequence argument at position #1 ‘tensors’

@ptrblck
Copy link
Collaborator

ptrblck commented Jul 11, 2019

Hi @yimjinkyu1,

could you post a small reproducible code snippet so that we can have a look?
This dummy example of a residual connection seems to work:

import torch
import torch.nn as nn

from apex import amp

class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.conv = nn.Conv2d(3, 3, 3, 1, 1)

    def forward(self, x):
        identity = x
        out = self.conv(x)
        out += identity

        return out

device = 'cuda'
model = MyModel().to(device)
x = torch.randn(1, 3, 24, 24, device=device)

model = amp.initialize(model, opt_level='O1')
output = model(x)

@mcarilli
Copy link
Contributor

I think it's the cat call that's failing, not the residual connection. Is that correct?
Cat is on the sequence_casts list so it should take all inputs and cast them to the widest type (in this case fp32).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants