-
Notifications
You must be signed in to change notification settings - Fork 28.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Can't get (global) attention probs using Longformer #5646
Comments
Hey @k141303, Thanks a lot for the issue - I can reproduce! |
Thanks a lot for your very clean issue + proposed solution. It makes it very easy to find the error and fix it :-) BTW, in cases like this issue when you see a clear fix to the bug, Pull Requests are very welcome as well! |
I also thought this was the solution, but it turned out to create a new bug. To reproduceSteps to reproduce the behavior:
I confirmed it with the following code. (Apply the above solution by overriding.) import math
import torch
from torch.nn import functional as F
from transformers import LongformerModel, AutoTokenizer, AutoConfig
from transformers.modeling_longformer import LongformerSelfAttention
class MyLongformerSelfAttention(LongformerSelfAttention):
def forward(
self, hidden_states, attention_mask=None, output_attentions=False,
):
attention_mask = attention_mask.squeeze(dim=2).squeeze(dim=1)
# is index masked or global attention
is_index_masked = attention_mask < 0
is_index_global_attn = attention_mask > 0
is_global_attn = any(is_index_global_attn.flatten())
hidden_states = hidden_states.transpose(0, 1)
# project hidden states
query_vectors = self.query(hidden_states)
key_vectors = self.key(hidden_states)
value_vectors = self.value(hidden_states)
seq_len, batch_size, embed_dim = hidden_states.size()
assert (
embed_dim == self.embed_dim
), f"hidden_states should have embed_dim = {self.embed_dim}, but has {embed_dim}"
# normalize query
query_vectors /= math.sqrt(self.head_dim)
query_vectors = query_vectors.view(seq_len, batch_size, self.num_heads, self.head_dim).transpose(0, 1)
key_vectors = key_vectors.view(seq_len, batch_size, self.num_heads, self.head_dim).transpose(0, 1)
# attn_probs = (batch_size, seq_len, num_heads, window*2+1)
attn_scores = self._sliding_chunks_query_key_matmul(
query_vectors, key_vectors, self.one_sided_attn_window_size
)
# values to pad for attention probs
remove_from_windowed_attention_mask = (attention_mask != 0).unsqueeze(dim=-1).unsqueeze(dim=-1)
# cast to fp32/fp16 then replace 1's with -inf
float_mask = remove_from_windowed_attention_mask.type_as(query_vectors).masked_fill(
remove_from_windowed_attention_mask, -10000.0
)
# diagonal mask with zeros everywhere and -inf inplace of padding
diagonal_mask = self._sliding_chunks_query_key_matmul(
float_mask.new_ones(size=float_mask.size()), float_mask, self.one_sided_attn_window_size
)
# pad local attention probs
attn_scores += diagonal_mask
assert list(attn_scores.size()) == [
batch_size,
seq_len,
self.num_heads,
self.one_sided_attn_window_size * 2 + 1,
], f"attn_probs should be of size ({batch_size}, {seq_len}, {self.num_heads}, {self.one_sided_attn_window_size * 2 + 1}), but is of size {attn_scores.size()}"
# compute local attention probs from global attention keys and contact over window dim
if is_global_attn:
# compute global attn indices required through out forward fn
(
max_num_global_attn_indices,
is_index_global_attn_nonzero,
is_local_index_global_attn_nonzero,
is_local_index_no_global_attn_nonzero,
) = self._get_global_attn_indices(is_index_global_attn)
# calculate global attn probs from global key
global_key_attn_scores = self._concat_with_global_key_attn_probs(
query_vectors=query_vectors,
key_vectors=key_vectors,
max_num_global_attn_indices=max_num_global_attn_indices,
is_index_global_attn_nonzero=is_index_global_attn_nonzero,
is_local_index_global_attn_nonzero=is_local_index_global_attn_nonzero,
is_local_index_no_global_attn_nonzero=is_local_index_no_global_attn_nonzero,
)
# concat to attn_probs
# (batch_size, seq_len, num_heads, extra attention count + 2*window+1)
attn_scores = torch.cat((global_key_attn_scores, attn_scores), dim=-1)
# free memory
del global_key_attn_scores
attn_probs_fp32 = F.softmax(attn_scores, dim=-1, dtype=torch.float32) # use fp32 for numerical stability
attn_probs = attn_probs_fp32.type_as(attn_scores)
# free memory
del attn_probs_fp32
# softmax sometimes inserts NaN if all positions are masked, replace them with 0
attn_probs = torch.masked_fill(attn_probs, is_index_masked.unsqueeze(-1).unsqueeze(-1), 0.0)
# apply dropout
attn_probs = F.dropout(attn_probs, p=self.dropout, training=self.training)
value_vectors = value_vectors.view(seq_len, batch_size, self.num_heads, self.head_dim).transpose(0, 1)
# compute local attention output with global attention value and add
if is_global_attn:
# compute sum of global and local attn
attn_output = self._compute_attn_output_with_global_indices(
value_vectors=value_vectors,
attn_probs=attn_probs,
max_num_global_attn_indices=max_num_global_attn_indices,
is_index_global_attn_nonzero=is_index_global_attn_nonzero,
is_local_index_global_attn_nonzero=is_local_index_global_attn_nonzero,
)
else:
# compute local attn only
attn_output = self._sliding_chunks_matmul_attn_probs_value(
attn_probs, value_vectors, self.one_sided_attn_window_size
)
assert attn_output.size() == (batch_size, seq_len, self.num_heads, self.head_dim), "Unexpected size"
attn_output = attn_output.transpose(0, 1).reshape(seq_len, batch_size, embed_dim).contiguous()
# compute value for global attention and overwrite to attention output
# TODO: remove the redundant computation
if is_global_attn:
global_attn_output = self._compute_global_attn_output_from_hidden(
hidden_states=hidden_states,
max_num_global_attn_indices=max_num_global_attn_indices,
is_local_index_global_attn_nonzero=is_local_index_global_attn_nonzero,
is_index_global_attn_nonzero=is_index_global_attn_nonzero,
is_local_index_no_global_attn_nonzero=is_local_index_no_global_attn_nonzero,
is_index_masked=is_index_masked,
)
# get only non zero global attn output
nonzero_global_attn_output = global_attn_output[
is_local_index_global_attn_nonzero[0], :, is_local_index_global_attn_nonzero[1]
]
# overwrite values with global attention
attn_output[is_index_global_attn_nonzero[::-1]] = nonzero_global_attn_output.view(
len(is_local_index_global_attn_nonzero[0]), -1
)
attn_output = attn_output.transpose(0, 1)
if output_attentions:
if is_global_attn:
# With global attention, return global attention probabilities only
# batch_size x num_heads x max_num_global_attention_tokens x sequence_length
# which is the attention weights from tokens with global attention to all tokens
# It doesn't not return local attention
# In case of variable number of global attantion in the rows of a batch,
# attn_probs are padded with -10000.0 attention scores
#attn_probs = attn_probs.view(batch_size, self.num_heads, max_num_global_attn_indices, seq_len)
attn_probs = attn_probs[:,:,:,:max_num_global_attn_indices]
attn_probs = attn_probs.permute(0, 2, 1, 3)
else:
# without global attention, return local attention probabilities
# batch_size x num_heads x sequence_length x window_size
# which is the attention weights of every token attending to its neighbours
attn_probs = attn_probs.permute(0, 2, 1, 3)
outputs = (attn_output, attn_probs) if output_attentions else (attn_output,)
return outputs
class MyLongformerModel(LongformerModel):
def __init__(self, config):
super().__init__(config)
for i, layer in enumerate(self.encoder.layer):
layer.attention.self = MyLongformerSelfAttention(config, i)
self.init_weights()
if __name__ == '__main__':
config = AutoConfig.from_pretrained("allenai/longformer-base-4096", output_attentions=True)
model = MyLongformerModel.from_pretrained("allenai/longformer-base-4096", config=config)
tokenizer = AutoTokenizer.from_pretrained("allenai/longformer-base-4096")
token_ids = [[
tokenizer.cls_token_id, 10, 11, 12,
tokenizer.sep_token_id, 21, 22, 23,
tokenizer.sep_token_id
]]*2
global_attention_mask = [[1,1,1,1,1,0,0,0,0], [1,1,1,1,1,1,1,0,0]]
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
n_gpu = torch.cuda.device_count()
model.to(device)
if n_gpu > 1:
model = torch.nn.DataParallel(model)
print(f"DEVICE:{device} N_GPU:{n_gpu}")
logit, *_, attention_probs = model(
torch.LongTensor(token_ids),
global_attention_mask=torch.LongTensor(global_attention_mask)
)
print(attention_probs[0].size()) username@34dcdd033731:~/Python/temp$ python3 test_longformer.py
DEVICE:cuda N_GPU:4
Traceback (most recent call last):
File "test_longformer.py", line 194, in <module>
global_attention_mask=torch.LongTensor(global_attention_mask)
File "/uge_mnt/home/username/.local/lib/python3.6/site-packages/torch/nn/modules/module.py", line 550, in __call__
result = self.forward(*input, **kwargs)
File "/uge_mnt/home/username/.local/lib/python3.6/site-packages/torch/nn/parallel/data_parallel.py", line 156, in forward
return self.gather(outputs, self.output_device)
File "/uge_mnt/home/username/.local/lib/python3.6/site-packages/torch/nn/parallel/data_parallel.py", line 168, in gather
return gather(outputs, output_device, dim=self.dim)
File "/uge_mnt/home/username/.local/lib/python3.6/site-packages/torch/nn/parallel/scatter_gather.py", line 68, in gather
res = gather_map(outputs)
File "/uge_mnt/home/username/.local/lib/python3.6/site-packages/torch/nn/parallel/scatter_gather.py", line 63, in gather_map
return type(out)(map(gather_map, zip(*outputs)))
File "/uge_mnt/home/username/.local/lib/python3.6/site-packages/torch/nn/parallel/scatter_gather.py", line 63, in gather_map
return type(out)(map(gather_map, zip(*outputs)))
File "/uge_mnt/home/username/.local/lib/python3.6/site-packages/torch/nn/parallel/scatter_gather.py", line 55, in gather_map
return Gather.apply(target_device, dim, *outputs)
File "/uge_mnt/home/username/.local/lib/python3.6/site-packages/torch/nn/parallel/_functions.py", line 68, in forward
return comm.gather(inputs, ctx.dim, ctx.target_device)
File "/uge_mnt/home/username/.local/lib/python3.6/site-packages/torch/cuda/comm.py", line 165, in gather
return torch._C._gather(tensors, dim, destination)
RuntimeError: Gather got an input of invalid size: got [1, 12, 512, 7], but expected [1, 12, 512, 5] I think there are some solutions. For example:
I'm sorry I can't suggest a specific solution. |
Thanks for the notification - will take a look next week :-) |
Sorry, I forgot that today is Friday. For those facing the same problem.The following is an idea for a temporary solution to the problem.
↓↓↓ #attn_probs = attn_probs.view(batch_size, self.num_heads, max_num_global_attn_indices, seq_len)
attn_probs = attn_probs[:,:,:,:max_num_global_attn_indices]
attn_probs = F.pad(
attn_probs,
(0, seq_len-max_num_global_attn_indices),
"constant",
0.0,
)
attn_probs = attn_probs.permute(0, 2, 1, 3) $ python3 test.py
DEVICE:cuda N_GPU:4
torch.Size([2, 12, 512, 512]) |
@k141303 - thanks a lot for your proposed solution. Padding to the sequence length is actually a very clean solution. |
So in this case the output would be: torch.Size([1, 12, 512, 513]) which is the same as if only local attention would have been used. |
Also, I wonder if the output is correct. Add the following lines right after the minimum code of @k141303.
This shows that:
Therefore neither the rows nor the column of the attention matrices can be |
@patrickvonplaten I think that the global attention that should be returned is a computation intermediate of the function If only global attention is to be returned, you could consider returning this intermediate together with the attention output of The dimension of this intermediate are Since the dimensions of global attention are intrinsically different from those of local attention, it's probably better to leave them as |
Hey @gui11aume , good point! I guess, either way we do it, it's not perfect for Longformer....I think the cleanest solution would actually to add a new output type called Opened an issue about it here: -> Feel free to open a PR if you want :-) It's not of very high prio for me at the moment - so I thought it might be a good issue to tackle for people that work with Longformer. If no one is interested in opening a PR, I'll eventually do it :-) |
I didn't want to do a PR earlier because I wasn't sure about the interface you want. Having a separate field |
🐛 Bug
Information
Model I am using Longformer:
Language I am using the model on Japanese:
The problem arises when using:
The tasks I am working on is:
To reproduce
Steps to reproduce the behavior:
The following is the minimum code to reproduce the error.
Expected behavior
The model can output attention probs for each attention head.
It would seem to work if I rewrite the target line as follows.
transformers/src/transformers/modeling_longformer.py
Line 435 in 02a0b43
Environment info
transformers
version:3.0.2The text was updated successfully, but these errors were encountered: