Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Can't get (global) attention probs using Longformer #5646

Closed
2 of 4 tasks
k141303 opened this issue Jul 10, 2020 · 12 comments · Fixed by #5659
Closed
2 of 4 tasks

Can't get (global) attention probs using Longformer #5646

k141303 opened this issue Jul 10, 2020 · 12 comments · Fixed by #5659

Comments

@k141303
Copy link

k141303 commented Jul 10, 2020

🐛 Bug

Information

Model I am using Longformer:

Language I am using the model on Japanese:

The problem arises when using:

  • the official example scripts: (give details below)
  • my own modified scripts: (give details below)

The tasks I am working on is:

  • an official GLUE/SQUaD task: (give the name)
  • my own task or dataset: (give details below)

To reproduce

Steps to reproduce the behavior:

  1. Set config.output_attentions=True
  2. Use global attention (sum(global_attention_mask)>0)

The following is the minimum code to reproduce the error.

import torch
from transformers import AutoModel, AutoTokenizer, AutoConfig

if __name__ == '__main__':
    config = AutoConfig.from_pretrained("allenai/longformer-base-4096", output_attentions=True)
    model = AutoModel.from_pretrained("allenai/longformer-base-4096", config=config)
    tokenizer = AutoTokenizer.from_pretrained("allenai/longformer-base-4096")
    token_ids = [[
        tokenizer.cls_token_id, 10, 11, 12,
        tokenizer.sep_token_id, 21, 22, 23,
        tokenizer.sep_token_id
    ]]
    global_attention_mask = [[1,1,1,1,1,0,0,0,0]]
    logit, *_, attention_probs = model(
        torch.LongTensor(token_ids),
        global_attention_mask=torch.LongTensor(global_attention_mask)
    )
    print(attention_probs[0].size())
$ python3 test.py
Traceback (most recent call last):
  File "test_longformer.py", line 16, in <module>
    global_attention_mask=torch.LongTensor(global_attention_mask)
  File "/uge_mnt/home/username/.local/lib/python3.6/site-packages/torch/nn/modules/module.py", line 550, in __call__
    result = self.forward(*input, **kwargs)
  File "/uge_mnt/home/username/.local/lib/python3.6/site-packages/transformers/modeling_longformer.py", line 1004, in forward
    output_hidden_states=output_hidden_states,
  File "/uge_mnt/home/username/.local/lib/python3.6/site-packages/torch/nn/modules/module.py", line 550, in __call__
    result = self.forward(*input, **kwargs)
  File "/uge_mnt/home/username/.local/lib/python3.6/site-packages/transformers/modeling_longformer.py", line 695, in forward
    layer_outputs = layer_module(hidden_states, attention_mask, output_attentions,)
  File "/uge_mnt/home/username/.local/lib/python3.6/site-packages/torch/nn/modules/module.py", line 550, in __call__
    result = self.forward(*input, **kwargs)
  File "/uge_mnt/home/username/.local/lib/python3.6/site-packages/transformers/modeling_longformer.py", line 658, in forward
    self_attn_outputs = self.attention(hidden_states, attention_mask, output_attentions=output_attentions,)
  File "/uge_mnt/home/username/.local/lib/python3.6/site-packages/torch/nn/modules/module.py", line 550, in __call__
    result = self.forward(*input, **kwargs)
  File "/uge_mnt/home/username/.local/lib/python3.6/site-packages/transformers/modeling_longformer.py", line 642, in forward
    self_outputs = self.self(hidden_states, attention_mask, output_attentions,)
  File "/uge_mnt/home/username/.local/lib/python3.6/site-packages/torch/nn/modules/module.py", line 550, in __call__
    result = self.forward(*input, **kwargs)
  File "/uge_mnt/home/username/.local/lib/python3.6/site-packages/transformers/modeling_longformer.py", line 435, in forward
    attn_probs = attn_probs.view(batch_size, self.num_heads, max_num_global_attn_indices, seq_len)
RuntimeError: shape '[1, 12, 5, 512]' is invalid for input of size 3182592

Expected behavior

The model can output attention probs for each attention head.

$ python3 test.py
torch.Size([1, 12, 4096, 5])

It would seem to work if I rewrite the target line as follows.

attn_probs = attn_probs.view(batch_size, self.num_heads, max_num_global_attn_indices, seq_len)

#attn_probs = attn_probs.view(batch_size, self.num_heads, max_num_global_attn_indices, seq_len)
attn_probs = attn_probs[:,:,:,:max_num_global_attn_indices]
attn_probs = attn_probs.permute(0, 2, 1, 3)

Environment info

  • transformers version:3.0.2
  • Platform:Ubuntu 18.04.4 LTS
  • Python version:Python 3.6.9 :: Anaconda, Inc.
  • PyTorch version (GPU?):1.5.1 (Yes)
  • Tensorflow version (GPU?):
  • Using GPU in script?:Yes
  • Using distributed or parallel set-up in script?:Yes
@patrickvonplaten
Copy link
Contributor

Hey @k141303,

Thanks a lot for the issue - I can reproduce!

@patrickvonplaten
Copy link
Contributor

Thanks a lot for your very clean issue + proposed solution. It makes it very easy to find the error and fix it :-)

BTW, in cases like this issue when you see a clear fix to the bug, Pull Requests are very welcome as well!

@k141303
Copy link
Author

k141303 commented Jul 10, 2020

Hi, @patrickvonplaten

I also thought this was the solution, but it turned out to create a new bug.

To reproduce

Steps to reproduce the behavior:

  1. Set config.output_attentions=True
  2. Use global attention (sum(global_attention_mask)>0)
  3. Use multiple GPUs
  4. max_num_global_attn_indices is different in the batch

I confirmed it with the following code. (Apply the above solution by overriding.)

import math
import torch
from torch.nn import functional as F
from transformers import LongformerModel, AutoTokenizer, AutoConfig
from transformers.modeling_longformer import LongformerSelfAttention

