Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bug in forward of attention_augmented_conv.py #21

Open
dhananjaisharma10 opened this issue Dec 29, 2020 · 2 comments
Open

Bug in forward of attention_augmented_conv.py #21

dhananjaisharma10 opened this issue Dec 29, 2020 · 2 comments

Comments

@dhananjaisharma10
Copy link

Hi! I think there's a bug at this line in the forward function. Specifically, if the attention tensor attn_out is as follows for an input image with shape (channels, h(=2), w(=3)) and self-attention channels dv = 2:

# attention values of the 6 pixels
Att tensor([[-3.5002, -1.2102],
        [-4.3694, -1.5107],
        [-4.7621, -1.6465],
        [-4.9178, -1.7003],
        [-2.2335, -0.7722],
        [-5.0056, -1.7307]], grad_fn=<SliceBackward>)

you should not reshape it directly using

attn_out = torch.reshape(attn_out, (batch, Nh, dv // Nh, height, width)) # Method 1

but instead you should use

attn_out = torch.reshape(attn_out.permute(0, 1, 3, 2), (bs, Nh, dv // Nh, H, W)) # Method 2

The output difference:

# Method 1
Att tensor([[[-3.5002, -1.2102, -4.3694],
         [-1.5107, -4.7621, -1.6465]],

        [[-4.9178, -1.7003, -2.2335],
         [-0.7722, -5.0056, -1.7307]]], grad_fn=<SliceBackward>)

vs.

# Method 2
Att tensor([[[-3.5002, -4.3694, -4.7621],
         [-4.9178, -2.2335, -5.0056]],

        [[-1.2102, -1.5107, -1.6465],
         [-1.7003, -0.7722, , -1.7307]]], grad_fn=<SliceBackward>)

Hope it helps!

@JonathanCMitchell
Copy link

It looks like you are just moving the width, and height around. What is the purpose behind this?

@JonathanCMitchell
Copy link

I can confirm that this is a bug and that you solved it. can disregard my original question. The issue was because on reshape we needed to have H*W as the last dimension

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants