Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Output global_attentions in Longformer models #7562

Merged
merged 13 commits into from
Nov 5, 2020
Prev Previous commit
Next Next commit
make style
patrickvonplaten committed Oct 6, 2020
commit 62bd7cc1d99ecb93727b86d5438f7ac6bc7c0c64
56 changes: 28 additions & 28 deletions src/transformers/modeling_longformer.py
Original file line number Diff line number Diff line change
@@ -16,30 +16,24 @@

import math
import warnings
from dataclasses import dataclass
from typing import List, Optional, Tuple

import torch
import torch.nn as nn
from torch.nn import CrossEntropyLoss, MSELoss
from torch.nn import functional as F

from dataclasses import dataclass
from typing import List, Optional, Tuple

from .activations import ACT2FN, gelu
from .configuration_longformer import LongformerConfig
from .file_utils import (
ModelOutput,
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_callable,
ModelOutput,
replace_return_docstrings,
)

from .modeling_outputs import (
MaskedLMOutput,
SequenceClassifierOutput,
TokenClassifierOutput,
)
from .modeling_outputs import MaskedLMOutput, SequenceClassifierOutput, TokenClassifierOutput
from .modeling_utils import (
PreTrainedModel,
apply_chunking_to_forward,
@@ -66,7 +60,7 @@

@dataclass
class LongformerBaseModelOutput(ModelOutput):
"""
"""
Base class for Longformer's outputs, with potential hidden states, local and global attentions.

Args:
@@ -112,7 +106,7 @@ class LongformerBaseModelOutput(ModelOutput):

@dataclass
class LongformerBaseModelOutputWithPooling(ModelOutput):
"""
"""
Base class for Longformer's outputs that also contains a pooling of the last hidden states.

Args:
@@ -164,7 +158,7 @@ class LongformerBaseModelOutputWithPooling(ModelOutput):

@dataclass
class LongformerMultipleChoiceModelOutput(ModelOutput):
"""
"""
Base class for outputs of multiple choice Longformer models.

Args:
@@ -215,7 +209,7 @@ class LongformerMultipleChoiceModelOutput(ModelOutput):

@dataclass
class LongformerQuestionAnsweringModelOutput(ModelOutput):
"""
"""
Base class for outputs of question answering Longformer models.

Args:
@@ -577,15 +571,19 @@ def forward(

if output_attentions:
if is_global_attn:
# The attention weights for tokens with global attention are
# just filler values, they were never used to compute the output.
# Fill with 0 now, the correct values are in 'global_attn_probs'.
local_attn_probs[is_index_global_attn_nonzero] = 0
# The attention weights for tokens with global attention are
# just filler values, they were never used to compute the output.
# Fill with 0 now, the correct values are in 'global_attn_probs'.
local_attn_probs[is_index_global_attn_nonzero] = 0
local_attn_probs = local_attn_probs.permute(0, 2, 1, 3)

outputs = (attn_output,) if not output_attentions \
else (attn_output, local_attn_probs, global_attn_probs) if is_global_attn \
else (attn_output, local_attn_probs)
outputs = (
(attn_output,)
if not output_attentions
else (attn_output, local_attn_probs, global_attn_probs)
if is_global_attn
else (attn_output, local_attn_probs)
)
return outputs

@staticmethod
@@ -935,9 +933,7 @@ def _compute_global_attn_output_from_hidden(
self.head_dim,
], f"global_attn_output tensor has the wrong size. Size should be {(batch_size * self.num_heads, max_num_global_attn_indices, self.head_dim)}, but is {global_attn_output.size()}."

global_attn_probs = global_attn_probs.view(
batch_size, self.num_heads, max_num_global_attn_indices, seq_len
)
global_attn_probs = global_attn_probs.view(batch_size, self.num_heads, max_num_global_attn_indices, seq_len)
global_attn_output = global_attn_output.view(
batch_size, self.num_heads, max_num_global_attn_indices, self.head_dim
)
@@ -1081,7 +1077,7 @@ def forward(
return_dict=False,
):
all_hidden_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None # All local attentions.
all_attentions = () if output_attentions else None # All local attentions.
all_global_attentions = () if output_attentions else None
for i, layer_module in enumerate(self.layer):
if output_hidden_states:
@@ -1112,19 +1108,23 @@ def custom_forward(*inputs):
all_attentions = all_attentions + (layer_outputs[1],)
# Output global attentions if they exist.
if len(layer_outputs) > 2:
all_global_attentions = all_global_attentions + (layer_outputs[2],)
all_global_attentions = all_global_attentions + (layer_outputs[2],)

# Add last layer
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)

if not return_dict:
return tuple(v for v in [hidden_states, all_hidden_states, all_attentions, all_global_attentions or None] if v is not None)
return tuple(
v
for v in [hidden_states, all_hidden_states, all_attentions, all_global_attentions or None]
if v is not None
)
return LongformerBaseModelOutput(
last_hidden_state=hidden_states,
hidden_states=all_hidden_states,
attentions=all_attentions,
global_attentions=all_global_attentions or None
global_attentions=all_global_attentions or None,
)


21 changes: 11 additions & 10 deletions tests/test_modeling_longformer.py
Original file line number Diff line number Diff line change
@@ -545,22 +545,24 @@ def test_layer_attn_probs(self):
attention_mask[0, :, :, -2:] = 10000.0
attention_mask[0, :, :, -1:] = -10000.0
attention_mask[1, :, :, 1:] = 10000.0
output_hidden_states, local_attentions, global_attentions = layer(hidden_states, attention_mask, output_attentions=True)
output_hidden_states, local_attentions, global_attentions = layer(
hidden_states, attention_mask, output_attentions=True
)

self.assertTrue(local_attentions.shape, (2, 2, 4, 8))
self.assertTrue(global_attentions.shape, (2, 2, 3, 4))

# All tokens with global attention have weight 0 in local attentions.
self.assertTrue(torch.all(local_attentions[0,:,2:4,:] == 0))
self.assertTrue(torch.all(local_attentions[1,:,1:4,:] == 0))
self.assertTrue(torch.all(local_attentions[0, :, 2:4, :] == 0))
self.assertTrue(torch.all(local_attentions[1, :, 1:4, :] == 0))

# The weight of all tokens with local attention must sum to 1.
self.assertTrue(torch.all(torch.abs(global_attentions[0,:,:2,:].sum(dim=-1) - 1) < 1e-6))
self.assertTrue(torch.all(torch.abs(global_attentions[1,:,:1,:].sum(dim=-1) - 1) < 1e-6))
self.assertTrue(torch.all(torch.abs(global_attentions[0, :, :2, :].sum(dim=-1) - 1) < 1e-6))
self.assertTrue(torch.all(torch.abs(global_attentions[1, :, :1, :].sum(dim=-1) - 1) < 1e-6))

self.assertTrue(
torch.allclose(
local_attentions[0,0,0,:],
local_attentions[0, 0, 0, :],
torch.tensor(
[0.3328, 0.0000, 0.0000, 0.0000, 0.0000, 0.3355, 0.3318, 0.0000],
dtype=torch.float32,
@@ -572,7 +574,7 @@ def test_layer_attn_probs(self):

self.assertTrue(
torch.allclose(
local_attentions[1,0,0,:],
local_attentions[1, 0, 0, :],
torch.tensor(
[0.2492, 0.2502, 0.2502, 0.0000, 0.0000, 0.2505, 0.0000, 0.0000],
dtype=torch.float32,
@@ -587,7 +589,7 @@ def test_layer_attn_probs(self):

self.assertTrue(
torch.allclose(
global_attentions[0,0,1,:],
global_attentions[0, 0, 1, :],
torch.tensor(
[0.2500, 0.2500, 0.2500, 0.2500],
dtype=torch.float32,
@@ -599,7 +601,7 @@ def test_layer_attn_probs(self):

self.assertTrue(
torch.allclose(
global_attentions[1,0,0,:],
global_attentions[1, 0, 0, :],
torch.tensor(
[0.2497, 0.2500, 0.2499, 0.2504],
dtype=torch.float32,
@@ -609,7 +611,6 @@ def test_layer_attn_probs(self):
)
)


@slow
def test_inference_no_head(self):
model = LongformerModel.from_pretrained("allenai/longformer-base-4096")