Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove device parameter from create_extended_attention_mask_for_decoder #16894

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/research_projects/longform-qa/eli5_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def embed_sentences_checkpointed(self, input_ids, attention_mask, checkpoint_bat
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
head_mask = [None] * self.sent_encoder.config.num_hidden_layers
extended_attention_mask: torch.Tensor = self.sent_encoder.get_extended_attention_mask(
attention_mask, input_shape, device
attention_mask, input_shape
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This one can be updated as it's not pinned.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This one can be updated as it's not pinned.

)

# define function for checkpointing
Expand Down
20 changes: 16 additions & 4 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -594,7 +594,13 @@ def invert_attention_mask(self, encoder_attention_mask: Tensor) -> Tensor:
return encoder_extended_attention_mask

@staticmethod
def create_extended_attention_mask_for_decoder(input_shape, attention_mask, device):
def create_extended_attention_mask_for_decoder(input_shape, attention_mask, device=None):
if device is not None:
warnings.warn(
"The `device` argument is deprecated and will be removed in v5 of Transformers.", FutureWarning
)
else:
device = attention_mask.device
batch_size, seq_length = input_shape
seq_ids = torch.arange(seq_length, device=device)
causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
Expand All @@ -615,7 +621,9 @@ def create_extended_attention_mask_for_decoder(input_shape, attention_mask, devi
extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
return extended_attention_mask

def get_extended_attention_mask(self, attention_mask: Tensor, input_shape: Tuple[int], device: device) -> Tensor:
def get_extended_attention_mask(
self, attention_mask: Tensor, input_shape: Tuple[int], device: device = None
) -> Tensor:
"""
Makes broadcastable attention and causal masks so that future and masked tokens are ignored.

Expand All @@ -624,12 +632,16 @@ def get_extended_attention_mask(self, attention_mask: Tensor, input_shape: Tuple
Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
input_shape (`Tuple[int]`):
The shape of the input to the model.
device: (`torch.device`):
The device of the input to the model.

Returns:
`torch.Tensor` The extended attention mask, with a the same dtype as `attention_mask.dtype`.
"""
if not (attention_mask.dim() == 2 and self.config.is_decoder):
# show warning only if it won't be shown in `create_extended_attention_mask_for_decoder`
if device is not None:
warnings.warn(
"The `device` argument is deprecated and will be removed in v5 of Transformers.", FutureWarning
)
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
if attention_mask.dim() == 3:
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/bert/modeling_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -982,7 +982,7 @@ def forward(

# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)

# If a 2D or 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -364,9 +364,7 @@ def forward(
# ourselves in which case we just need to make it broadcastable to all heads.
extended_attention_mask = None
if not use_cache:
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(
attention_mask, input_shape, device
)
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)

# If a 2D or 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
Expand Down
4 changes: 1 addition & 3 deletions src/transformers/models/big_bird/modeling_big_bird.py
Original file line number Diff line number Diff line change
Expand Up @@ -2112,9 +2112,7 @@ def forward(
to_mask = None
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(
attention_mask, input_shape, device
)
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
else:
raise ValueError(
f"attention_type can either be original_full or block_sparse, but is {self.attention_type}"
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/canine/modeling_canine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1130,12 +1130,12 @@ def forward(

# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
molecule_attention_mask = self._downsample_attention_mask(
attention_mask, downsampling_rate=self.config.downsampling_rate
)
extended_molecule_attention_mask: torch.Tensor = self.get_extended_attention_mask(
molecule_attention_mask, (batch_size, molecule_attention_mask.shape[-1]), device
molecule_attention_mask, (batch_size, molecule_attention_mask.shape[-1])
)

# Prepare head mask if needed
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/convbert/modeling_convbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -833,7 +833,7 @@ def forward(
else:
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)

extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape, device)
extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape)
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)

hidden_states = self.embeddings(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -820,7 +820,7 @@ def forward(

# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)

# If a 2D or 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/electra/modeling_electra.py
Original file line number Diff line number Diff line change
Expand Up @@ -882,7 +882,7 @@ def forward(
else:
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)

extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape, device)
extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape)

# If a 2D or 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/ibert/modeling_ibert.py
Original file line number Diff line number Diff line change
Expand Up @@ -814,7 +814,7 @@ def forward(

# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)

# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1692,7 +1692,7 @@ def forward(

# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)[
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)[
:, 0, 0, :
]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -940,7 +940,7 @@ def forward(

# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)

# If a 2D or 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/mmbt/modeling_mmbt.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ def forward(
[torch.ones(input_modal_shape, device=device), encoder_attention_mask], dim=1
)

extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape, self.device)
extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape)
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)

Expand Down
4 changes: 1 addition & 3 deletions src/transformers/models/mobilebert/modeling_mobilebert.py
Original file line number Diff line number Diff line change
Expand Up @@ -875,9 +875,7 @@ def forward(

# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(
attention_mask, input_shape, self.device
)
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)

# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/mpnet/modeling_mpnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,7 +547,7 @@ def forward(

if attention_mask is None:
attention_mask = torch.ones(input_shape, device=device)
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)

head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
embedding_output = self.embeddings(input_ids=input_ids, position_ids=position_ids, inputs_embeds=inputs_embeds)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -624,7 +624,7 @@ def forward(

# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)

# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/qdqbert/modeling_qdqbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -952,7 +952,7 @@ def forward(

# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)

# If a 2D or 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/realm/modeling_realm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1078,7 +1078,7 @@ def forward(

# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)

# If a 2D or 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/rembert/modeling_rembert.py
Original file line number Diff line number Diff line change
Expand Up @@ -857,7 +857,7 @@ def forward(

# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)

# If a 2D or 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/retribert/modeling_retribert.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def embed_sentences_checkpointed(
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
head_mask = [None] * sent_encoder.config.num_hidden_layers
extended_attention_mask: torch.Tensor = sent_encoder.get_extended_attention_mask(
attention_mask, input_shape, device
attention_mask, input_shape
)

# define function for checkpointing
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/roberta/modeling_roberta.py
Original file line number Diff line number Diff line change
Expand Up @@ -817,7 +817,7 @@ def forward(

# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)

# If a 2D or 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/roformer/modeling_roformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -900,7 +900,7 @@ def forward(

# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)

# If a 2D or 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/splinter/modeling_splinter.py
Original file line number Diff line number Diff line change
Expand Up @@ -710,7 +710,7 @@ def forward(

# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)

# If a 2D or 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -612,7 +612,7 @@ def forward(
if token_type_ids is None:
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)

extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape, device)
extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape)
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/t5/modeling_t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -957,7 +957,7 @@ def forward(

# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape, inputs_embeds.device)
extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape)

# If a 2D or 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/tapas/modeling_tapas.py
Original file line number Diff line number Diff line change
Expand Up @@ -954,7 +954,7 @@ def forward(

# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)

# If a 2D ou 3D attention mask is provided for the cross-attention
# we need to make broadcastabe to [batch_size, num_heads, seq_length, seq_length]
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/vilt/modeling_vilt.py
Original file line number Diff line number Diff line change
Expand Up @@ -843,7 +843,7 @@ def forward(

# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)

encoder_outputs = self.encoder(
embedding_output,
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/visual_bert/modeling_visual_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -794,12 +794,12 @@ def forward(
if visual_embeds is not None:
combined_attention_mask = torch.cat((attention_mask, visual_attention_mask), dim=-1)
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(
combined_attention_mask, [batch_size, input_shape + visual_input_shape], device
combined_attention_mask, (batch_size, input_shape + visual_input_shape)
)

else:
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(
attention_mask, [batch_size, input_shape], device
attention_mask, (batch_size, input_shape)
)

# Prepare head mask if needed
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -788,7 +788,7 @@ def forward(

# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)

# If a 2D or 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/yoso/modeling_yoso.py
Original file line number Diff line number Diff line change
Expand Up @@ -816,7 +816,7 @@ def forward(

# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)

# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
Expand Down
Loading