Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FeatureExtractorSavingUtils] Refactor PretrainedFeatureExtractor #10594

20 changes: 14 additions & 6 deletions docs/source/main_classes/feature_extractor.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,24 @@
Feature Extractor
-----------------------------------------------------------------------------------------------------------------------

A feature extractor is in charge of preparing read-in audio files for a speech model. This includes feature extraction,
such as processing audio files to, *e.g.*, Log-Mel Spectrogram features, but also padding, normalization, and
conversion to Numpy, PyTorch, and TensorFlow tensors.
A feature extractor is in charge of preparing input features for a multi-modal model. This includes feature extraction
from sequences, *e.g.*, pre-processing audio files to Log-Mel Spectrogram features, feature extraction from images
*e.g.* cropping image image files, but also padding, normalization, and conversion to Numpy, PyTorch, and TensorFlow
tensors.


PreTrainedFeatureExtractor
FeatureExtractionMixin
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.PreTrainedFeatureExtractor
:members: from_pretrained, save_pretrained, pad
.. autoclass:: transformers.feature_extraction_utils.FeatureExtractionMixin
:members: from_pretrained, save_pretrained


SequenceFeatureExtractor
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.SequenceFeatureExtractor
:members: pad


BatchFeature
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@
"SpecialTokensMixin",
"TokenSpan",
],
"feature_extraction_utils": ["PreTrainedFeatureExtractor", "BatchFeature"],
"feature_extraction_sequence_utils": ["SequenceFeatureExtractor", "BatchFeature"],
"trainer_callback": [
"DefaultFlowCallback",
"EarlyStoppingCallback",
Expand Down Expand Up @@ -1250,7 +1250,7 @@
)

# Feature Extractor
from .feature_extraction_utils import BatchFeature, PreTrainedFeatureExtractor
from .feature_extraction_utils import BatchFeature, SequenceFeatureExtractor

# Files and general utilities
from .file_utils import (
Expand Down
317 changes: 317 additions & 0 deletions src/transformers/feature_extraction_sequence_utils.py

Large diffs are not rendered by default.

355 changes: 38 additions & 317 deletions src/transformers/feature_extraction_utils.py

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,16 @@

import numpy as np

from ...feature_extraction_utils import BatchFeature, PreTrainedFeatureExtractor
from ...feature_extraction_sequence_utils import SequenceFeatureExtractor
from ...feature_extraction_utils import BatchFeature
from ...file_utils import PaddingStrategy, TensorType
from ...utils import logging


logger = logging.get_logger(__name__)


class Wav2Vec2FeatureExtractor(PreTrainedFeatureExtractor):
class Wav2Vec2FeatureExtractor(SequenceFeatureExtractor):
r"""
Constructs a Wav2Vec2 feature extractor.

Expand Down
13 changes: 7 additions & 6 deletions src/transformers/models/wav2vec2/processing_wav2vec2.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,8 @@ def save_pretrained(self, save_directory):

.. note::

This class method is simply calling :meth:`~transformers.PreTrainedFeatureExtractor.save_pretrained` and
This class method is simply calling
:meth:`~transformers.feature_extraction_utils.FeatureExtractionMixin.save_pretrained` and
:meth:`~transformers.tokenization_utils_base.PreTrainedTokenizer.save_pretrained`. Please refer to the
docstrings of the methods above for more information.

Expand All @@ -80,9 +81,9 @@ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
.. note::

This class method is simply calling Wav2Vec2FeatureExtractor's
:meth:`~transformers.PreTrainedFeatureExtractor.from_pretrained` and Wav2Vec2CTCTokenizer's
:meth:`~transformers.tokenization_utils_base.PreTrainedTokenizer.from_pretrained`. Please refer to the
docstrings of the methods above for more information.
:meth:`~transformers.feature_extraction_utils.FeatureExtractionMixin.from_pretrained` and
Wav2Vec2CTCTokenizer's :meth:`~transformers.tokenization_utils_base.PreTrainedTokenizer.from_pretrained`.
Please refer to the docstrings of the methods above for more information.

Args:
pretrained_model_name_or_path (:obj:`str` or :obj:`os.PathLike`):
Expand All @@ -92,12 +93,12 @@ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
huggingface.co. Valid model ids can be located at the root-level, like ``bert-base-uncased``, or
namespaced under a user or organization name, like ``dbmdz/bert-base-german-cased``.
- a path to a `directory` containing a feature extractor file saved using the
:meth:`~transformers.PreTrainedFeatureExtractor.save_pretrained` method, e.g.,
:meth:`~transformers.SequenceFeatureExtractor.save_pretrained` method, e.g.,
``./my_model_directory/``.
- a path or url to a saved feature extractor JSON `file`, e.g.,
``./my_model_directory/feature_extraction_config.json``.
**kwargs
Additional keyword arguments passed along to both :class:`~transformers.PreTrainedFeatureExtractor` and
Additional keyword arguments passed along to both :class:`~transformers.SequenceFeatureExtractor` and
:class:`~transformers.PreTrainedTokenizer`
"""
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(pretrained_model_name_or_path, **kwargs)
Expand Down
3 changes: 1 addition & 2 deletions src/transformers/tokenization_utils_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -727,8 +727,7 @@ def to(self, device: Union[str, "torch.device"]) -> "BatchEncoding":
device (:obj:`str` or :obj:`torch.device`): The device to put the tensors on.

Returns:
:class:`~transformers.BatchEncoding`: The same instance of :class:`~transformers.BatchEncoding` after
modification.
:class:`~transformers.BatchEncoding`: The same instance after modification.
"""

# This check catches things like APEX blindly calling "to" on all inputs to a module
Expand Down
236 changes: 1 addition & 235 deletions tests/test_feature_extraction_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,28 +18,8 @@
import os
import tempfile

import numpy as np

from transformers import BatchFeature
from transformers.testing_utils import require_tf, require_torch


class FeatureExtractionMixin:

# to overwrite at feature extractactor specific tests
feat_extract_tester = None
feature_extraction_class = None

@property
def feat_extract_dict(self):
return self.feat_extract_tester.prepare_feat_extract_dict()

def test_feat_extract_common_properties(self):
feat_extract = self.feature_extraction_class(**self.feat_extract_dict)
self.assertTrue(hasattr(feat_extract, "feature_size"))
self.assertTrue(hasattr(feat_extract, "sampling_rate"))
self.assertTrue(hasattr(feat_extract, "padding_value"))

class FeatureExtractionSavingTestMixin:
def test_feat_extract_to_json_string(self):
feat_extract = self.feature_extraction_class(**self.feat_extract_dict)
obj = json.loads(feat_extract.to_json_string())
Expand Down Expand Up @@ -68,217 +48,3 @@ def test_feat_extract_from_and_save_pretrained(self):
def test_init_without_params(self):
feat_extract = self.feature_extraction_class()
self.assertIsNotNone(feat_extract)

def test_batch_feature(self):
speech_inputs = self.feat_extract_tester.prepare_inputs_for_common()
feat_extract = self.feature_extraction_class(**self.feat_extract_dict)
input_name = feat_extract.model_input_names[0]

processed_features = BatchFeature({input_name: speech_inputs})

self.assertTrue(all(len(x) == len(y) for x, y in zip(speech_inputs, processed_features[input_name])))

speech_inputs = self.feat_extract_tester.prepare_inputs_for_common(equal_length=True)
processed_features = BatchFeature({input_name: speech_inputs}, tensor_type="np")

batch_features_input = processed_features[input_name]

if len(batch_features_input.shape) < 3:
batch_features_input = batch_features_input[:, :, None]

self.assertTrue(
batch_features_input.shape
== (self.feat_extract_tester.batch_size, len(speech_inputs[0]), self.feat_extract_tester.feature_size)
)

@require_torch
def test_batch_feature_pt(self):
speech_inputs = self.feat_extract_tester.prepare_inputs_for_common(equal_length=True)
feat_extract = self.feature_extraction_class(**self.feat_extract_dict)
input_name = feat_extract.model_input_names[0]

processed_features = BatchFeature({input_name: speech_inputs}, tensor_type="pt")

batch_features_input = processed_features[input_name]

if len(batch_features_input.shape) < 3:
batch_features_input = batch_features_input[:, :, None]

self.assertTrue(
batch_features_input.shape
== (self.feat_extract_tester.batch_size, len(speech_inputs[0]), self.feat_extract_tester.feature_size)
)

@require_tf
def test_batch_feature_tf(self):
speech_inputs = self.feat_extract_tester.prepare_inputs_for_common(equal_length=True)
feat_extract = self.feature_extraction_class(**self.feat_extract_dict)
input_name = feat_extract.model_input_names[0]

processed_features = BatchFeature({input_name: speech_inputs}, tensor_type="tf")

batch_features_input = processed_features[input_name]

if len(batch_features_input.shape) < 3:
batch_features_input = batch_features_input[:, :, None]

self.assertTrue(
batch_features_input.shape
== (self.feat_extract_tester.batch_size, len(speech_inputs[0]), self.feat_extract_tester.feature_size)
)

def _check_padding(self, numpify=False):
def _inputs_have_equal_length(input):
length = len(input[0])
for input_slice in input[1:]:
if len(input_slice) != length:
return False
return True

def _inputs_are_equal(input_1, input_2):
if len(input_1) != len(input_2):
return False

for input_slice_1, input_slice_2 in zip(input_1, input_2):
if not np.allclose(np.asarray(input_slice_1), np.asarray(input_slice_2), atol=1e-3):
return False
return True

feat_extract = self.feature_extraction_class(**self.feat_extract_dict)
speech_inputs = self.feat_extract_tester.prepare_inputs_for_common(numpify=numpify)
input_name = feat_extract.model_input_names[0]

processed_features = BatchFeature({input_name: speech_inputs})

pad_diff = self.feat_extract_tester.seq_length_diff
pad_max_length = self.feat_extract_tester.max_seq_length + pad_diff
pad_min_length = self.feat_extract_tester.min_seq_length
batch_size = self.feat_extract_tester.batch_size
feature_size = self.feat_extract_tester.feature_size

# test padding for List[int] + numpy
input_1 = feat_extract.pad(processed_features, padding=False)[input_name]
input_2 = feat_extract.pad(processed_features, padding="longest")[input_name]
input_3 = feat_extract.pad(processed_features, padding="max_length", max_length=len(speech_inputs[-1]))[
input_name
]
input_4 = feat_extract.pad(processed_features, padding="longest", return_tensors="np")[input_name]

# max_length parameter has to be provided when setting `padding="max_length"`
with self.assertRaises(ValueError):
feat_extract.pad(processed_features, padding="max_length")[input_name]

input_5 = feat_extract.pad(
processed_features, padding="max_length", max_length=pad_max_length, return_tensors="np"
)[input_name]

self.assertFalse(_inputs_have_equal_length(input_1))
self.assertTrue(_inputs_have_equal_length(input_2))
self.assertTrue(_inputs_have_equal_length(input_3))
self.assertTrue(_inputs_are_equal(input_2, input_3))
self.assertTrue(len(input_1[0]) == pad_min_length)
self.assertTrue(len(input_1[1]) == pad_min_length + pad_diff)
self.assertTrue(input_4.shape[:2] == (batch_size, len(input_3[0])))
self.assertTrue(input_5.shape[:2] == (batch_size, pad_max_length))

if feature_size > 1:
self.assertTrue(input_4.shape[2] == input_5.shape[2] == feature_size)

# test padding for `pad_to_multiple_of` for List[int] + numpy
input_6 = feat_extract.pad(processed_features, pad_to_multiple_of=10)[input_name]
input_7 = feat_extract.pad(processed_features, padding="longest", pad_to_multiple_of=10)[input_name]
input_8 = feat_extract.pad(
processed_features, padding="max_length", pad_to_multiple_of=10, max_length=pad_max_length
)[input_name]
input_9 = feat_extract.pad(
processed_features,
padding="max_length",
pad_to_multiple_of=10,
max_length=pad_max_length,
return_tensors="np",
)[input_name]

self.assertTrue(all(len(x) % 10 == 0 for x in input_6))
self.assertTrue(_inputs_are_equal(input_6, input_7))

expected_mult_pad_length = pad_max_length if pad_max_length % 10 == 0 else (pad_max_length // 10 + 1) * 10
self.assertTrue(all(len(x) == expected_mult_pad_length for x in input_8))
self.assertTrue(input_9.shape[:2], (batch_size, expected_mult_pad_length))

if feature_size > 1:
self.assertTrue(input_9.shape[2] == feature_size)

# Check padding value is correct
padding_vector_sum = (np.ones(self.feat_extract_tester.feature_size) * feat_extract.padding_value).sum()
self.assertTrue(
abs(np.asarray(input_2[0])[pad_min_length:].sum() - padding_vector_sum * (pad_max_length - pad_min_length))
< 1e-3
)
self.assertTrue(
abs(
np.asarray(input_2[1])[pad_min_length + pad_diff :].sum()
- padding_vector_sum * (pad_max_length - pad_min_length - pad_diff)
)
< 1e-3
)
self.assertTrue(
abs(
np.asarray(input_2[2])[pad_min_length + 2 * pad_diff :].sum()
- padding_vector_sum * (pad_max_length - pad_min_length - 2 * pad_diff)
)
< 1e-3
)
self.assertTrue(
abs(input_5[0, pad_min_length:].sum() - padding_vector_sum * (pad_max_length - pad_min_length)) < 1e-3
)
self.assertTrue(
abs(input_9[0, pad_min_length:].sum() - padding_vector_sum * (expected_mult_pad_length - pad_min_length))
< 1e-3
)

def test_padding_from_list(self):
self._check_padding(numpify=False)

def test_padding_from_array(self):
self._check_padding(numpify=True)

@require_torch
def test_padding_accepts_tensors_pt(self):
feat_extract = self.feature_extraction_class(**self.feat_extract_dict)
speech_inputs = self.feat_extract_tester.prepare_inputs_for_common()
input_name = feat_extract.model_input_names[0]

processed_features = BatchFeature({input_name: speech_inputs})

input_np = feat_extract.pad(processed_features, padding="longest", return_tensors="np")[input_name]
input_pt = feat_extract.pad(processed_features, padding="longest", return_tensors="pt")[input_name]

self.assertTrue(abs(input_np.sum() - input_pt.numpy().sum()) < 1e-2)

@require_tf
def test_padding_accepts_tensors_tf(self):
feat_extract = self.feature_extraction_class(**self.feat_extract_dict)
speech_inputs = self.feat_extract_tester.prepare_inputs_for_common()
input_name = feat_extract.model_input_names[0]

processed_features = BatchFeature({input_name: speech_inputs})

input_np = feat_extract.pad(processed_features, padding="longest", return_tensors="np")[input_name]
input_tf = feat_extract.pad(processed_features, padding="longest", return_tensors="tf")[input_name]

self.assertTrue(abs(input_np.sum() - input_tf.numpy().sum()) < 1e-2)

def test_attention_mask(self):
feat_dict = self.feat_extract_dict
feat_dict["return_attention_mask"] = True
feat_extract = self.feature_extraction_class(**feat_dict)
speech_inputs = self.feat_extract_tester.prepare_inputs_for_common()
input_lenghts = [len(x) for x in speech_inputs]
input_name = feat_extract.model_input_names[0]

processed = BatchFeature({input_name: speech_inputs})

processed = feat_extract.pad(processed, padding="longest", return_tensors="np")
self.assertIn("attention_mask", processed)
self.assertListEqual(list(processed.attention_mask.shape), list(processed[input_name].shape[:2]))
self.assertListEqual(processed.attention_mask.sum(-1).tolist(), input_lenghts)
4 changes: 2 additions & 2 deletions tests/test_feature_extraction_wav2vec2.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from transformers import WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST, Wav2Vec2Config, Wav2Vec2FeatureExtractor
from transformers.testing_utils import slow

from .test_feature_extraction_common import FeatureExtractionMixin
from .test_sequence_feature_extraction_common import SequenceFeatureExtractionTestMixin


global_rng = random.Random()
Expand Down Expand Up @@ -94,7 +94,7 @@ def _flatten(list_of_lists):
return speech_inputs


class Wav2Vec2FeatureExtractionTest(FeatureExtractionMixin, unittest.TestCase):
class Wav2Vec2FeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.TestCase):

feature_extraction_class = Wav2Vec2FeatureExtractor

Expand Down
Loading