Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove device parameter from create_extended_attention_mask_for_decoder #16894

Conversation

pbelevich
Copy link
Contributor

What does this PR do?

This RP removes redundant device parameter from create_extended_attention_mask_for_decoder that may cause potential issues if passed device is not equal attention_mask.device, see line modeling_utils.py#L610. Explanation: tracing logic from line 610 to method signature:
causal_mask.device == attention_mask.device => seq_ids.device == attention_mask.device => device == attention_mask.device

@michaelbenayoun

Fixes # (issue)

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@pbelevich pbelevich force-pushed the remove_device_from_create_extended_attention_mask_for_decoder branch from 9eaea11 to 5d59df5 Compare April 22, 2022 15:10
@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Apr 22, 2022

The documentation is not available anymore as the PR was closed or merged.

@pbelevich pbelevich marked this pull request as ready for review April 22, 2022 15:27
@michaelbenayoun
Copy link
Member

This seems legit for me, pinging @LysandreJik, @sgugger and @ydshieh to comment on this.

@michaelbenayoun michaelbenayoun requested review from ydshieh, michaelbenayoun, LysandreJik and sgugger and removed request for ydshieh and michaelbenayoun April 25, 2022 08:41
@ydshieh
Copy link
Collaborator

ydshieh commented Apr 25, 2022

LGTM, as it uses the device from the argument attention_mask.

def create_extended_attention_mask_for_decoder(self, input_shape, attention_mask):
batch_size, seq_length = input_shape
device = attention_mask.device

Thank you for reducing the potential issue!

(Please wait the approvals from sgugger or LysandreJik before merge 🙏 )

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your PR. Removing arguments from public methods is a bit of a breaking change (even if I don't expect many users to use those directly) so since it's very easy to avoid it here and raise a proper deprecation warning, I would like to this added before we merge.

Also, the first two research projects should not be touched.

@@ -152,7 +152,7 @@ def forward(

# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This example is pinned to Transformers == 3.5.1 so don't make any change there.

@@ -195,7 +195,7 @@ def forward(

# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here.

@@ -195,7 +195,7 @@ def forward(

# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here.

@@ -137,7 +137,7 @@ def embed_sentences_checkpointed(self, input_ids, attention_mask, checkpoint_bat
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
head_mask = [None] * self.sent_encoder.config.num_hidden_layers
extended_attention_mask: torch.Tensor = self.sent_encoder.get_extended_attention_mask(
attention_mask, input_shape, device
attention_mask, input_shape
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This one can be updated as it's not pinned.

@@ -137,7 +137,7 @@ def embed_sentences_checkpointed(self, input_ids, attention_mask, checkpoint_bat
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
head_mask = [None] * self.sent_encoder.config.num_hidden_layers
extended_attention_mask: torch.Tensor = self.sent_encoder.get_extended_attention_mask(
attention_mask, input_shape, device
attention_mask, input_shape
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This one can be updated as it's not pinned.

@@ -589,8 +589,9 @@ def invert_attention_mask(self, encoder_attention_mask: Tensor) -> Tensor:

return encoder_extended_attention_mask

def create_extended_attention_mask_for_decoder(self, input_shape, attention_mask, device):
def create_extended_attention_mask_for_decoder(self, input_shape, attention_mask):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removing an argument from a public method like this is a breaking change, so we should continue to accept if with a default of None than raise a deprecation warning if we detect it's not None telling the user that argument is not used anymore and will be removed in v5 of Transformers.

@@ -589,8 +589,9 @@ def invert_attention_mask(self, encoder_attention_mask: Tensor) -> Tensor:

return encoder_extended_attention_mask

def create_extended_attention_mask_for_decoder(self, input_shape, attention_mask, device):
def create_extended_attention_mask_for_decoder(self, input_shape, attention_mask):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removing an argument from a public method like this is a breaking change, so we should continue to accept if with a default of None than raise a deprecation warning if we detect it's not None telling the user that argument is not used anymore and will be removed in v5 of Transformers.

@@ -610,7 +611,7 @@ def create_extended_attention_mask_for_decoder(self, input_shape, attention_mask
extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
return extended_attention_mask

def get_extended_attention_mask(self, attention_mask: Tensor, input_shape: Tuple[int], device: device) -> Tensor:
def get_extended_attention_mask(self, attention_mask: Tensor, input_shape: Tuple[int]) -> Tensor:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here.

@@ -610,7 +611,7 @@ def create_extended_attention_mask_for_decoder(self, input_shape, attention_mask
extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
return extended_attention_mask

def get_extended_attention_mask(self, attention_mask: Tensor, input_shape: Tuple[int], device: device) -> Tensor:
def get_extended_attention_mask(self, attention_mask: Tensor, input_shape: Tuple[int]) -> Tensor:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here.

@pbelevich pbelevich force-pushed the remove_device_from_create_extended_attention_mask_for_decoder branch 3 times, most recently from 1326012 to 994597c Compare April 29, 2022 15:19
@@ -152,7 +152,7 @@ def forward(

# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This example is pinned to Transformers == 3.5.1 so don't make any change there.

def create_extended_attention_mask_for_decoder(input_shape, attention_mask, device):
def create_extended_attention_mask_for_decoder(self, input_shape, attention_mask, device=None):
if device is not None:
warnings.warn("`device` is deprecated and will be removed in v5 of Transformers.")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
warnings.warn("`device` is deprecated and will be removed in v5 of Transformers.")
warnings.warn("The `device` argument is deprecated and will be removed in v5 of Transformers.", FutureWarning)

Comment on lines 633 to 635
device: (`torch.device`):
**DEPRECATED**. `attention_mask.device` will be used instead in v5 of Transformers.
The device of the input to the model.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove the documentation entirely for a deprecated argument.

if not (attention_mask.dim() == 2 and self.config.is_decoder):
# show warning only if it won't be shown in `create_extended_attention_mask_for_decoder`
if device is not None:
warnings.warn("`device` is deprecated and will be removed in v5 of Transformers.")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
warnings.warn("`device` is deprecated and will be removed in v5 of Transformers.")
warnings.warn("The `device` argument is deprecated and will be removed in v5 of Transformers.", FutureWarning)

@pbelevich pbelevich force-pushed the remove_device_from_create_extended_attention_mask_for_decoder branch from 994597c to 91758cf Compare April 29, 2022 15:31
@pbelevich pbelevich force-pushed the remove_device_from_create_extended_attention_mask_for_decoder branch from 91758cf to 209647b Compare April 29, 2022 15:43
@pbelevich
Copy link
Contributor Author

@sgugger thanks for the code review! all comments have been addressed

@sgugger
Copy link
Collaborator

sgugger commented Apr 29, 2022

Thanks! Pinging @LysandreJik for final review :-)

Copy link
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks @pbelevich!

@LysandreJik LysandreJik merged commit 39f8eaf into huggingface:main May 3, 2022
stevhliu pushed a commit to stevhliu/transformers that referenced this pull request May 3, 2022
nandwalritik pushed a commit to nandwalritik/transformers that referenced this pull request May 4, 2022
Narsil pushed a commit to Narsil/transformers that referenced this pull request May 12, 2022
elusenji pushed a commit to elusenji/transformers that referenced this pull request Jun 12, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants