Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix-bug: fix attn transpose bug #52

Open
wants to merge 1 commit into
base: master
Choose a base branch
from

Conversation

tqnwhz
Copy link

@tqnwhz tqnwhz commented May 12, 2022

Hi, I seem to find a bug in the code.

In extract_features function of NgramTransformerDecoder, a transpose operation is applied to attn, which is the output of NgramTransformerDecoderLayer . The code snippet is as follows:

class NgramTransformerDecoder(FairseqIncrementalDecoder):
    def extract_features(self, prev_output_tokens, encoder_out=None, incremental_state=None, **unused):
        # ......
        # decoder layers
        for layer in self.layers:
            x, attn = layer(
                x,
                encoder_out['encoder_out'] if encoder_out is not None else None,
                encoder_out['encoder_padding_mask'] if encoder_out is not None else None,
                incremental_state,
                self_attn_mask=self_attn_mask,
                ngram_mask_matrix=ngram_mask_matrix,
                i_buckets_main_stream=i_buckets_main_stream,
                i_bucket_relative_stream=i_bucket_relative_stream,
                real_positions=real_positions
            )
            inner_states.append(x)
        # TODO [(1+ngram)*T, B, C] -> [B, (1+ngram)*T, C]
        x_list = x.transpose(0, 1).chunk(1 + self.ngram, 1)
        if attn is not None:
            attn_list = attn.transpose(0, 1).chunk(1 + self.ngram, 1)
        else:
            attn_list = None

        return x_list, {'attn': attn_list}

As can be seen from the code comments, it's purpose is to change the dims from [(1+ngram)*T, B, C] to [B, (1+ngram)*T, C]. The variable attn, from NgramTransformerDecoderLayer, is the second result returned by its encoder_attn(fairseq.modules.MultiheadAttention).

In fairseqv0.9.0, the code snippet of MultiheadAttention's forward function is as follows:

class MultiheadAttention(nn.Module):
    def forward(
        self,
        # ...
    ):
        # ......
        if need_weights:
            attn_weights = attn_weights_float.view(bsz, self.num_heads, tgt_len, src_len).transpose(1, 0)
            if not need_head_weights:
                # average attention weights over heads
                attn_weights = attn_weights.mean(dim=0)
        else:
            attn_weights = None

        return attn, attn_weights

It can be seen that, the second result of forward function(attn_weights), has the shape (bsz, self.num_heads, tgt_len, src_len) originally. After transpose and mean operator, it has the shape (bsz, tgt_len, src_len), which is the actual shape of attn mentioned in extract_features rather than (1+ngram)*T, B, C described in the comment. BTW, shape and transpose of x in extract_features is right. And the attn is not actually used during training and inferencing. So I guess it's the reason why it has not been found for 2 years.

But if one wants to some modification and needs to use the variable attn , like me, will find it has a confusing shape caused by the transpose operator. And it does take me some time to find the bug.

Hoping the PR can be merged.

@ghost
Copy link

ghost commented May 12, 2022

CLA assistant check
All CLA requirements met.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant