Skip to content

Commit

Permalink
[FeatureExtractorSavingUtils] Refactor PretrainedFeatureExtractor (#1…
Browse files Browse the repository at this point in the history
…0594)

* save first version

* finish refactor

* finish refactor

* correct naming

* correct naming

* shorter names

* Update src/transformers/feature_extraction_common_utils.py

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>

* change name

* finish

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
  • Loading branch information
patrickvonplaten and LysandreJik authored Mar 9, 2021
1 parent b6a28e9 commit 9a06b6b
Show file tree
Hide file tree
Showing 10 changed files with 638 additions and 572 deletions.
20 changes: 14 additions & 6 deletions docs/source/main_classes/feature_extractor.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,24 @@
Feature Extractor
-----------------------------------------------------------------------------------------------------------------------

A feature extractor is in charge of preparing read-in audio files for a speech model. This includes feature extraction,
such as processing audio files to, *e.g.*, Log-Mel Spectrogram features, but also padding, normalization, and
conversion to Numpy, PyTorch, and TensorFlow tensors.
A feature extractor is in charge of preparing input features for a multi-modal model. This includes feature extraction
from sequences, *e.g.*, pre-processing audio files to Log-Mel Spectrogram features, feature extraction from images
*e.g.* cropping image image files, but also padding, normalization, and conversion to Numpy, PyTorch, and TensorFlow
tensors.


PreTrainedFeatureExtractor
FeatureExtractionMixin
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.PreTrainedFeatureExtractor
:members: from_pretrained, save_pretrained, pad
.. autoclass:: transformers.feature_extraction_utils.FeatureExtractionMixin
:members: from_pretrained, save_pretrained


SequenceFeatureExtractor
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.SequenceFeatureExtractor
:members: pad


BatchFeature
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@
"SpecialTokensMixin",
"TokenSpan",
],
"feature_extraction_utils": ["PreTrainedFeatureExtractor", "BatchFeature"],
"feature_extraction_sequence_utils": ["SequenceFeatureExtractor", "BatchFeature"],
"trainer_callback": [
"DefaultFlowCallback",
"EarlyStoppingCallback",
Expand Down Expand Up @@ -1257,7 +1257,7 @@
)

# Feature Extractor
from .feature_extraction_utils import BatchFeature, PreTrainedFeatureExtractor
from .feature_extraction_utils import BatchFeature, SequenceFeatureExtractor