class MyLongformerSelfAttention(LongformerSelfAttention):
    def forward(
        self, hidden_states, attention_mask=None, output_attentions=False,
    ):

        attention_mask = attention_mask.squeeze(dim=2).squeeze(dim=1)

        # is index masked or global attention
        is_index_masked = attention_mask < 0
        is_index_global_attn = attention_mask > 0
        is_global_attn = any(is_index_global_attn.flatten())

        hidden_states = hidden_states.transpose(0, 1)

        # project hidden states
        query_vectors = self.query(hidden_states)
        key_vectors = self.key(hidden_states)
        value_vectors = self.value(hidden_states)

        seq_len, batch_size, embed_dim = hidden_states.size()
        assert (
            embed_dim == self.embed_dim
        ), f"hidden_states should have embed_dim = {self.embed_dim}, but has {embed_dim}"

        # normalize query
        query_vectors /= math.sqrt(self.head_dim)

        query_vectors = query_vectors.view(seq_len, batch_size, self.num_heads, self.head_dim).transpose(0, 1)
        key_vectors = key_vectors.view(seq_len, batch_size, self.num_heads, self.head_dim).transpose(0, 1)

        # attn_probs = (batch_size, seq_len, num_heads, window*2+1)
        attn_scores = self._sliding_chunks_query_key_matmul(
            query_vectors, key_vectors, self.one_sided_attn_window_size
        )

        # values to pad for attention probs
        remove_from_windowed_attention_mask = (attention_mask != 0).unsqueeze(dim=-1).unsqueeze(dim=-1)

        # cast to fp32/fp16 then replace 1's with -inf
        float_mask = remove_from_windowed_attention_mask.type_as(query_vectors).masked_fill(
            remove_from_windowed_attention_mask, -10000.0
        )
        # diagonal mask with zeros everywhere and -inf inplace of padding
        diagonal_mask = self._sliding_chunks_query_key_matmul(
            float_mask.new_ones(size=float_mask.size()), float_mask, self.one_sided_attn_window_size
        )

        # pad local attention probs
        attn_scores += diagonal_mask

        assert list(attn_scores.size()) == [
            batch_size,
            seq_len,
            self.num_heads,
            self.one_sided_attn_window_size * 2 + 1,
        ], f"attn_probs should be of size ({batch_size}, {seq_len}, {self.num_heads}, {self.one_sided_attn_window_size * 2 + 1}), but is of size {attn_scores.size()}"

        # compute local attention probs from global attention keys and contact over window dim
        if is_global_attn:
            # compute global attn indices required through out forward fn
            (
                max_num_global_attn_indices,
                is_index_global_attn_nonzero,
                is_local_index_global_attn_nonzero,
                is_local_index_no_global_attn_nonzero,
            ) = self._get_global_attn_indices(is_index_global_attn)
            # calculate global attn probs from global key
            global_key_attn_scores = self._concat_with_global_key_attn_probs(
                query_vectors=query_vectors,
                key_vectors=key_vectors,
                max_num_global_attn_indices=max_num_global_attn_indices,
                is_index_global_attn_nonzero=is_index_global_attn_nonzero,
                is_local_index_global_attn_nonzero=is_local_index_global_attn_nonzero,
                is_local_index_no_global_attn_nonzero=is_local_index_no_global_attn_nonzero,
            )
            # concat to attn_probs
            # (batch_size, seq_len, num_heads, extra attention count + 2*window+1)
            attn_scores = torch.cat((global_key_attn_scores, attn_scores), dim=-1)

            # free memory
            del global_key_attn_scores

        attn_probs_fp32 = F.softmax(attn_scores, dim=-1, dtype=torch.float32)  # use fp32 for numerical stability
        attn_probs = attn_probs_fp32.type_as(attn_scores)

        # free memory
        del attn_probs_fp32

        # softmax sometimes inserts NaN if all positions are masked, replace them with 0
        attn_probs = torch.masked_fill(attn_probs, is_index_masked.unsqueeze(-1).unsqueeze(-1), 0.0)

        # apply dropout
        attn_probs = F.dropout(attn_probs, p=self.dropout, training=self.training)

        value_vectors = value_vectors.view(seq_len, batch_size, self.num_heads, self.head_dim).transpose(0, 1)

        # compute local attention output with global attention value and add
        if is_global_attn:
            # compute sum of global and local attn
            attn_output = self._compute_attn_output_with_global_indices(
                value_vectors=value_vectors,
                attn_probs=attn_probs,
                max_num_global_attn_indices=max_num_global_attn_indices,
                is_index_global_attn_nonzero=is_index_global_attn_nonzero,
                is_local_index_global_attn_nonzero=is_local_index_global_attn_nonzero,
            )
        else:
            # compute local attn only
            attn_output = self._sliding_chunks_matmul_attn_probs_value(
                attn_probs, value_vectors, self.one_sided_attn_window_size
            )

        assert attn_output.size() == (batch_size, seq_len, self.num_heads, self.head_dim), "Unexpected size"
        attn_output = attn_output.transpose(0, 1).reshape(seq_len, batch_size, embed_dim).contiguous()

        # compute value for global attention and overwrite to attention output
        # TODO: remove the redundant computation
        if is_global_attn:
            global_attn_output = self._compute_global_attn_output_from_hidden(
                hidden_states=hidden_states,
                max_num_global_attn_indices=max_num_global_attn_indices,
                is_local_index_global_attn_nonzero=is_local_index_global_attn_nonzero,
                is_index_global_attn_nonzero=is_index_global_attn_nonzero,
                is_local_index_no_global_attn_nonzero=is_local_index_no_global_attn_nonzero,
                is_index_masked=is_index_masked,
            )

            # get only non zero global attn output
            nonzero_global_attn_output = global_attn_output[
                is_local_index_global_attn_nonzero[0], :, is_local_index_global_attn_nonzero[1]
            ]
            # overwrite values with global attention
            attn_output[is_index_global_attn_nonzero[::-1]] = nonzero_global_attn_output.view(
                len(is_local_index_global_attn_nonzero[0]), -1
            )

        attn_output = attn_output.transpose(0, 1)

        if output_attentions:
            if is_global_attn:
                # With global attention, return global attention probabilities only
                # batch_size x num_heads x max_num_global_attention_tokens x sequence_length
                # which is the attention weights from tokens with global attention to all tokens
                # It doesn't not return local attention
                # In case of variable number of global attantion in the rows of a batch,
                # attn_probs are padded with -10000.0 attention scores

                #attn_probs = attn_probs.view(batch_size, self.num_heads, max_num_global_attn_indices, seq_len)
                attn_probs = attn_probs[:,:,:,:max_num_global_attn_indices]
                attn_probs = attn_probs.permute(0, 2, 1, 3)
            else:
                # without global attention, return local attention probabilities
                # batch_size x num_heads x sequence_length x window_size
                # which is the attention weights of every token attending to its neighbours
                attn_probs = attn_probs.permute(0, 2, 1, 3)

        outputs = (attn_output, attn_probs) if output_attentions else (attn_output,)
        return outputs

