diff --git a/src/transformers/models/gpt_neo/modeling_gpt_neo.py b/src/transformers/models/gpt_neo/modeling_gpt_neo.py index bb70db7ec11956..818423b423df28 100755 --- a/src/transformers/models/gpt_neo/modeling_gpt_neo.py +++ b/src/transformers/models/gpt_neo/modeling_gpt_neo.py @@ -192,6 +192,57 @@ def _look_back(tensor, block_length, window_size, pad_value=0, is_key_value=True padded_tensor = padded_tensor.transpose(-2, -1) return padded_tensor + @staticmethod + def _split_seq_length_dim_to(tensors, dim_factor_1, dim_factor_2): + """ + Splits sequence length dim of tensors into `dim_factor_1` and `dim_factor_2` dims + """ + batch_size = tensors.shape[0] + split_dim_shape = (batch_size, dim_factor_1, dim_factor_2) + + if len(tensors.shape) == 3: + return torch.reshape(tensors, split_dim_shape + (-1,)) + elif len(tensors.shape) == 2: + return torch.reshape(tensors, split_dim_shape) + else: + raise ValueError(f"Input vector rank should be one of [2, 3], but is: {len(tensors.shape)}") + + @staticmethod + def create_local_attention_mask(batch_size, seq_length, window_size, device, attention_mask=None): + block_length, num_blocks = GPTNeoAttentionMixin._get_block_length_and_num_blocks(seq_length, window_size) + indices = torch.arange(seq_length, dtype=torch.long, device=device).repeat(batch_size, 1) + + query_indices = GPTNeoAttentionMixin._split_seq_length_dim_to(indices, num_blocks, block_length) + key_indices = GPTNeoAttentionMixin._look_back(indices, block_length, window_size, is_key_value=False) + + # create mask tensor such that each block contains a causal_mask for that block + causal_mask = torch.ge(query_indices.unsqueeze(-1), key_indices.unsqueeze(-2)) + + if attention_mask is None: + attention_mask = torch.ones(batch_size, seq_length, dtype=torch.long, device=device) + + # A block can also be padded becuase of the _look_back operation + # look back into the attention_block such that it will also get padded the same way + # and have 0s in the padded position + attention_mask = GPTNeoAttentionMixin._look_back(attention_mask, block_length, window_size, is_key_value=False) + attention_mask = attention_mask.unsqueeze(-2) # Add an extra dimension to account for hidden_dim + + # Multiply the causal_mask with attention_mask so the padded positions (by _look_back operation) + # will contain 0s. + # This also makes sure that other positions ignored by the attention_mask will also be ignored + # in the causal_mask. + causal_mask = causal_mask * attention_mask + + # In GPT Neo's local attention each window can attend to at most window_size tokens + # rest of the tokens should be ignored. + relative_position = key_indices.unsqueeze(-2) - query_indices.unsqueeze(-1) + visible = torch.gt(relative_position, -window_size) + + causal_mask = causal_mask * visible + causal_mask = causal_mask.unsqueeze(-3).bool() # Add an extra dimension to account for num_heads + + return causal_mask + def _split_heads(self, tensor, num_heads, attn_head_size): """ Splits hidden_size dim into attn_head_size and num_heads @@ -218,20 +269,6 @@ def _merge_heads(self, tensor, num_heads, attn_head_size): new_shape = tensor.size()[:-2] + (num_heads * attn_head_size,) return tensor.view(new_shape) - def _split_seq_length_dim_to(self, tensors, dim_factor_1, dim_factor_2, hidden_size): - """ - Splits sequence length dim of tensors into `dim_factor_1` and `dim_factor_2` dims - """ - batch_size = tensors.shape[0] - split_dim_shape = (batch_size, dim_factor_1, dim_factor_2) - - if len(tensors.shape) == 3: - return torch.reshape(tensors, split_dim_shape + (hidden_size,)) - elif len(tensors.shape) == 2: - return torch.reshape(tensors, split_dim_shape) - else: - raise ValueError(f"Input vector rank should be one of [2, 3], but is: {len(tensors.shape)}") - def _attn(self, query, key, value, causal_mask, masked_bias, attn_dropout, attention_mask=None, head_mask=None): # Keep the attention weights computation in fp32 to avoid overflow issues query = query.to(torch.float32) @@ -289,8 +326,8 @@ def __init__(self, config): def forward( self, hidden_states, - layer_past=None, attention_mask=None, + layer_past=None, head_mask=None, use_cache=False, output_attentions=False, @@ -357,45 +394,11 @@ def __init__(self, config): self.window_size = config.window_size - def _create_attention_mask(self, batch_size, seq_length, num_blocks, block_length, device, attention_mask=None): - indices = torch.arange(seq_length, dtype=torch.long, device=device).repeat(batch_size, 1) - - query_indices = self._split_seq_length_dim_to(indices, num_blocks, block_length, self.embed_dim) - key_indices = self._look_back(indices, block_length, self.window_size, is_key_value=False) - - # create mask tensor such that each block contains a causal_mask for that block - causal_mask = torch.ge(query_indices.unsqueeze(-1), key_indices.unsqueeze(-2)) - - if attention_mask is None: - attention_mask = torch.ones(batch_size, seq_length, dtype=torch.long, device=device) - - # A block can also be padded becuase of the _look_back operation - # look back into the attention_block such that it will also get padded the same way - # and have 0s in the padded position - attention_mask = self._look_back(attention_mask, block_length, self.window_size, is_key_value=False) - attention_mask = attention_mask.unsqueeze(-2) # Add an extra dimension to account for hidden_dim - - # Multiply the causal_mask with attention_mask so the padded positions (by _look_back operation) - # will contain 0s. - # This also makes sure that other positions ignored by the attention_mask will also be ignored - # in the causal_mask. - causal_mask = causal_mask * attention_mask - - # In GPT Neo's local attention each window can attend to at most window_size tokens - # rest of the tokens should be ignored. - relative_position = key_indices.unsqueeze(-2) - query_indices.unsqueeze(-1) - visible = torch.gt(relative_position, -self.window_size) - - causal_mask = causal_mask * visible - causal_mask = causal_mask.unsqueeze(-3).bool() # Add an extra dimension to account for num_heads - - return causal_mask - def forward( self, hidden_states, + attention_mask, layer_past=None, - attention_mask=None, head_mask=None, use_cache=False, output_attentions=False, @@ -421,9 +424,9 @@ def forward( # create buckets if layer_past is not None: # we just need 1 block with block_length 1 when caching is enabled - query = self._split_seq_length_dim_to(query, 1, 1, self.embed_dim) + query = self._split_seq_length_dim_to(query, 1, 1) else: - query = self._split_seq_length_dim_to(query, num_blocks, block_length, self.embed_dim) + query = self._split_seq_length_dim_to(query, num_blocks, block_length) key = self._look_back(key, block_length, self.window_size) value = self._look_back(value, block_length, self.window_size) @@ -437,18 +440,16 @@ def forward( key = self._split_heads(key, self.num_heads, self.head_dim) value = self._split_heads(value, self.num_heads, self.head_dim) - mask = self._create_attention_mask( - batch_size, full_seq_length, num_blocks, block_length, hidden_states.device, attention_mask - ) if layer_past is not None: - mask = mask[:, -1:, :, -1:, :] # only take the mask for the last block + # only take the mask for the last block + attention_mask = attention_mask[:, -1:, :, -1:, :] # attn attn_output, attn_weights = self._attn( query, key, value, - causal_mask=mask, + causal_mask=attention_mask, masked_bias=self.masked_bias, attn_dropout=self.attn_dropout, head_mask=head_mask, @@ -495,8 +496,8 @@ def forward( ): outputs = self.attention( hidden_states, - layer_past=layer_past, attention_mask=attention_mask, + layer_past=layer_past, head_mask=head_mask, use_cache=use_cache, output_attentions=output_attentions, @@ -765,8 +766,9 @@ def forward( past_key_values = tuple([None] * len(self.h)) else: past_length = past_key_values[0][0].size(-2) + + device = input_ids.device if input_ids is not None else inputs_embeds.device if position_ids is None: - device = input_ids.device if input_ids is not None else inputs_embeds.device position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device) position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1]) @@ -791,6 +793,13 @@ def forward( else: global_attention_mask = None + # Local causal attention mask + batch_size, seq_length = input_shape + full_seq_length = seq_length + past_length + local_attention_mask = GPTNeoAttentionMixin.create_local_attention_mask( + batch_size, full_seq_length, self.config.window_size, device, attention_mask + ) + # Prepare head mask if needed # 1.0 in head_mask indicate we keep the head # attention_probs has shape bsz x num_headss x N x N @@ -815,7 +824,7 @@ def forward( all_hidden_states = () if output_hidden_states else None for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): attn_type = self.config.attention_layers[i] - attn_mask = global_attention_mask if attn_type == "global" else attention_mask + attn_mask = global_attention_mask if attn_type == "global" else local_attention_mask if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) diff --git a/tests/test_modeling_gpt_neo.py b/tests/test_modeling_gpt_neo.py index 14d966d61b4bce..ccf63c5e241be3 100644 --- a/tests/test_modeling_gpt_neo.py +++ b/tests/test_modeling_gpt_neo.py @@ -36,7 +36,7 @@ GPTNeoForCausalLM, GPTNeoModel, ) - from transformers.models.gpt_neo.modeling_gpt_neo import GPTNeoAttentionMixin, GPTNeoLocalSelfAttention + from transformers.models.gpt_neo.modeling_gpt_neo import GPTNeoAttentionMixin class GPTNeoModelTester: @@ -497,12 +497,14 @@ def test_look_back(self): def test_create_attention_mask(self): config = GPTNeoConfig.from_pretrained("valhalla/gpt-neo-random-tiny") - layer = GPTNeoLocalSelfAttention(config) window_size = config.window_size batch_size, seq_length = 8, 1 block_length, num_blocks = GPTNeoAttentionMixin._get_block_length_and_num_blocks(seq_length, window_size) - causal_mask = layer._create_attention_mask(batch_size, seq_length, num_blocks, block_length, torch_device) + # causal_mask = layer._create_attention_mask(batch_size, seq_length, num_blocks, block_length, torch_device) + causal_mask = GPTNeoAttentionMixin.create_local_attention_mask( + batch_size, seq_length, config.window_size, torch_device + ) # check shapes expected_shape = [batch_size, num_blocks, 1, block_length, window_size + block_length] self.assertListEqual(list(causal_mask.shape), expected_shape) @@ -516,8 +518,11 @@ def test_create_attention_mask(self): attention_mask = torch.ones(batch_size, seq_length, dtype=torch.long, device=torch_device) attention_mask[:, -3:] = 0 # don't attend last 3 tokens - causal_mask = layer._create_attention_mask( - batch_size, seq_length, num_blocks, block_length, torch_device, attention_mask + # causal_mask = layer._create_attention_mask( + # batch_size, seq_length, num_blocks, block_length, torch_device, attention_mask + # ) + causal_mask = GPTNeoAttentionMixin.create_local_attention_mask( + batch_size, seq_length, config.window_size, torch_device, attention_mask ) # last 3 tokens will be in the last block and shoul have 0s in causal_mask self.assertTrue(torch.all(causal_mask[:, -1, :, :, -3:] == 0)) @@ -539,8 +544,11 @@ def test_local_attn_probs(self): mask_tokens = 3 attention_mask = torch.ones(batch_size, seq_length, device=torch_device, dtype=torch.long) attention_mask[:, -mask_tokens:] = 0 # dont atten last mask_tokens + local_causal_mask = GPTNeoAttentionMixin.create_local_attention_mask( + batch_size, seq_length, model.config.window_size, torch_device, attention_mask + ) - _, attn_probs = layer(hidden_states, attention_mask=attention_mask, output_attentions=True) + _, attn_probs = layer(hidden_states, attention_mask=local_causal_mask, output_attentions=True) # the last 3 tokens will be in the last block, and should have 0 attn_probs self.assertTrue(torch.all(attn_probs[:, -1, :, -mask_tokens:, -mask_tokens:] == 0))