Skip to content

Commit

Permalink
Fix deprecation warnings for int div (#15180)
Browse files Browse the repository at this point in the history
* Fix deprecation warnings for int div

Co-authored-by: mgoldey <matthew.goldey@gmail.com>

* Fix import

* ensure that tensor output is python scalar

* make backward compatible

* make code more readable

* adapt test functions

Co-authored-by: mgoldey <matthew.goldey@gmail.com>
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
  • Loading branch information
3 people authored Jan 18, 2022
1 parent f6d3fee commit 531336b
Show file tree
Hide file tree
Showing 10 changed files with 43 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,8 @@ def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) ->
batch_size = batch["input_values"].shape[0]

mask_indices_seq_length = self.model._get_feat_extract_output_lengths(batch["input_values"].shape[-1])
# make sure masked sequence length is a Python scalar
mask_indices_seq_length = int(mask_indices_seq_length)

# make sure that no loss is computed on padded inputs
if batch.get("attention_mask") is not None:
Expand Down
11 changes: 11 additions & 0 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union

import torch
from packaging import version
from torch import Tensor, device, nn
from torch.nn import CrossEntropyLoss

Expand Down Expand Up @@ -2362,3 +2363,13 @@ def forward(self, hidden_states):
return torch.cat(output_chunks, dim=chunk_dim)

return forward_fn(*input_tensors)


def torch_int_div(tensor1, tensor2):
"""
A function that performs integer division across different versions of PyTorch.
"""
if version.parse(torch.__version__) < version.parse("1.8.0"):
return tensor1 // tensor2
else:
return torch.div(tensor1, tensor2, rounding_mode="floor")
4 changes: 2 additions & 2 deletions src/transformers/models/hubert/modeling_hubert.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
replace_return_docstrings,
)
from ...modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassifierOutput
from ...modeling_utils import PreTrainedModel
from ...modeling_utils import PreTrainedModel, torch_int_div
from ...utils import logging
from .configuration_hubert import HubertConfig

Expand Down Expand Up @@ -829,7 +829,7 @@ def _get_feat_extract_output_lengths(self, input_lengths: Union[torch.LongTensor
def _conv_out_length(input_length, kernel_size, stride):
# 1D convolutional layer output length formula taken
# from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
return (input_length - kernel_size) // stride + 1
return torch_int_div(input_length - kernel_size, stride) + 1

for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride):
input_lengths = _conv_out_length(input_lengths, kernel_size, stride)
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/sew/modeling_sew.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from ...activations import ACT2FN
from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward
from ...modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassifierOutput
from ...modeling_utils import PreTrainedModel
from ...modeling_utils import PreTrainedModel, torch_int_div
from ...utils import logging
from .configuration_sew import SEWConfig

Expand Down Expand Up @@ -735,7 +735,7 @@ def _get_feat_extract_output_lengths(self, input_lengths: Union[torch.LongTensor
def _conv_out_length(input_length, kernel_size, stride):
# 1D convolutional layer output length formula taken
# from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
return (input_length - kernel_size) // stride + 1
return torch_int_div(input_length - kernel_size, stride) + 1

for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride):
input_lengths = _conv_out_length(input_lengths, kernel_size, stride)
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/sew_d/modeling_sew_d.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from ...activations import ACT2FN
from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward
from ...modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassifierOutput
from ...modeling_utils import PreTrainedModel
from ...modeling_utils import PreTrainedModel, torch_int_div
from ...utils import logging
from .configuration_sew_d import SEWDConfig

Expand Down Expand Up @@ -1266,7 +1266,7 @@ def _get_feat_extract_output_lengths(self, input_lengths: Union[torch.LongTensor
def _conv_out_length(input_length, kernel_size, stride):
# 1D convolutional layer output length formula taken
# from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
return (input_length - kernel_size) // stride + 1
return torch_int_div(input_length - kernel_size, stride) + 1

for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride):
input_lengths = _conv_out_length(input_lengths, kernel_size, stride)
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/unispeech/modeling_unispeech.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
replace_return_docstrings,
)
from ...modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassifierOutput
from ...modeling_utils import PreTrainedModel
from ...modeling_utils import PreTrainedModel, torch_int_div
from ...utils import logging
from .configuration_unispeech import UniSpeechConfig

Expand Down Expand Up @@ -969,7 +969,7 @@ def _get_feat_extract_output_lengths(self, input_lengths: Union[torch.LongTensor
def _conv_out_length(input_length, kernel_size, stride):
# 1D convolutional layer output length formula taken
# from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
return (input_length - kernel_size) // stride + 1
return torch_int_div(input_length - kernel_size, stride) + 1

for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride):
input_lengths = _conv_out_length(input_lengths, kernel_size, stride)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
replace_return_docstrings,
)
from ...modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassifierOutput, TokenClassifierOutput
from ...modeling_utils import PreTrainedModel
from ...modeling_utils import PreTrainedModel, torch_int_div
from ...utils import logging
from .configuration_unispeech_sat import UniSpeechSatConfig

Expand Down Expand Up @@ -1003,7 +1003,7 @@ def _get_feat_extract_output_lengths(self, input_lengths: Union[torch.LongTensor
def _conv_out_length(input_length, kernel_size, stride):
# 1D convolutional layer output length formula taken
# from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
return (input_length - kernel_size) // stride + 1
return torch_int_div(input_length - kernel_size, stride) + 1

for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride):
input_lengths = _conv_out_length(input_lengths, kernel_size, stride)
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/wav2vec2/modeling_wav2vec2.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
SequenceClassifierOutput,
TokenClassifierOutput,
)
from ...modeling_utils import PreTrainedModel
from ...modeling_utils import PreTrainedModel, torch_int_div
from ...utils import logging
from .configuration_wav2vec2 import Wav2Vec2Config

Expand Down Expand Up @@ -1104,7 +1104,7 @@ def _get_feat_extract_output_lengths(
def _conv_out_length(input_length, kernel_size, stride):
# 1D convolutional layer output length formula taken
# from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
return (input_length - kernel_size) // stride + 1
return torch_int_div(input_length - kernel_size, stride) + 1

for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride):
input_lengths = _conv_out_length(input_lengths, kernel_size, stride)
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/wavlm/modeling_wavlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
add_start_docstrings_to_model_forward,
)
from ...modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassifierOutput, TokenClassifierOutput
from ...modeling_utils import PreTrainedModel
from ...modeling_utils import PreTrainedModel, torch_int_div
from ...utils import logging
from .configuration_wavlm import WavLMConfig

Expand Down Expand Up @@ -1057,7 +1057,7 @@ def _get_feat_extract_output_lengths(
def _conv_out_length(input_length, kernel_size, stride):
# 1D convolutional layer output length formula taken
# from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
return (input_length - kernel_size) // stride + 1
return torch_int_div(input_length - kernel_size, stride) + 1

for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride):
input_lengths = _conv_out_length(input_lengths, kernel_size, stride)
Expand Down
32 changes: 16 additions & 16 deletions tests/test_modeling_wav2vec2.py
Original file line number Diff line number Diff line change
Expand Up @@ -794,10 +794,10 @@ def test_model_for_pretraining(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
model = Wav2Vec2ForPreTraining(config).to(torch_device)

features_shape = (
inputs_dict["input_values"].shape[0],
model._get_feat_extract_output_lengths(inputs_dict["input_values"].shape[1]),
)
batch_size = inputs_dict["input_values"].shape[0]
feature_seq_length = int(model._get_feat_extract_output_lengths(inputs_dict["input_values"].shape[1]))

features_shape = (batch_size, feature_seq_length)

mask_time_indices = _compute_mask_indices(
features_shape,
Expand Down Expand Up @@ -1158,10 +1158,10 @@ def test_inference_integration(self):

inputs_dict = feature_extractor(input_speech, return_tensors="pt", padding=True)

features_shape = (
inputs_dict["input_values"].shape[0],
model._get_feat_extract_output_lengths(torch.tensor(inputs_dict["input_values"].shape[1])),
)
batch_size = inputs_dict["input_values"].shape[0]
feature_seq_length = int(model._get_feat_extract_output_lengths(inputs_dict["input_values"].shape[1]))

features_shape = (batch_size, feature_seq_length)

np.random.seed(4)
mask_time_indices = _compute_mask_indices(
Expand Down Expand Up @@ -1208,10 +1208,10 @@ def test_inference_pretrained(self):

inputs_dict = feature_extractor(input_speech, return_tensors="pt", padding=True)

features_shape = (
inputs_dict["input_values"].shape[0],
model._get_feat_extract_output_lengths(torch.tensor(inputs_dict["input_values"].shape[1])),
)
batch_size = inputs_dict["input_values"].shape[0]
feature_seq_length = int(model._get_feat_extract_output_lengths(inputs_dict["input_values"].shape[1]))

features_shape = (batch_size, feature_seq_length)

torch.manual_seed(0)
mask_time_indices = _compute_mask_indices(
Expand Down Expand Up @@ -1279,10 +1279,10 @@ def test_loss_pretraining(self):

inputs_dict = feature_extractor(input_speech, return_tensors="pt", padding=True)

features_shape = (
inputs_dict["input_values"].shape[0],
model._get_feat_extract_output_lengths(inputs_dict["input_values"].shape[1]),
)
batch_size = inputs_dict["input_values"].shape[0]
feature_seq_length = int(model._get_feat_extract_output_lengths(inputs_dict["input_values"].shape[1]))

features_shape = (batch_size, feature_seq_length)

torch.manual_seed(0)
np.random.seed(0)
Expand Down

0 comments on commit 531336b

Please sign in to comment.