# Files and general utilities
from .file_utils import (
Expand Down
317 changes: 317 additions & 0 deletions src/transformers/feature_extraction_sequence_utils.py

Large diffs are not rendered by default.

355 changes: 38 additions & 317 deletions src/transformers/feature_extraction_utils.py

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,16 @@

import numpy as np

from ...feature_extraction_utils import BatchFeature, PreTrainedFeatureExtractor
from ...feature_extraction_sequence_utils import SequenceFeatureExtractor
from ...feature_extraction_utils import BatchFeature
from ...file_utils import PaddingStrategy, TensorType
from ...utils import logging


logger = logging.get_logger(__name__)


class Wav2Vec2FeatureExtractor(PreTrainedFeatureExtractor):
class Wav2Vec2FeatureExtractor(SequenceFeatureExtractor):
r"""
Constructs a Wav2Vec2 feature extractor.
Expand Down
13 changes: 7 additions & 6 deletions src/transformers/models/wav2vec2/processing_wav2vec2.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,8 @@ def save_pretrained(self, save_directory):
.. note::
This class method is simply calling :meth:`~transformers.PreTrainedFeatureExtractor.save_pretrained` and
This class method is simply calling
:meth:`~transformers.feature_extraction_utils.FeatureExtractionMixin.save_pretrained` and
:meth:`~transformers.tokenization_utils_base.PreTrainedTokenizer.save_pretrained`. Please refer to the
docstrings of the methods above for more information.
Expand All @@ -80,9 +81,9 @@ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
.. note::
This class method is simply calling Wav2Vec2FeatureExtractor's
:meth:`~transformers.PreTrainedFeatureExtractor.from_pretrained` and Wav2Vec2CTCTokenizer's
:meth:`~transformers.tokenization_utils_base.PreTrainedTokenizer.from_pretrained`. Please refer to the
docstrings of the methods above for more information.
:meth:`~transformers.feature_extraction_utils.FeatureExtractionMixin.from_pretrained` and
Wav2Vec2CTCTokenizer's :meth:`~transformers.tokenization_utils_base.PreTrainedTokenizer.from_pretrained`.
Please refer to the docstrings of the methods above for more information.
Args:
pretrained_model_name_or_path (:obj:`str` or :obj:`os.PathLike`):
Expand All @@ -92,12 +93,12 @@ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
huggingface.co. Valid model ids can be located at the root-level, like ``bert-base-uncased``, or
namespaced under a user or organization name, like ``dbmdz/bert-base-german-cased``.
- a path to a `directory` containing a feature extractor file saved using the
:meth:`~transformers.PreTrainedFeatureExtractor.save_pretrained` method, e.g.,
:meth:`~transformers.SequenceFeatureExtractor.save_pretrained` method, e.g.,
``./my_model_directory/``.
- a path or url to a saved feature extractor JSON `file`, e.g.,
``./my_model_directory/feature_extraction_config.json``.
**kwargs
Additional keyword arguments passed along to both :class:`~transformers.PreTrainedFeatureExtractor` and
Additional keyword arguments passed along to both :class:`~transformers.SequenceFeatureExtractor` and
:class:`~transformers.PreTrainedTokenizer`
"""
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(pretrained_model_name_or_path, **kwargs)
Expand Down
3 changes: 1 addition & 2 deletions src/transformers/tokenization_utils_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -727,8 +727,7 @@ def to(self, device: Union[str, "torch.device"]) -> "BatchEncoding":
device (:obj:`str` or :obj:`torch.device`): The device to put the tensors on.
Returns:
:class:`~transformers.BatchEncoding`: The same instance of :class:`~transformers.BatchEncoding` after
modification.
:class:`~transformers.BatchEncoding`: The same instance after modification.
"""

# This check catches things like APEX blindly calling "to" on all inputs to a module
Expand Down
236 changes: 1 addition & 235 deletions tests/test_feature_extraction_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,28 +18,8 @@
import os
import tempfile

import numpy as np

from transformers import BatchFeature
from transformers.testing_utils import require_tf, require_torch


class FeatureExtractionMixin:

# to overwrite at feature extractactor specific tests
feat_extract_tester = None
feature_extraction_class = None

@property
def feat_extract_dict(self):
return self.feat_extract_tester.prepare_feat_extract_dict()

def test_feat_extract_common_properties(self):
feat_extract = self.feature_extraction_class(**self.feat_extract_dict)
self.assertTrue(hasattr(feat_extract, "feature_size"))
self.assertTrue(hasattr(feat_extract, "sampling_rate"))
self.assertTrue(hasattr(feat_extract, "padding_value"))

class FeatureExtractionSavingTestMixin:
def test_feat_extract_to_json_string(self):
feat_extract = self.feature_extraction_class(**self.feat_extract_dict)
obj = json.loads(feat_extract.to_json_string())
Expand Down Expand Up @@ -68,217 +48,3 @@ def test_feat_extract_from_and_save_pretrained(self):
def test_init_without_params(self):
feat_extract = self.feature_extraction_class()
self.assertIsNotNone(feat_extract)

def test_batch_feature(self):
speech_inputs = self.feat_extract_tester.prepare_inputs_for_common()
feat_extract = self.feature_extraction_class(**self.feat_extract_dict)
input_name = feat_extract.model_input_names[0]

processed_features = BatchFeature({input_name: speech_inputs})

self.assertTrue(all(len(x) == len(y) for x, y in zip(speech_inputs, processed_features[input_name])))

speech_inputs = self.feat_extract_tester.prepare_inputs_for_common(equal_length=True)
processed_features = BatchFeature({input_name: speech_inputs}, tensor_type="np")

batch_features_input = processed_features[input_name]

if len(batch_features_input.shape) < 3:
batch_features_input = batch_features_input[:, :, None]

self.assertTrue(
batch_features_input.shape
== (self.feat_extract_tester.batch_size, len(speech_inputs[0]), self.feat_extract_tester.feature_size)
)

@require_torch
def test_batch_feature_pt(self):
speech_inputs = self.feat_extract_tester.prepare_inputs_for_common(equal_length=True)
feat_extract = self.feature_extraction_class(**self.feat_extract_dict)
input_name = feat_extract.model_input_names[0]

processed_features = BatchFeature({input_name: speech_inputs}, tensor_type="pt")

batch_features_input = processed_features[input_name]

if len(batch_features_input.shape) < 3:
batch_features_input = batch_features_input[:, :, None]

self.assertTrue(
batch_features_input.shape
== (self.feat_extract_tester.batch_size, len(speech_inputs[0]), self.feat_extract_tester.feature_size)
)

@require_tf
def test_batch_feature_tf(self):
speech_inputs = self.feat_extract_tester.prepare_inputs_for_common(equal_length=True)
feat_extract = self.feature_extraction_class(**self.feat_extract_dict)
input_name = feat_extract.model_input_names[0]

processed_features = BatchFeature({input_name: speech_inputs}, tensor_type="tf")

batch_features_input = processed_features[input_name]

if len(batch_features_input.shape) < 3:
batch_features_input = batch_features_input[:, :, None]

self.assertTrue(
batch_features_input.shape
== (self.feat_extract_tester.batch_size, len(speech_inputs[0]), self.feat_extract_tester.feature_size)
)

def _check_padding(self, numpify=False):
def _inputs_have_equal_length(input):
length = len(input[0])
for input_slice in input[1:]:
if len(input_slice) != length:
return False
return True

def _inputs_are_equal(input_1, input_2):
if len(input_1) != len(input_2):
return False

for input_slice_1, input_slice_2 in zip(input_1, input_2):
if not np.allclose(np.asarray(input_slice_1), np.asarray(input_slice_2), atol=1e-3):
return False
return True

feat_extract = self.feature_extraction_class(**self.feat_extract_dict)
speech_inputs = self.feat_extract_tester.prepare_inputs_for_common(numpify=numpify)
input_name = feat_extract.model_input_names[0]

processed_features = BatchFeature({input_name: speech_inputs})

pad_diff = self.feat_extract_tester.seq_length_diff
pad_max_length = self.feat_extract_tester.max_seq_length + pad_diff
pad_min_length = self.feat_extract_tester.min_seq_length
batch_size = self.feat_extract_tester.batch_size
feature_size = self.feat_extract_tester.feature_size

# test padding for List[int] + numpy
input_1 = feat_extract.pad(processed_features, padding=False)[input_name]
input_2 = feat_extract.pad(processed_features, padding="longest")[input_name]
input_3 = feat_extract.pad(processed_features, padding="max_length", max_length=len(speech_inputs[-1]))[
input_name
]
input_4 = feat_extract.pad(processed_features, padding="longest", return_tensors="np")[input_name]

# max_length parameter has to be provided when setting `padding="max_length"`
with self.assertRaises(ValueError):
feat_extract.pad(processed_features, padding="max_length")[input_name]

input_5 = feat_extract.pad(
processed_features, padding="max_length", max_length=pad_max_length, return_tensors="np"
)[input_name]

self.assertFalse(_inputs_have_equal_length(input_1))
self.assertTrue(_inputs_have_equal_length(input_2))
self.assertTrue(_inputs_have_equal_length(input_3))
self.assertTrue(_inputs_are_equal(input_2, input_3))
self.assertTrue(len(input_1[0]) == pad_min_length)
self.assertTrue(len(input_1[1]) == pad_min_length + pad_diff)
self.assertTrue(input_4.shape[:2] == (batch_size, len(input_3[0])))
self.assertTrue(input_5.shape[:2] == (batch_size, pad_max_length))

if feature_size > 1:
self.assertTrue(input_4.shape[2] == input_5.shape[2] == feature_size)

# test padding for `pad_to_multiple_of` for List[int] + numpy
input_6 = feat_extract.pad(processed_features, pad_to_multiple_of=10)[input_name]
input_7 = feat_extract.pad(processed_features, padding="longest", pad_to_multiple_of=10)[input_name]
input_8 = feat_extract.pad(
processed_features, padding="max_length", pad_to_multiple_of=10, max_length=pad_max_length
)[input_name]
input_9 = feat_extract.pad(
processed_features,
padding="max_length",
pad_to_multiple_of=10,
max_length=pad_max_length,
return_tensors="np",
)[input_name]

self.assertTrue(all(len(x) % 10 == 0 for x in input_6))
self.assertTrue(_inputs_are_equal(input_6, input_7))

expected_mult_pad_length = pad_max_length if pad_max_length % 10 == 0 else (pad_max_length // 10 + 1) * 10
self.assertTrue(all(len(x) == expected_mult_pad_length for x in input_8))
self.assertTrue(input_9.shape[:2], (batch_size, expected_mult_pad_length))

if feature_size > 1:
self.assertTrue(input_9.shape[2] == feature_size)

# Check padding value is correct
padding_vector_sum = (np.ones(self.feat_extract_tester.feature_size) * feat_extract.padding_value).sum()
self.assertTrue(
abs(np.asarray(input_2[0])[pad_min_length:].sum() - padding_vector_sum * (pad_max_length - pad_min_length))
< 1e-3
)
self.assertTrue(
abs(
np.asarray(input_2[1])[pad_min_length + pad_diff :].sum()
- padding_vector_sum * (pad_max_length - pad_min_length - pad_diff)
)
< 1e-3
)
self.assertTrue(
abs(
np.asarray(input_2[2])[pad_min_length + 2 * pad_diff :].sum()
- padding_vector_sum * (pad_max_length - pad_min_length - 2 * pad_diff)
)
< 1e-3
)
self.assertTrue(
abs(input_5[0, pad_min_length:].sum() - padding_vector_sum * (pad_max_length - pad_min_length)) < 1e-3
)
self.assertTrue(
abs(input_9[0, pad_min_length:].sum() - padding_vector_sum * (expected_mult_pad_length - pad_min_length))
< 1e-3
)

def test_padding_from_list(self):
self._check_padding(numpify=False)

def test_padding_from_array(self):
self._check_padding(numpify=True)

@require_torch
def test_padding_accepts_tensors_pt(self):
feat_extract = self.feature_extraction_class(**self.feat_extract_dict)
speech_inputs = self.feat_extract_tester.prepare_inputs_for_common()
input_name = feat_extract.model_input_names[0]

processed_features = BatchFeature({input_name: speech_inputs})

input_np = feat_extract.pad(processed_features, padding="longest", return_tensors="np")[input_name]
input_pt = feat_extract.pad(processed_features, padding="longest", return_tensors="pt")[input_name]

self.assertTrue(abs(input_np.sum() - input_pt.numpy().sum()) < 1e-2)

@require_tf
def test_padding_accepts_tensors_tf(self):
feat_extract = self.feature_extraction_class(**self.feat_extract_dict)
speech_inputs = self.feat_extract_tester.prepare_inputs_for_common()
input_name = feat_extract.model_input_names[0]

processed_features = BatchFeature({input_name: speech_inputs})

input_np = feat_extract.pad(processed_features, padding="longest", return_tensors="np")[input_name]
input_tf = feat_extract.pad(processed_features, padding="longest", return_tensors="tf")[input_name]

self.assertTrue(abs(input_np.sum() - input_tf.numpy().sum()) < 1e-2)

def test_attention_mask(self):
feat_dict = self.feat_extract_dict
feat_dict["return_attention_mask"] = True
feat_extract = self.feature_extraction_class(**feat_dict)
speech_inputs = self.feat_extract_tester.prepare_inputs_for_common()
input_lenghts = [len(x) for x in speech_inputs]
input_name = feat_extract.model_input_names[0]

processed = BatchFeature({input_name: speech_inputs})

processed = feat_extract.pad(processed, padding="longest", return_tensors="np")
self.assertIn("attention_mask", processed)
self.assertListEqual(list(processed.attention_mask.shape), list(processed[input_name].shape[:2]))
self.assertListEqual(processed.attention_mask.sum(-1).tolist(), input_lenghts)
4 changes: 2 additions & 2 deletions tests/test_feature_extraction_wav2vec2.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from transformers import WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST, Wav2Vec2Config, Wav2Vec2FeatureExtractor
from transformers.testing_utils import slow

from .test_feature_extraction_common import FeatureExtractionMixin
from .test_sequence_feature_extraction_common import SequenceFeatureExtractionTestMixin


global_rng = random.Random()
Expand Down Expand Up @@ -94,7 +94,7 @@ def _flatten(list_of_lists):
return speech_inputs


class Wav2Vec2FeatureExtractionTest(FeatureExtractionMixin, unittest.TestCase):
class Wav2Vec2FeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.TestCase):

feature_extraction_class = Wav2Vec2FeatureExtractor

Expand Down
Loading

0 comments on commit 9a06b6b

Please sign in to comment.