Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GPTNeo] create local attention mask ones #11335

Merged
merged 2 commits into from
Apr 20, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 69 additions & 60 deletions src/transformers/models/gpt_neo/modeling_gpt_neo.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,57 @@ def _look_back(tensor, block_length, window_size, pad_value=0, is_key_value=True
padded_tensor = padded_tensor.transpose(-2, -1)
return padded_tensor

@staticmethod
def _split_seq_length_dim_to(tensors, dim_factor_1, dim_factor_2):
"""
Splits sequence length dim of tensors into `dim_factor_1` and `dim_factor_2` dims
"""
batch_size = tensors.shape[0]
split_dim_shape = (batch_size, dim_factor_1, dim_factor_2)

if len(tensors.shape) == 3:
return torch.reshape(tensors, split_dim_shape + (-1,))
elif len(tensors.shape) == 2:
return torch.reshape(tensors, split_dim_shape)
else:
raise ValueError(f"Input vector rank should be one of [2, 3], but is: {len(tensors.shape)}")

@staticmethod
def create_local_attention_mask(batch_size, seq_length, window_size, device, attention_mask=None):
block_length, num_blocks = GPTNeoAttentionMixin._get_block_length_and_num_blocks(seq_length, window_size)
indices = torch.arange(seq_length, dtype=torch.long, device=device).repeat(batch_size, 1)

query_indices = GPTNeoAttentionMixin._split_seq_length_dim_to(indices, num_blocks, block_length)
key_indices = GPTNeoAttentionMixin._look_back(indices, block_length, window_size, is_key_value=False)

# create mask tensor such that each block contains a causal_mask for that block
causal_mask = torch.ge(query_indices.unsqueeze(-1), key_indices.unsqueeze(-2))

if attention_mask is None:
attention_mask = torch.ones(batch_size, seq_length, dtype=torch.long, device=device)

# A block can also be padded becuase of the _look_back operation
# look back into the attention_block such that it will also get padded the same way
# and have 0s in the padded position
attention_mask = GPTNeoAttentionMixin._look_back(attention_mask, block_length, window_size, is_key_value=False)
attention_mask = attention_mask.unsqueeze(-2) # Add an extra dimension to account for hidden_dim

# Multiply the causal_mask with attention_mask so the padded positions (by _look_back operation)
# will contain 0s.
# This also makes sure that other positions ignored by the attention_mask will also be ignored
# in the causal_mask.
causal_mask = causal_mask * attention_mask

# In GPT Neo's local attention each window can attend to at most window_size tokens
# rest of the tokens should be ignored.
relative_position = key_indices.unsqueeze(-2) - query_indices.unsqueeze(-1)
visible = torch.gt(relative_position, -window_size)

causal_mask = causal_mask * visible
causal_mask = causal_mask.unsqueeze(-3).bool() # Add an extra dimension to account for num_heads

return causal_mask

def _split_heads(self, tensor, num_heads, attn_head_size):
"""
Splits hidden_size dim into attn_head_size and num_heads
Expand All @@ -218,20 +269,6 @@ def _merge_heads(self, tensor, num_heads, attn_head_size):
new_shape = tensor.size()[:-2] + (num_heads * attn_head_size,)
return tensor.view(new_shape)

def _split_seq_length_dim_to(self, tensors, dim_factor_1, dim_factor_2, hidden_size):
"""
Splits sequence length dim of tensors into `dim_factor_1` and `dim_factor_2` dims
"""
batch_size = tensors.shape[0]
split_dim_shape = (batch_size, dim_factor_1, dim_factor_2)

if len(tensors.shape) == 3:
return torch.reshape(tensors, split_dim_shape + (hidden_size,))
elif len(tensors.shape) == 2:
return torch.reshape(tensors, split_dim_shape)
else:
raise ValueError(f"Input vector rank should be one of [2, 3], but is: {len(tensors.shape)}")

def _attn(self, query, key, value, causal_mask, masked_bias, attn_dropout, attention_mask=None, head_mask=None):
# Keep the attention weights computation in fp32 to avoid overflow issues
query = query.to(torch.float32)
Expand Down Expand Up @@ -289,8 +326,8 @@ def __init__(self, config):
def forward(
self,
hidden_states,
layer_past=None,
attention_mask=None,
layer_past=None,
head_mask=None,
use_cache=False,
output_attentions=False,
Expand Down Expand Up @@ -357,45 +394,11 @@ def __init__(self, config):

self.window_size = config.window_size

def _create_attention_mask(self, batch_size, seq_length, num_blocks, block_length, device, attention_mask=None):
indices = torch.arange(seq_length, dtype=torch.long, device=device).repeat(batch_size, 1)

query_indices = self._split_seq_length_dim_to(indices, num_blocks, block_length, self.embed_dim)
key_indices = self._look_back(indices, block_length, self.window_size, is_key_value=False)

# create mask tensor such that each block contains a causal_mask for that block
causal_mask = torch.ge(query_indices.unsqueeze(-1), key_indices.unsqueeze(-2))

if attention_mask is None:
attention_mask = torch.ones(batch_size, seq_length, dtype=torch.long, device=device)

# A block can also be padded becuase of the _look_back operation
# look back into the attention_block such that it will also get padded the same way
# and have 0s in the padded position
attention_mask = self._look_back(attention_mask, block_length, self.window_size, is_key_value=False)
attention_mask = attention_mask.unsqueeze(-2) # Add an extra dimension to account for hidden_dim

# Multiply the causal_mask with attention_mask so the padded positions (by _look_back operation)
# will contain 0s.
# This also makes sure that other positions ignored by the attention_mask will also be ignored
# in the causal_mask.
causal_mask = causal_mask * attention_mask

# In GPT Neo's local attention each window can attend to at most window_size tokens
# rest of the tokens should be ignored.
relative_position = key_indices.unsqueeze(-2) - query_indices.unsqueeze(-1)
visible = torch.gt(relative_position, -self.window_size)

causal_mask = causal_mask * visible
causal_mask = causal_mask.unsqueeze(-3).bool() # Add an extra dimension to account for num_heads

return causal_mask

def forward(
self,
hidden_states,
attention_mask,
layer_past=None,
attention_mask=None,
head_mask=None,
use_cache=False,
output_attentions=False,
Expand All @@ -421,9 +424,9 @@ def forward(
# create buckets
if layer_past is not None:
# we just need 1 block with block_length 1 when caching is enabled
query = self._split_seq_length_dim_to(query, 1, 1, self.embed_dim)
query = self._split_seq_length_dim_to(query, 1, 1)
else:
query = self._split_seq_length_dim_to(query, num_blocks, block_length, self.embed_dim)
query = self._split_seq_length_dim_to(query, num_blocks, block_length)

key = self._look_back(key, block_length, self.window_size)
value = self._look_back(value, block_length, self.window_size)
Expand All @@ -437,18 +440,16 @@ def forward(
key = self._split_heads(key, self.num_heads, self.head_dim)
value = self._split_heads(value, self.num_heads, self.head_dim)

mask = self._create_attention_mask(
batch_size, full_seq_length, num_blocks, block_length, hidden_states.device, attention_mask
)
if layer_past is not None:
mask = mask[:, -1:, :, -1:, :] # only take the mask for the last block
# only take the mask for the last block
attention_mask = attention_mask[:, -1:, :, -1:, :]

# attn
attn_output, attn_weights = self._attn(
query,
key,
value,
causal_mask=mask,
causal_mask=attention_mask,
masked_bias=self.masked_bias,
attn_dropout=self.attn_dropout,
head_mask=head_mask,
Expand Down Expand Up @@ -495,8 +496,8 @@ def forward(
):
outputs = self.attention(
hidden_states,
layer_past=layer_past,
attention_mask=attention_mask,
layer_past=layer_past,
head_mask=head_mask,
use_cache=use_cache,
output_attentions=output_attentions,
Expand Down Expand Up @@ -765,8 +766,9 @@ def forward(
past_key_values = tuple([None] * len(self.h))
else:
past_length = past_key_values[0][0].size(-2)

device = input_ids.device if input_ids is not None else inputs_embeds.device
if position_ids is None:
device = input_ids.device if input_ids is not None else inputs_embeds.device
position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])

Expand All @@ -791,6 +793,13 @@ def forward(
else:
global_attention_mask = None

# Local causal attention mask
batch_size, seq_length = input_shape
full_seq_length = seq_length + past_length
local_attention_mask = GPTNeoAttentionMixin.create_local_attention_mask(
batch_size, full_seq_length, self.config.window_size, device, attention_mask
)

# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x num_headss x N x N
Expand All @@ -815,7 +824,7 @@ def forward(
all_hidden_states = () if output_hidden_states else None
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
attn_type = self.config.attention_layers[i]
attn_mask = global_attention_mask if attn_type == "global" else attention_mask
attn_mask = global_attention_mask if attn_type == "global" else local_attention_mask

if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
Expand Down
20 changes: 14 additions & 6 deletions tests/test_modeling_gpt_neo.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
GPTNeoForCausalLM,
GPTNeoModel,
)
from transformers.models.gpt_neo.modeling_gpt_neo import GPTNeoAttentionMixin, GPTNeoLocalSelfAttention
from transformers.models.gpt_neo.modeling_gpt_neo import GPTNeoAttentionMixin


class GPTNeoModelTester:
Expand Down Expand Up @@ -497,12 +497,14 @@ def test_look_back(self):

def test_create_attention_mask(self):
config = GPTNeoConfig.from_pretrained("valhalla/gpt-neo-random-tiny")
layer = GPTNeoLocalSelfAttention(config)
window_size = config.window_size
batch_size, seq_length = 8, 1
block_length, num_blocks = GPTNeoAttentionMixin._get_block_length_and_num_blocks(seq_length, window_size)

causal_mask = layer._create_attention_mask(batch_size, seq_length, num_blocks, block_length, torch_device)
# causal_mask = layer._create_attention_mask(batch_size, seq_length, num_blocks, block_length, torch_device)
causal_mask = GPTNeoAttentionMixin.create_local_attention_mask(
batch_size, seq_length, config.window_size, torch_device
)
# check shapes
expected_shape = [batch_size, num_blocks, 1, block_length, window_size + block_length]
self.assertListEqual(list(causal_mask.shape), expected_shape)
Expand All @@ -516,8 +518,11 @@ def test_create_attention_mask(self):
attention_mask = torch.ones(batch_size, seq_length, dtype=torch.long, device=torch_device)
attention_mask[:, -3:] = 0 # don't attend last 3 tokens

causal_mask = layer._create_attention_mask(
batch_size, seq_length, num_blocks, block_length, torch_device, attention_mask
# causal_mask = layer._create_attention_mask(
# batch_size, seq_length, num_blocks, block_length, torch_device, attention_mask
# )
causal_mask = GPTNeoAttentionMixin.create_local_attention_mask(
batch_size, seq_length, config.window_size, torch_device, attention_mask
)
# last 3 tokens will be in the last block and shoul have 0s in causal_mask
self.assertTrue(torch.all(causal_mask[:, -1, :, :, -3:] == 0))
Expand All @@ -539,8 +544,11 @@ def test_local_attn_probs(self):
mask_tokens = 3
attention_mask = torch.ones(batch_size, seq_length, device=torch_device, dtype=torch.long)
attention_mask[:, -mask_tokens:] = 0 # dont atten last mask_tokens
local_causal_mask = GPTNeoAttentionMixin.create_local_attention_mask(
batch_size, seq_length, model.config.window_size, torch_device, attention_mask
)

_, attn_probs = layer(hidden_states, attention_mask=attention_mask, output_attentions=True)
_, attn_probs = layer(hidden_states, attention_mask=local_causal_mask, output_attentions=True)

# the last 3 tokens will be in the last block, and should have 0 attn_probs
self.assertTrue(torch.all(attn_probs[:, -1, :, -mask_tokens:, -mask_tokens:] == 0))
Expand Down