class MyLongformerModel(LongformerModel):
    def __init__(self, config):
        super().__init__(config)
        for i, layer in enumerate(self.encoder.layer):
            layer.attention.self = MyLongformerSelfAttention(config, i)
        self.init_weights()

if __name__ == '__main__':
    config = AutoConfig.from_pretrained("allenai/longformer-base-4096", output_attentions=True)
    model = MyLongformerModel.from_pretrained("allenai/longformer-base-4096", config=config)
    tokenizer = AutoTokenizer.from_pretrained("allenai/longformer-base-4096")

    token_ids = [[
        tokenizer.cls_token_id, 10, 11, 12,
        tokenizer.sep_token_id, 21, 22, 23,
        tokenizer.sep_token_id
    ]]*2
    global_attention_mask = [[1,1,1,1,1,0,0,0,0], [1,1,1,1,1,1,1,0,0]]

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    n_gpu = torch.cuda.device_count()
    model.to(device)
    if n_gpu > 1:
        model = torch.nn.DataParallel(model)
    print(f"DEVICE:{device} N_GPU:{n_gpu}")

    logit, *_, attention_probs = model(
        torch.LongTensor(token_ids),
        global_attention_mask=torch.LongTensor(global_attention_mask)
    )

    print(attention_probs[0].size())
username@34dcdd033731:~/Python/temp$ python3 test_longformer.py
DEVICE:cuda N_GPU:4
Traceback (most recent call last):
  File "test_longformer.py", line 194, in <module>
    global_attention_mask=torch.LongTensor(global_attention_mask)
  File "/uge_mnt/home/username/.local/lib/python3.6/site-packages/torch/nn/modules/module.py", line 550, in __call__
    result = self.forward(*input, **kwargs)
  File "/uge_mnt/home/username/.local/lib/python3.6/site-packages/torch/nn/parallel/data_parallel.py", line 156, in forward
    return self.gather(outputs, self.output_device)
  File "/uge_mnt/home/username/.local/lib/python3.6/site-packages/torch/nn/parallel/data_parallel.py", line 168, in gather
    return gather(outputs, output_device, dim=self.dim)
  File "/uge_mnt/home/username/.local/lib/python3.6/site-packages/torch/nn/parallel/scatter_gather.py", line 68, in gather
    res = gather_map(outputs)
  File "/uge_mnt/home/username/.local/lib/python3.6/site-packages/torch/nn/parallel/scatter_gather.py", line 63, in gather_map
    return type(out)(map(gather_map, zip(*outputs)))
  File "/uge_mnt/home/username/.local/lib/python3.6/site-packages/torch/nn/parallel/scatter_gather.py", line 63, in gather_map
    return type(out)(map(gather_map, zip(*outputs)))
  File "/uge_mnt/home/username/.local/lib/python3.6/site-packages/torch/nn/parallel/scatter_gather.py", line 55, in gather_map
    return Gather.apply(target_device, dim, *outputs)
  File "/uge_mnt/home/username/.local/lib/python3.6/site-packages/torch/nn/parallel/_functions.py", line 68, in forward
    return comm.gather(inputs, ctx.dim, ctx.target_device)
  File "/uge_mnt/home/username/.local/lib/python3.6/site-packages/torch/cuda/comm.py", line 165, in gather
    return torch._C._gather(tensors, dim, destination)
RuntimeError: Gather got an input of invalid size: got [1, 12, 512, 7], but expected [1, 12, 512, 5]

I think there are some solutions.

For example:

  • Share max_num_global_attn_indices between GPUs.
  • Define max_num_global_attn_indices in config.

I'm sorry I can't suggest a specific solution.

@patrickvonplaten
Copy link
Contributor

Thanks for the notification - will take a look next week :-)

@k141303
Copy link
Author

k141303 commented Jul 10, 2020

Sorry, I forgot that today is Friday.
Have a good weekend :-)

For those facing the same problem.

The following is an idea for a temporary solution to the problem.
It might be helpful.

attn_probs = attn_probs.view(batch_size, self.num_heads, max_num_global_attn_indices, seq_len)

