Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add type hints for several pytorch models (batch-3) #25705

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions src/transformers/models/ernie_m/modeling_ernie_m.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,7 +538,7 @@ def forward(
output_hidden_states: Optional[bool] = None,
output_attentions: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
) -> Union[Tuple[torch.FloatTensor], BaseModelOutputWithPoolingAndCrossAttentions]:
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time.")

Expand Down Expand Up @@ -647,7 +647,7 @@ def forward(
output_attentions: Optional[bool] = None,
return_dict: Optional[bool] = True,
labels: Optional[torch.Tensor] = None,
):
) -> Union[Tuple[torch.FloatTensor], SequenceClassifierOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
Expand Down Expand Up @@ -744,7 +744,7 @@ def forward(
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = True,
):
) -> Union[Tuple[torch.FloatTensor], MultipleChoiceModelOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
Expand Down Expand Up @@ -837,7 +837,7 @@ def forward(
output_attentions: Optional[bool] = None,
return_dict: Optional[bool] = True,
labels: Optional[torch.Tensor] = None,
):
) -> Union[Tuple[torch.FloatTensor], TokenClassifierOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
Expand Down Expand Up @@ -914,7 +914,7 @@ def forward(
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = True,
):
) -> Union[Tuple[torch.FloatTensor], QuestionAnsweringModelOutput]:
r"""
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for position (index) of the start of the labelled span for computing the token classification loss.
Expand Down Expand Up @@ -1003,7 +1003,7 @@ def forward(
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = True,
):
) -> Union[Tuple[torch.FloatTensor], QuestionAnsweringModelOutput]:
r"""
start_positions (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for position (index) for computing the start_positions loss. Position outside of the sequence are
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/esm/modeling_esmfold.py
Original file line number Diff line number Diff line change
Expand Up @@ -2086,11 +2086,11 @@ def _af2_to_esm_from_vocab_list(vocab_list: List[str]) -> torch.Tensor:
def forward(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
masking_pattern: Optional[torch.Tensor] = None,
num_recycles: Optional[int] = None,
):
) -> EsmForProteinFoldingOutput:
r"""
Returns:

Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/graphormer/modeling_graphormer.py
Original file line number Diff line number Diff line change
Expand Up @@ -816,8 +816,8 @@ def forward(
out_degree: torch.LongTensor,
spatial_pos: torch.LongTensor,
attn_edge_type: torch.LongTensor,
perturb=None,
masked_tokens=None,
perturb: Optional[torch.FloatTensor] = None,
masked_tokens: None = None,
Comment on lines +819 to +820
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For perturb I took what microsoft/graphormer uses.

masked_tokens on the other hand, raises a NotImplemented whenever it has something that is not None

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Haa, that's funny. Congrats on our first correct type hint of an explicit None!

return_dict: Optional[bool] = None,
**unused,
) -> Union[Tuple[torch.LongTensor], BaseModelOutputWithNoAttention]:
Expand Down
26 changes: 13 additions & 13 deletions src/transformers/models/instructblip/modeling_instructblip.py
Original file line number Diff line number Diff line change
Expand Up @@ -1124,19 +1124,19 @@ def get_extended_attention_mask(

def forward(
self,
input_ids,
attention_mask=None,
position_ids=None,
query_embeds=None,
head_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
past_key_values=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
input_ids: torch.LongTensor,
attention_mask: Optional[torch.FloatTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
query_embeds: Optional[torch.Tensor] = None,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was unsure of the dtype (I couldn't find any reference) so I left the generic torch.Tensor

head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple[torch.FloatTensor], BaseModelOutputWithPoolingAndCrossAttentions]:
r"""
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/layoutlm/modeling_layoutlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -891,8 +891,8 @@ def forward(
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
encoder_hidden_states=None,
encoder_attention_mask=None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/luke/modeling_luke.py
Original file line number Diff line number Diff line change
Expand Up @@ -1664,7 +1664,7 @@ def __init__(self, config):
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask=None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
entity_ids: Optional[torch.LongTensor] = None,
Expand Down