-
Notifications
You must be signed in to change notification settings - Fork 27.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Remove device parameter from create_extended_attention_mask_for_decoder #16894
Remove device parameter from create_extended_attention_mask_for_decoder #16894
Conversation
9eaea11
to
5d59df5
Compare
The documentation is not available anymore as the PR was closed or merged. |
This seems legit for me, pinging @LysandreJik, @sgugger and @ydshieh to comment on this. |
LGTM, as it uses the device from the argument transformers/src/transformers/modeling_utils.py Lines 592 to 594 in 5d59df5
Thank you for reducing the potential issue! (Please wait the approvals from sgugger or LysandreJik before merge 🙏 ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for your PR. Removing arguments from public methods is a bit of a breaking change (even if I don't expect many users to use those directly) so since it's very easy to avoid it here and raise a proper deprecation warning, I would like to this added before we merge.
Also, the first two research projects should not be touched.
@@ -152,7 +152,7 @@ def forward( | |||
|
|||
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] | |||
# ourselves in which case we just need to make it broadcastable to all heads. | |||
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device) | |||
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This example is pinned to Transformers == 3.5.1 so don't make any change there.
@@ -195,7 +195,7 @@ def forward( | |||
|
|||
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] | |||
# ourselves in which case we just need to make it broadcastable to all heads. | |||
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device) | |||
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here.
@@ -195,7 +195,7 @@ def forward( | |||
|
|||
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] | |||
# ourselves in which case we just need to make it broadcastable to all heads. | |||
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device) | |||
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here.
@@ -137,7 +137,7 @@ def embed_sentences_checkpointed(self, input_ids, attention_mask, checkpoint_bat | |||
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) | |||
head_mask = [None] * self.sent_encoder.config.num_hidden_layers | |||
extended_attention_mask: torch.Tensor = self.sent_encoder.get_extended_attention_mask( | |||
attention_mask, input_shape, device | |||
attention_mask, input_shape |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This one can be updated as it's not pinned.
@@ -137,7 +137,7 @@ def embed_sentences_checkpointed(self, input_ids, attention_mask, checkpoint_bat | |||
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) | |||
head_mask = [None] * self.sent_encoder.config.num_hidden_layers | |||
extended_attention_mask: torch.Tensor = self.sent_encoder.get_extended_attention_mask( | |||
attention_mask, input_shape, device | |||
attention_mask, input_shape |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This one can be updated as it's not pinned.
src/transformers/modeling_utils.py
Outdated
@@ -589,8 +589,9 @@ def invert_attention_mask(self, encoder_attention_mask: Tensor) -> Tensor: | |||
|
|||
return encoder_extended_attention_mask | |||
|
|||
def create_extended_attention_mask_for_decoder(self, input_shape, attention_mask, device): | |||
def create_extended_attention_mask_for_decoder(self, input_shape, attention_mask): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removing an argument from a public method like this is a breaking change, so we should continue to accept if with a default of None
than raise a deprecation warning if we detect it's not None
telling the user that argument is not used anymore and will be removed in v5 of Transformers.
src/transformers/modeling_utils.py
Outdated
@@ -589,8 +589,9 @@ def invert_attention_mask(self, encoder_attention_mask: Tensor) -> Tensor: | |||
|
|||
return encoder_extended_attention_mask | |||
|
|||
def create_extended_attention_mask_for_decoder(self, input_shape, attention_mask, device): | |||
def create_extended_attention_mask_for_decoder(self, input_shape, attention_mask): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removing an argument from a public method like this is a breaking change, so we should continue to accept if with a default of None
than raise a deprecation warning if we detect it's not None
telling the user that argument is not used anymore and will be removed in v5 of Transformers.
src/transformers/modeling_utils.py
Outdated
@@ -610,7 +611,7 @@ def create_extended_attention_mask_for_decoder(self, input_shape, attention_mask | |||
extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :] | |||
return extended_attention_mask | |||
|
|||
def get_extended_attention_mask(self, attention_mask: Tensor, input_shape: Tuple[int], device: device) -> Tensor: | |||
def get_extended_attention_mask(self, attention_mask: Tensor, input_shape: Tuple[int]) -> Tensor: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here.
src/transformers/modeling_utils.py
Outdated
@@ -610,7 +611,7 @@ def create_extended_attention_mask_for_decoder(self, input_shape, attention_mask | |||
extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :] | |||
return extended_attention_mask | |||
|
|||
def get_extended_attention_mask(self, attention_mask: Tensor, input_shape: Tuple[int], device: device) -> Tensor: | |||
def get_extended_attention_mask(self, attention_mask: Tensor, input_shape: Tuple[int]) -> Tensor: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here.
1326012
to
994597c
Compare
@@ -152,7 +152,7 @@ def forward( | |||
|
|||
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] | |||
# ourselves in which case we just need to make it broadcastable to all heads. | |||
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device) | |||
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This example is pinned to Transformers == 3.5.1 so don't make any change there.
src/transformers/modeling_utils.py
Outdated
def create_extended_attention_mask_for_decoder(input_shape, attention_mask, device): | ||
def create_extended_attention_mask_for_decoder(self, input_shape, attention_mask, device=None): | ||
if device is not None: | ||
warnings.warn("`device` is deprecated and will be removed in v5 of Transformers.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
warnings.warn("`device` is deprecated and will be removed in v5 of Transformers.") | |
warnings.warn("The `device` argument is deprecated and will be removed in v5 of Transformers.", FutureWarning) |
src/transformers/modeling_utils.py
Outdated
device: (`torch.device`): | ||
**DEPRECATED**. `attention_mask.device` will be used instead in v5 of Transformers. | ||
The device of the input to the model. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove the documentation entirely for a deprecated argument.
src/transformers/modeling_utils.py
Outdated
if not (attention_mask.dim() == 2 and self.config.is_decoder): | ||
# show warning only if it won't be shown in `create_extended_attention_mask_for_decoder` | ||
if device is not None: | ||
warnings.warn("`device` is deprecated and will be removed in v5 of Transformers.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
warnings.warn("`device` is deprecated and will be removed in v5 of Transformers.") | |
warnings.warn("The `device` argument is deprecated and will be removed in v5 of Transformers.", FutureWarning) |
994597c
to
91758cf
Compare
91758cf
to
209647b
Compare
@sgugger thanks for the code review! all comments have been addressed |
Thanks! Pinging @LysandreJik for final review :-) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks @pbelevich!
What does this PR do?
This RP removes redundant
device
parameter fromcreate_extended_attention_mask_for_decoder
that may cause potential issues if passeddevice
is not equalattention_mask.device
, see linemodeling_utils.py#L610
. Explanation: tracing logic from line 610 to method signature:causal_mask.device
==attention_mask.device
=>seq_ids.device
==attention_mask.device
=>device
==attention_mask.device
@michaelbenayoun
Fixes # (issue)
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.