↓↓↓

                #attn_probs = attn_probs.view(batch_size, self.num_heads, max_num_global_attn_indices, seq_len)
                attn_probs = attn_probs[:,:,:,:max_num_global_attn_indices]
                attn_probs = F.pad(
                    attn_probs,
                    (0, seq_len-max_num_global_attn_indices),
                    "constant",
                    0.0,
                )
                attn_probs = attn_probs.permute(0, 2, 1, 3)
$ python3 test.py
DEVICE:cuda N_GPU:4
torch.Size([2, 12, 512, 512])

@patrickvonplaten
Copy link
Contributor

patrickvonplaten commented Jul 13, 2020

@k141303 - thanks a lot for your proposed solution. Padding to the sequence length is actually a very clean solution.
Since we are only returning global attention probs, I think logically it makes also sense to pad the other values with 0.0 since they weren't attended to for global attention => so we'll go for this here.
Instead of seq_len we will pad to window_size so that local and global attention always have the same output dimension. I think this has a slight advantage in that the output signature is more consistent.

@patrickvonplaten
Copy link
Contributor

So in this case the output would be:

torch.Size([1, 12, 512, 513])

which is the same as if only local attention would have been used.

@gui11aume
Copy link
Contributor

gui11aume commented Sep 23, 2020

@patrickvonplaten It seems that the code causing the error in commit 02a0b43 (fixed by commit 7096e47) was reintroduced at some point. The code of current commit df53643 looks like 02a0b43 instead of 7096e47.

@gui11aume
Copy link
Contributor

Also, I wonder if the output is correct. Add the following lines right after the minimum code of @k141303.

print(attention_probs[0][0,0,:5,:].sum(dim=1))
print(attention_probs[0][0,0,:,:5].sum(dim=0))

This shows that:

  1. For each head (showing only for the first), all the rows with global attention do not sum to 1.
  2. For each head (showing only for the first), all the columns with global attention do not sum to 1.

Therefore neither the rows nor the column of the attention matrices can be the attention weights from tokens with global attention to all tokens. As far as I understand from the code, the columns are actually the attention weights from all tokens to the tokens with global attention, but this is not really useful, is it? For instance, it would be more useful to know where CLS puts attention instead of knowing which tokens pay attention to CLS.

@gui11aume
Copy link
Contributor

gui11aume commented Sep 23, 2020

@patrickvonplaten I think that the global attention that should be returned is a computation intermediate of the function _compute_global_attn_output_from_hidden. It is called global_attn_probs (or global_attn_probs_float before the dropouts are applied).

If only global attention is to be returned, you could consider returning this intermediate together with the attention output of _compute_global_attn_output_from_hidden. If you assign it to attn_probs in the function forward then you are almost done (otherwise you have to recompute it).

The dimension of this intermediate are (H,G,L) where H is the number of attention heads, G is the number of tokens with global attention and L is the text length (a multiple of attention_window, which I will write W for short). If you want the output to have dimensions (H,L,W) to be congruent with the local attention, you would have to transpose it before padding. This may be very confusing because the rows of the local attention would sum to 1, whereas the the first G columns of the global attention would sum to 1 and all the others would sum to 0.

Since the dimensions of global attention are intrinsically different from those of local attention, it's probably better to leave them as (H,G,L). You could output a tuple with local attention (H,L,W) and global attention (H,G,L) instead of a single tensor. Unfortunately reconstituting full attention matrices (H,L,L) is a no go: you need Longformers precisely because this does not fit in memory.

@patrickvonplaten
Copy link
Contributor

Hey @gui11aume , good point!

I guess, either way we do it, it's not perfect for Longformer....I think the cleanest solution would actually to add a new output type called global_attentions and output both attentions and global_attentions. This is more or less the same idea as outputting two tuples that you proposed.

Opened an issue about it here: -> Feel free to open a PR if you want :-) It's not of very high prio for me at the moment - so I thought it might be a good issue to tackle for people that work with Longformer. If no one is interested in opening a PR, I'll eventually do it :-)

@gui11aume
Copy link
Contributor

I didn't want to do a PR earlier because I wasn't sure about the interface you want. Having a separate field global_attentions is much cleaner. I should be able to propose something soon and I'll continue the discussion on issue #7514.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants