From c009e906d69fee6d9a60b4e214be832e63e41687 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 8 Mar 2021 16:34:36 +0300 Subject: [PATCH 1/9] save first version --- .../source/main_classes/feature_extractor.rst | 8 +- src/transformers/__init__.py | 4 +- .../feature_extraction_saving_utils.py | 297 ++++++++++++++++ ...y => feature_extraction_sequence_utils.py} | 324 ++---------------- .../wav2vec2/feature_extraction_wav2vec2.py | 8 +- .../models/wav2vec2/processing_wav2vec2.py | 10 +- .../test_feature_extraction_saving_common.py | 50 +++ tests/test_feature_extraction_wav2vec2.py | 4 +- ...est_sequence_feature_extraction_common.py} | 55 +-- 9 files changed, 406 insertions(+), 354 deletions(-) create mode 100644 src/transformers/feature_extraction_saving_utils.py rename src/transformers/{feature_extraction_utils.py => feature_extraction_sequence_utils.py} (56%) create mode 100644 tests/test_feature_extraction_saving_common.py rename tests/{test_feature_extraction_common.py => test_sequence_feature_extraction_common.py} (83%) diff --git a/docs/source/main_classes/feature_extractor.rst b/docs/source/main_classes/feature_extractor.rst index 6d99cc2504bc85..5aa9f112a3ab9d 100644 --- a/docs/source/main_classes/feature_extractor.rst +++ b/docs/source/main_classes/feature_extractor.rst @@ -19,15 +19,15 @@ such as processing audio files to, *e.g.*, Log-Mel Spectrogram features, but als conversion to Numpy, PyTorch, and TensorFlow tensors. -PreTrainedFeatureExtractor +PreTrainedSequenceFeatureExtractor ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -.. autoclass:: transformers.PreTrainedFeatureExtractor +.. autoclass:: transformers.PreTrainedSequenceFeatureExtractor :members: from_pretrained, save_pretrained, pad -BatchFeature +BatchSequenceFeature ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -.. autoclass:: transformers.BatchFeature +.. autoclass:: transformers.BatchSequenceFeature :members: diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 2b6a037892c1e4..290411216dffd4 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -246,7 +246,7 @@ "SpecialTokensMixin", "TokenSpan", ], - "feature_extraction_utils": ["PreTrainedFeatureExtractor", "BatchFeature"], + "feature_extraction_sequence_utils": ["PreTrainedSequenceFeatureExtractor", "BatchSequenceFeature"], "trainer_callback": [ "DefaultFlowCallback", "EarlyStoppingCallback", @@ -1250,7 +1250,7 @@ ) # Feature Extractor - from .feature_extraction_utils import BatchFeature, PreTrainedFeatureExtractor + from .feature_extraction_sequence_utils import BatchSequenceFeature, PreTrainedSequenceFeatureExtractor # Files and general utilities from .file_utils import ( diff --git a/src/transformers/feature_extraction_saving_utils.py b/src/transformers/feature_extraction_saving_utils.py new file mode 100644 index 00000000000000..a7278762999bc6 --- /dev/null +++ b/src/transformers/feature_extraction_saving_utils.py @@ -0,0 +1,297 @@ +# coding=utf-8 +# Copyright 2021 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + Feature extraction common class for python feature extractors. +""" +import copy +import json +import os +from typing import Any, Dict, Tuple, Union + +from .file_utils import FEATURE_EXTRACTOR_NAME, cached_path, hf_bucket_url, is_remote_url +from .utils import logging + + +logger = logging.get_logger(__name__) + +PreTrainedFeatureExtractor = Union["PreTrainedSequenceFeatureExtractor"] + + +class FeatureExtractionSavingUtilsMixin: + """ + This is a general feature extraction class for speech recognition. + + Args: + feature_size (:obj:`int`): + The feature dimension of the extracted features. + sampling_rate (:obj:`int`): + The sampling rate at which the audio files should be digitalized expressed in Hertz per second (Hz). + padding_value (:obj:`float`): + The value that is used to fill the padding values / vectors. + """ + + @classmethod + def from_pretrained( + cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs + ) -> "PreTrainedFeatureExtractor": + r""" + Instantiate a :class:`~transformers.PreTrainedFeatureExtractor` (or a derived class) from a pretrained feature + extractor. + + Args: + pretrained_model_name_or_path (:obj:`str` or :obj:`os.PathLike`): + This can be either: + + - a string, the `model id` of a pretrained feature_extractor hosted inside a model repo on + huggingface.co. Valid model ids can be located at the root-level, like ``bert-base-uncased``, or + namespaced under a user or organization name, like ``dbmdz/bert-base-german-cased``. + - a path to a `directory` containing a feature extractor file saved using the + :func:`~transformers.PreTrainedFeatureExtractor.save_pretrained` method, e.g., + ``./my_model_directory/``. + - a path or url to a saved feature extractor JSON `file`, e.g., + ``./my_model_directory/feature_extraction_config.json``. + cache_dir (:obj:`str` or :obj:`os.PathLike`, `optional`): + Path to a directory in which a downloaded pretrained model feature extractor should be cached if the + standard cache should not be used. + force_download (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether or not to force to (re-)download the feature extractor files and override the cached versions + if they exist. + resume_download (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether or not to delete incompletely received file. Attempts to resume the download if such a file + exists. + proxies (:obj:`Dict[str, str]`, `optional`): + A dictionary of proxy servers to use by protocol or endpoint, e.g., :obj:`{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request. + use_auth_token (:obj:`str` or `bool`, `optional`): + The token to use as HTTP bearer authorization for remote files. If :obj:`True`, will use the token + generated when running :obj:`transformers-cli login` (stored in :obj:`~/.huggingface`). + revision(:obj:`str`, `optional`, defaults to :obj:`"main"`): + The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a + git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any + identifier allowed by git. + return_unused_kwargs (:obj:`bool`, `optional`, defaults to :obj:`False`): + If :obj:`False`, then this function returns just the final feature extractor object. If :obj:`True`, + then this functions returns a :obj:`Tuple(feature_extractor, unused_kwargs)` where `unused_kwargs` is a + dictionary consisting of the key/value pairs whose keys are not feature extractor attributes: i.e., the + part of ``kwargs`` which has not been used to update ``feature_extractor`` and is otherwise ignored. + kwargs (:obj:`Dict[str, Any]`, `optional`): + The values in kwargs of any keys which are feature extractor attributes will be used to override the + loaded values. Behavior concerning key/value pairs whose keys are *not* feature extractor attributes is + controlled by the ``return_unused_kwargs`` keyword parameter. .. note:: + Passing :obj:`use_auth_token=True` is required when you want to use a private model. + + Returns: + :class:`~transformers.PreTrainedFeatureExtractor`: The feature extractor object instantiated from this + pretrained model. + + Examples:: + # We can't instantiate directly the base class `PreTrainedFeatureExtractor` so let's show the examples on a + # derived class: Wav2Vec2FeatureExtractor + feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained('facebook/wav2vec2-base-960h') # Download feature_extraction_config from huggingface.co and cache. + feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained('./test/saved_model/') # E.g. feature_extractor (or model) was saved using `save_pretrained('./test/saved_model/')` + feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained('./test/saved_model/preprocessor_config.json') + feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained('facebook/wav2vec2-base-960h', return_attention_mask=False, foo=False) + assert feature_extractor.return_attention_mask is False + feature_extractor, unused_kwargs = Wav2Vec2FeatureExtractor.from_pretrained('facebook/wav2vec2-base-960h', return_attention_mask=False, + foo=False, return_unused_kwargs=True) + assert feature_extractor.return_attention_mask is False + assert unused_kwargs == {'foo': False} + """ + feature_extractor_dict, kwargs = cls.get_feature_extractor_dict(pretrained_model_name_or_path, **kwargs) + + return cls.from_dict(feature_extractor_dict, **kwargs) + + def save_pretrained(self, save_directory: Union[str, os.PathLike]): + """ + Save a feature_extractor object to the directory ``save_directory``, so that it can be re-loaded using the + :func:`~transformers.PreTrainedFeatureExtractor.from_pretrained` class method. + + Args: + save_directory (:obj:`str` or :obj:`os.PathLike`): + Directory where the feature extractor JSON file will be saved (will be created if it does not exist). + """ + if os.path.isfile(save_directory): + raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file") + os.makedirs(save_directory, exist_ok=True) + # If we save using the predefined names, we can load using `from_pretrained` + output_feature_extractor_file = os.path.join(save_directory, FEATURE_EXTRACTOR_NAME) + + self.to_json_file(output_feature_extractor_file) + logger.info(f"Configuration saved in {output_feature_extractor_file}") + + @classmethod + def get_feature_extractor_dict( + cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs + ) -> Tuple[Dict[str, Any], Dict[str, Any]]: + """ + From a ``pretrained_model_name_or_path``, resolve to a dictionary of parameters, to be used for instantiating a + :class:`~transformers.PreTrainedFeatureExtractor` using ``from_dict``. + + Parameters: + pretrained_model_name_or_path (:obj:`str` or :obj:`os.PathLike`): + The identifier of the pre-trained checkpoint from which we want the dictionary of parameters. + + Returns: + :obj:`Tuple[Dict, Dict]`: The dictionary(ies) that will be used to instantiate the feature extractor + object. + """ + cache_dir = kwargs.pop("cache_dir", None) + force_download = kwargs.pop("force_download", False) + resume_download = kwargs.pop("resume_download", False) + proxies = kwargs.pop("proxies", None) + use_auth_token = kwargs.pop("use_auth_token", None) + local_files_only = kwargs.pop("local_files_only", False) + revision = kwargs.pop("revision", None) + + pretrained_model_name_or_path = str(pretrained_model_name_or_path) + if os.path.isdir(pretrained_model_name_or_path): + feature_extractor_file = os.path.join(pretrained_model_name_or_path, FEATURE_EXTRACTOR_NAME) + elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path): + feature_extractor_file = pretrained_model_name_or_path + else: + feature_extractor_file = hf_bucket_url( + pretrained_model_name_or_path, filename=FEATURE_EXTRACTOR_NAME, revision=revision, mirror=None + ) + + try: + # Load from URL or cache if already cached + resolved_feature_extractor_file = cached_path( + feature_extractor_file, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + resume_download=resume_download, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + ) + # Load feature_extractor dict + with open(resolved_feature_extractor_file, "r", encoding="utf-8") as reader: + text = reader.read() + feature_extractor_dict = json.loads(text) + + except EnvironmentError as err: + logger.error(err) + msg = ( + f"Can't load feature extractor for '{pretrained_model_name_or_path}'. Make sure that:\n\n" + f"- '{pretrained_model_name_or_path}' is a correct model identifier listed on 'https://huggingface.co/models'\n\n" + f"- or '{pretrained_model_name_or_path}' is the correct path to a directory containing a {FEATURE_EXTRACTOR_NAME} file\n\n" + ) + raise EnvironmentError(msg) + + except json.JSONDecodeError: + msg = ( + f"Couldn't reach server at '{feature_extractor_file}' to download feature extractor configuration file or " + "feature extractor configuration file is not a valid JSON file. " + f"Please check network or file content here: {resolved_feature_extractor_file}." + ) + raise EnvironmentError(msg) + + if resolved_feature_extractor_file == feature_extractor_file: + logger.info(f"loading feature extractor configuration file {feature_extractor_file}") + else: + logger.info( + f"loading feature extractor configuration file {feature_extractor_file} from cache at {resolved_feature_extractor_file}" + ) + + return feature_extractor_dict, kwargs + + @classmethod + def from_dict(cls, feature_extractor_dict: Dict[str, Any], **kwargs) -> "PreTrainedFeatureExtractor": + """ + Instantiates a :class:`~transformers.PreTrainedFeatureExtractor` from a Python dictionary of parameters. + + Args: + feature_extractor_dict (:obj:`Dict[str, Any]`): + Dictionary that will be used to instantiate the feature extractor object. Such a dictionary can be + retrieved from a pretrained checkpoint by leveraging the + :func:`~transformers.PreTrainedFeatureExtractor.to_dict` method. + kwargs (:obj:`Dict[str, Any]`): + Additional parameters from which to initialize the feature extractor object. + + Returns: + :class:`~transformers.PreTrainedFeatureExtractor`: The feature extractor object instantiated from those + parameters. + """ + return_unused_kwargs = kwargs.pop("return_unused_kwargs", False) + + feature_extractor = cls(**feature_extractor_dict) + + # Update feature_extractor with kwargs if needed + to_remove = [] + for key, value in kwargs.items(): + if hasattr(feature_extractor, key): + setattr(feature_extractor, key, value) + to_remove.append(key) + for key in to_remove: + kwargs.pop(key, None) + + logger.info(f"Feature extractor {feature_extractor}") + if return_unused_kwargs: + return feature_extractor, kwargs + else: + return feature_extractor + + def to_dict(self) -> Dict[str, Any]: + """ + Serializes this instance to a Python dictionary. + + Returns: + :obj:`Dict[str, Any]`: Dictionary of all the attributes that make up this feature extractor instance. + """ + output = copy.deepcopy(self.__dict__) + + return output + + @classmethod + def from_json_file(cls, json_file: Union[str, os.PathLike]) -> "PreTrainedFeatureExtractor": + """ + Instantiates a :class:`~transformers.PreTrainedFeatureExtractor` from the path to a JSON file of parameters. + + Args: + json_file (:obj:`str` or :obj:`os.PathLike`): + Path to the JSON file containing the parameters. + + Returns: + :class:`~transformers.PreTrainedFeatureExtractor`: The feature_extractor object instantiated from that JSON + file. + """ + with open(json_file, "r", encoding="utf-8") as reader: + text = reader.read() + feature_extractor_dict = json.loads(text) + return cls(**feature_extractor_dict) + + def to_json_string(self) -> str: + """ + Serializes this instance to a JSON string. + + Returns: + :obj:`str`: String containing all the attributes that make up this feature_extractor instance in JSON + format. + """ + return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n" + + def to_json_file(self, json_file_path: Union[str, os.PathLike]): + """ + Save this instance to a JSON file. + + Args: + json_file_path (:obj:`str` or :obj:`os.PathLike`): + Path to the JSON file in which this feature_extractor instance's parameters will be saved. + """ + with open(json_file_path, "w", encoding="utf-8") as writer: + writer.write(self.to_json_string()) + + def __repr__(self): + return f"{self.__class__.__name__} {self.to_json_string()}" diff --git a/src/transformers/feature_extraction_utils.py b/src/transformers/feature_extraction_sequence_utils.py similarity index 56% rename from src/transformers/feature_extraction_utils.py rename to src/transformers/feature_extraction_sequence_utils.py index 250a144313e71f..c84923b33ff8f3 100644 --- a/src/transformers/feature_extraction_utils.py +++ b/src/transformers/feature_extraction_sequence_utils.py @@ -15,16 +15,13 @@ """ Feature extraction common class for python feature extractors. """ -import copy -import json -import os from collections import UserDict -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union import numpy as np +from .feature_extraction_saving_utils import FeatureExtractionSavingUtilsMixin from .file_utils import ( - FEATURE_EXTRACTOR_NAME, PaddingStrategy, TensorType, _is_jax, @@ -32,10 +29,7 @@ _is_tensorflow, _is_torch, _is_torch_device, - cached_path, - hf_bucket_url, is_flax_available, - is_remote_url, is_tf_available, is_torch_available, to_py_obj, @@ -52,9 +46,9 @@ import torch -class BatchFeature(UserDict): +class BatchSequenceFeature(UserDict): r""" - Holds the output of the :meth:`~transformers.PreTrainedFeatureExtractor.pad` and feature extractor specific + Holds the output of the :meth:`~transformers.PreTrainedSequenceFeatureExtractor.pad` and feature extractor specific ``__call__`` methods. This class is derived from a python dictionary and can be used as a dictionary. @@ -169,8 +163,8 @@ def convert_to_tensors(self, tensor_type: Optional[Union[str, TensorType]] = Non return self @torch_required - # Copied from transformers.tokenization_utils_base.BatchEncoding.to with BatchEncoding->BatchFeature - def to(self, device: Union[str, "torch.device"]) -> "BatchFeature": + # Copied from transformers.tokenization_utils_base.BatchEncoding.to with BatchEncoding->BatchSequenceFeature + def to(self, device: Union[str, "torch.device"]) -> "BatchSequenceFeature": """ Send all values to device by calling :obj:`v.to(device)` (PyTorch only). @@ -178,8 +172,8 @@ def to(self, device: Union[str, "torch.device"]) -> "BatchFeature": device (:obj:`str` or :obj:`torch.device`): The device to put the tensors on. Returns: - :class:`~transformers.BatchFeature`: The same instance of :class:`~transformers.BatchFeature` after - modification. + :class:`~transformers.BatchSequenceFeature`: The same instance of + :class:`~transformers.BatchSequenceFeature` after modification. """ # This check catches things like APEX blindly calling "to" on all inputs to a module @@ -188,11 +182,11 @@ def to(self, device: Union[str, "torch.device"]) -> "BatchFeature": if isinstance(device, str) or _is_torch_device(device) or isinstance(device, int): self.data = {k: v.to(device=device) for k, v in self.data.items()} else: - logger.warning(f"Attempting to cast a BatchFeature to type {str(device)}. This is not supported.") + logger.warning(f"Attempting to cast a BatchSequenceFeature to type {str(device)}. This is not supported.") return self -class PreTrainedFeatureExtractor: +class PreTrainedSequenceFeatureExtractor(FeatureExtractionSavingUtilsMixin): """ This is a general feature extraction class for speech recognition. @@ -221,284 +215,21 @@ def __init__(self, feature_size: int, sampling_rate: int, padding_value: float, logger.error(f"Can't set {key} with value {value} for {self}") raise err - @classmethod - def from_pretrained( - cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs - ) -> "PreTrainedFeatureExtractor": - r""" - Instantiate a :class:`~transformers.PreTrainedFeatureExtractor` (or a derived class) from a pretrained feature - extractor. - - Args: - pretrained_model_name_or_path (:obj:`str` or :obj:`os.PathLike`): - This can be either: - - - a string, the `model id` of a pretrained feature_extractor hosted inside a model repo on - huggingface.co. Valid model ids can be located at the root-level, like ``bert-base-uncased``, or - namespaced under a user or organization name, like ``dbmdz/bert-base-german-cased``. - - a path to a `directory` containing a feature extractor file saved using the - :func:`~transformers.PreTrainedFeatureExtractor.save_pretrained` method, e.g., - ``./my_model_directory/``. - - a path or url to a saved feature extractor JSON `file`, e.g., - ``./my_model_directory/feature_extraction_config.json``. - cache_dir (:obj:`str` or :obj:`os.PathLike`, `optional`): - Path to a directory in which a downloaded pretrained model feature extractor should be cached if the - standard cache should not be used. - force_download (:obj:`bool`, `optional`, defaults to :obj:`False`): - Whether or not to force to (re-)download the feature extractor files and override the cached versions - if they exist. - resume_download (:obj:`bool`, `optional`, defaults to :obj:`False`): - Whether or not to delete incompletely received file. Attempts to resume the download if such a file - exists. - proxies (:obj:`Dict[str, str]`, `optional`): - A dictionary of proxy servers to use by protocol or endpoint, e.g., :obj:`{'http': 'foo.bar:3128', - 'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request. - use_auth_token (:obj:`str` or `bool`, `optional`): - The token to use as HTTP bearer authorization for remote files. If :obj:`True`, will use the token - generated when running :obj:`transformers-cli login` (stored in :obj:`~/.huggingface`). - revision(:obj:`str`, `optional`, defaults to :obj:`"main"`): - The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a - git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any - identifier allowed by git. - return_unused_kwargs (:obj:`bool`, `optional`, defaults to :obj:`False`): - If :obj:`False`, then this function returns just the final feature extractor object. - - If :obj:`True`, then this functions returns a :obj:`Tuple(feature_extractor, unused_kwargs)` where - `unused_kwargs` is a dictionary consisting of the key/value pairs whose keys are not feature extractor - attributes: i.e., the part of ``kwargs`` which has not been used to update ``feature_extractor`` and is - otherwise ignored. - kwargs (:obj:`Dict[str, Any]`, `optional`): - The values in kwargs of any keys which are feature extractor attributes will be used to override the - loaded values. Behavior concerning key/value pairs whose keys are *not* feature extractor attributes is - controlled by the ``return_unused_kwargs`` keyword parameter. - - .. note:: - - Passing :obj:`use_auth_token=True` is required when you want to use a private model. - - - Returns: - :class:`~transformers.PreTrainedFeatureExtractor`: The feature extractor object instantiated from this - pretrained model. - - Examples:: - - # We can't instantiate directly the base class `PreTrainedFeatureExtractor` so let's show the examples on a - # derived class: Wav2Vec2FeatureExtractor - feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained('facebook/wav2vec2-base-960h') # Download feature_extraction_config from huggingface.co and cache. - feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained('./test/saved_model/') # E.g. feature_extractor (or model) was saved using `save_pretrained('./test/saved_model/')` - feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained('./test/saved_model/preprocessor_config.json') - feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained('facebook/wav2vec2-base-960h', return_attention_mask=False, foo=False) - assert feature_extractor.return_attention_mask is False - feature_extractor, unused_kwargs = Wav2Vec2FeatureExtractor.from_pretrained('facebook/wav2vec2-base-960h', return_attention_mask=False, - foo=False, return_unused_kwargs=True) - assert feature_extractor.return_attention_mask is False - assert unused_kwargs == {'foo': False} - - """ - feature_extractor_dict, kwargs = cls.get_feature_extractor_dict(pretrained_model_name_or_path, **kwargs) - - return cls.from_dict(feature_extractor_dict, **kwargs) - - def save_pretrained(self, save_directory: Union[str, os.PathLike]): - """ - Save a feature_extractor object to the directory ``save_directory``, so that it can be re-loaded using the - :func:`~transformers.PreTrainedFeatureExtractor.from_pretrained` class method. - - Args: - save_directory (:obj:`str` or :obj:`os.PathLike`): - Directory where the feature extractor JSON file will be saved (will be created if it does not exist). - """ - if os.path.isfile(save_directory): - raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file") - os.makedirs(save_directory, exist_ok=True) - # If we save using the predefined names, we can load using `from_pretrained` - output_feature_extractor_file = os.path.join(save_directory, FEATURE_EXTRACTOR_NAME) - - self.to_json_file(output_feature_extractor_file) - logger.info(f"Configuration saved in {output_feature_extractor_file}") - - @classmethod - def get_feature_extractor_dict( - cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs - ) -> Tuple[Dict[str, Any], Dict[str, Any]]: - """ - From a ``pretrained_model_name_or_path``, resolve to a dictionary of parameters, to be used for instantiating a - :class:`~transformers.PreTrainedFeatureExtractor` using ``from_dict``. - - Parameters: - pretrained_model_name_or_path (:obj:`str` or :obj:`os.PathLike`): - The identifier of the pre-trained checkpoint from which we want the dictionary of parameters. - - Returns: - :obj:`Tuple[Dict, Dict]`: The dictionary(ies) that will be used to instantiate the feature extractor - object. - """ - cache_dir = kwargs.pop("cache_dir", None) - force_download = kwargs.pop("force_download", False) - resume_download = kwargs.pop("resume_download", False) - proxies = kwargs.pop("proxies", None) - use_auth_token = kwargs.pop("use_auth_token", None) - local_files_only = kwargs.pop("local_files_only", False) - revision = kwargs.pop("revision", None) - - pretrained_model_name_or_path = str(pretrained_model_name_or_path) - if os.path.isdir(pretrained_model_name_or_path): - feature_extractor_file = os.path.join(pretrained_model_name_or_path, FEATURE_EXTRACTOR_NAME) - elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path): - feature_extractor_file = pretrained_model_name_or_path - else: - feature_extractor_file = hf_bucket_url( - pretrained_model_name_or_path, filename=FEATURE_EXTRACTOR_NAME, revision=revision, mirror=None - ) - - try: - # Load from URL or cache if already cached - resolved_feature_extractor_file = cached_path( - feature_extractor_file, - cache_dir=cache_dir, - force_download=force_download, - proxies=proxies, - resume_download=resume_download, - local_files_only=local_files_only, - use_auth_token=use_auth_token, - ) - # Load feature_extractor dict - with open(resolved_feature_extractor_file, "r", encoding="utf-8") as reader: - text = reader.read() - feature_extractor_dict = json.loads(text) - - except EnvironmentError as err: - logger.error(err) - msg = ( - f"Can't load feature extractor for '{pretrained_model_name_or_path}'. Make sure that:\n\n" - f"- '{pretrained_model_name_or_path}' is a correct model identifier listed on 'https://huggingface.co/models'\n\n" - f"- or '{pretrained_model_name_or_path}' is the correct path to a directory containing a {FEATURE_EXTRACTOR_NAME} file\n\n" - ) - raise EnvironmentError(msg) - - except json.JSONDecodeError: - msg = ( - f"Couldn't reach server at '{feature_extractor_file}' to download feature extractor configuration file or " - "feature extractor configuration file is not a valid JSON file. " - f"Please check network or file content here: {resolved_feature_extractor_file}." - ) - raise EnvironmentError(msg) - - if resolved_feature_extractor_file == feature_extractor_file: - logger.info(f"loading feature extractor configuration file {feature_extractor_file}") - else: - logger.info( - f"loading feature extractor configuration file {feature_extractor_file} from cache at {resolved_feature_extractor_file}" - ) - - return feature_extractor_dict, kwargs - - @classmethod - def from_dict(cls, feature_extractor_dict: Dict[str, Any], **kwargs) -> "PreTrainedFeatureExtractor": - """ - Instantiates a :class:`~transformers.PreTrainedFeatureExtractor` from a Python dictionary of parameters. - - Args: - feature_extractor_dict (:obj:`Dict[str, Any]`): - Dictionary that will be used to instantiate the feature extractor object. Such a dictionary can be - retrieved from a pretrained checkpoint by leveraging the - :func:`~transformers.PreTrainedFeatureExtractor.to_dict` method. - kwargs (:obj:`Dict[str, Any]`): - Additional parameters from which to initialize the feature extractor object. - - Returns: - :class:`~transformers.PreTrainedFeatureExtractor`: The feature extractor object instantiated from those - parameters. - """ - return_unused_kwargs = kwargs.pop("return_unused_kwargs", False) - - feature_extractor = cls(**feature_extractor_dict) - - # Update feature_extractor with kwargs if needed - to_remove = [] - for key, value in kwargs.items(): - if hasattr(feature_extractor, key): - setattr(feature_extractor, key, value) - to_remove.append(key) - for key in to_remove: - kwargs.pop(key, None) - - logger.info(f"Feature extractor {feature_extractor}") - if return_unused_kwargs: - return feature_extractor, kwargs - else: - return feature_extractor - - def to_dict(self) -> Dict[str, Any]: - """ - Serializes this instance to a Python dictionary. - - Returns: - :obj:`Dict[str, Any]`: Dictionary of all the attributes that make up this feature extractor instance. - """ - output = copy.deepcopy(self.__dict__) - - return output - - @classmethod - def from_json_file(cls, json_file: Union[str, os.PathLike]) -> "PreTrainedFeatureExtractor": - """ - Instantiates a :class:`~transformers.PreTrainedFeatureExtractor` from the path to a JSON file of parameters. - - Args: - json_file (:obj:`str` or :obj:`os.PathLike`): - Path to the JSON file containing the parameters. - - Returns: - :class:`~transformers.PreTrainedFeatureExtractor`: The feature_extractor object instantiated from that JSON - file. - - """ - with open(json_file, "r", encoding="utf-8") as reader: - text = reader.read() - feature_extractor_dict = json.loads(text) - return cls(**feature_extractor_dict) - - def to_json_string(self) -> str: - """ - Serializes this instance to a JSON string. - - Returns: - :obj:`str`: String containing all the attributes that make up this feature_extractor instance in JSON - format. - """ - return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n" - - def to_json_file(self, json_file_path: Union[str, os.PathLike]): - """ - Save this instance to a JSON file. - - Args: - json_file_path (:obj:`str` or :obj:`os.PathLike`): - Path to the JSON file in which this feature_extractor instance's parameters will be saved. - """ - with open(json_file_path, "w", encoding="utf-8") as writer: - writer.write(self.to_json_string()) - - def __repr__(self): - return f"{self.__class__.__name__} {self.to_json_string()}" - def pad( self, processed_features: Union[ - BatchFeature, - List[BatchFeature], - Dict[str, BatchFeature], - Dict[str, List[BatchFeature]], - List[Dict[str, BatchFeature]], + BatchSequenceFeature, + List[BatchSequenceFeature], + Dict[str, BatchSequenceFeature], + Dict[str, List[BatchSequenceFeature]], + List[Dict[str, BatchSequenceFeature]], ], padding: Union[bool, str, PaddingStrategy] = True, max_length: Optional[int] = None, pad_to_multiple_of: Optional[int] = None, return_attention_mask: Optional[bool] = None, return_tensors: Optional[Union[str, TensorType]] = None, - ) -> BatchFeature: + ) -> BatchSequenceFeature: """ Pad input values / input vectors or a batch of input values / input vectors up to predefined length or to the max sequence length in the batch. @@ -513,11 +244,12 @@ def pad( the case of PyTorch tensors, you will lose the specific device of your tensors however. Args: - processed_features (:class:`~transformers.BatchFeature`, list of :class:`~transformers.BatchFeature`, :obj:`Dict[str, List[float]]`, :obj:`Dict[str, List[List[float]]` or :obj:`List[Dict[str, List[float]]]`): - Processed inputs. Can represent one input (:class:`~transformers.BatchFeature` or :obj:`Dict[str, - List[float]]`) or a batch of input values / vectors (list of :class:`~transformers.BatchFeature`, - `Dict[str, List[List[float]]]` or `List[Dict[str, List[float]]]`) so you can use this method during - preprocessing as well as in a PyTorch Dataloader collate function. + processed_features (:class:`~transformers.BatchSequenceFeature`, list of :class:`~transformers.BatchSequenceFeature`, :obj:`Dict[str, List[float]]`, :obj:`Dict[str, List[List[float]]` or :obj:`List[Dict[str, List[float]]]`): + Processed inputs. Can represent one input (:class:`~transformers.BatchSequenceFeature` or + :obj:`Dict[str, List[float]]`) or a batch of input values / vectors (list of + :class:`~transformers.BatchSequenceFeature`, `Dict[str, List[List[float]]]` or `List[Dict[str, + List[float]]]`) so you can use this method during preprocessing as well as in a PyTorch Dataloader + collate function. Instead of :obj:`List[float]` you can have tensors (numpy arrays, PyTorch tensors or TensorFlow tensors), see the note above for the return type. @@ -552,7 +284,9 @@ def pad( """ # If we have a list of dicts, let's convert it in a dict of lists # We do this to allow using this method as a collate_fn function in PyTorch Dataloader - if isinstance(processed_features, (list, tuple)) and isinstance(processed_features[0], (dict, BatchFeature)): + if isinstance(processed_features, (list, tuple)) and isinstance( + processed_features[0], (dict, BatchSequenceFeature) + ): processed_features = { key: [example[key] for example in processed_features] for key in processed_features[0].keys() } @@ -560,7 +294,7 @@ def pad( # The model's main input name, usually `input_values`, has be passed for padding if self.model_input_names[0] not in processed_features: raise ValueError( - "You should supply an instance of :class:`~transformers.BatchFeature` or list of :class:`~transformers.BatchFeature` to this method" + "You should supply an instance of :class:`~transformers.BatchSequenceFeature` or list of :class:`~transformers.BatchSequenceFeature` to this method" f"that includes {self.model_input_names[0]}, but you provided {list(processed_features.keys())}" ) @@ -615,7 +349,7 @@ def pad( pad_to_multiple_of=pad_to_multiple_of, return_attention_mask=return_attention_mask, ) - return BatchFeature(processed_features, tensor_type=return_tensors) + return BatchSequenceFeature(processed_features, tensor_type=return_tensors) batch_size = len(required_input) assert all( @@ -642,11 +376,11 @@ def pad( batch_outputs[key] = [] batch_outputs[key].append(value) - return BatchFeature(batch_outputs, tensor_type=return_tensors) + return BatchSequenceFeature(batch_outputs, tensor_type=return_tensors) def _pad( self, - processed_features: Union[Dict[str, List[float]], BatchFeature], + processed_features: Union[Dict[str, List[float]], BatchSequenceFeature], max_length: Optional[int] = None, padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD, pad_to_multiple_of: Optional[int] = None, diff --git a/src/transformers/models/wav2vec2/feature_extraction_wav2vec2.py b/src/transformers/models/wav2vec2/feature_extraction_wav2vec2.py index bc4297c1ac19a7..4a2a868de47302 100644 --- a/src/transformers/models/wav2vec2/feature_extraction_wav2vec2.py +++ b/src/transformers/models/wav2vec2/feature_extraction_wav2vec2.py @@ -20,7 +20,7 @@ import numpy as np -from ...feature_extraction_utils import BatchFeature, PreTrainedFeatureExtractor +from ...feature_extraction_sequence_utils import BatchSequenceFeature, PreTrainedSequenceFeatureExtractor from ...file_utils import PaddingStrategy, TensorType from ...utils import logging @@ -28,7 +28,7 @@ logger = logging.get_logger(__name__) -class Wav2Vec2FeatureExtractor(PreTrainedFeatureExtractor): +class Wav2Vec2FeatureExtractor(PreTrainedSequenceFeatureExtractor): r""" Constructs a Wav2Vec2 feature extractor. @@ -93,7 +93,7 @@ def __call__( return_tensors: Optional[Union[str, TensorType]] = None, sampling_rate: Optional[int] = None, **kwargs - ) -> BatchFeature: + ) -> BatchSequenceFeature: """ Main method to featurize and prepare for the model one or several sequence(s). sequences. @@ -179,7 +179,7 @@ def __call__( raw_speech = self.zero_mean_unit_var_norm(raw_speech) # convert into correct format for padding - encoded_inputs = BatchFeature({"input_values": raw_speech}) + encoded_inputs = BatchSequenceFeature({"input_values": raw_speech}) padded_inputs = self.pad( encoded_inputs, diff --git a/src/transformers/models/wav2vec2/processing_wav2vec2.py b/src/transformers/models/wav2vec2/processing_wav2vec2.py index 71202a2ff07f78..9676225447dbc3 100644 --- a/src/transformers/models/wav2vec2/processing_wav2vec2.py +++ b/src/transformers/models/wav2vec2/processing_wav2vec2.py @@ -59,7 +59,8 @@ def save_pretrained(self, save_directory): .. note:: - This class method is simply calling :meth:`~transformers.PreTrainedFeatureExtractor.save_pretrained` and + This class method is simply calling + :meth:`~transformers.PreTrainedSequenceFeatureExtractor.save_pretrained` and :meth:`~transformers.tokenization_utils_base.PreTrainedTokenizer.save_pretrained`. Please refer to the docstrings of the methods above for more information. @@ -80,7 +81,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): .. note:: This class method is simply calling Wav2Vec2FeatureExtractor's - :meth:`~transformers.PreTrainedFeatureExtractor.from_pretrained` and Wav2Vec2CTCTokenizer's + :meth:`~transformers.PreTrainedSequenceFeatureExtractor.from_pretrained` and Wav2Vec2CTCTokenizer's :meth:`~transformers.tokenization_utils_base.PreTrainedTokenizer.from_pretrained`. Please refer to the docstrings of the methods above for more information. @@ -92,12 +93,13 @@ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): huggingface.co. Valid model ids can be located at the root-level, like ``bert-base-uncased``, or namespaced under a user or organization name, like ``dbmdz/bert-base-german-cased``. - a path to a `directory` containing a feature extractor file saved using the - :meth:`~transformers.PreTrainedFeatureExtractor.save_pretrained` method, e.g., + :meth:`~transformers.PreTrainedSequenceFeatureExtractor.save_pretrained` method, e.g., ``./my_model_directory/``. - a path or url to a saved feature extractor JSON `file`, e.g., ``./my_model_directory/feature_extraction_config.json``. **kwargs - Additional keyword arguments passed along to both :class:`~transformers.PreTrainedFeatureExtractor` and + Additional keyword arguments passed along to both + :class:`~transformers.PreTrainedSequenceFeatureExtractor` and :class:`~transformers.PreTrainedTokenizer` """ feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(pretrained_model_name_or_path, **kwargs) diff --git a/tests/test_feature_extraction_saving_common.py b/tests/test_feature_extraction_saving_common.py new file mode 100644 index 00000000000000..49dfa6dfd4dbcb --- /dev/null +++ b/tests/test_feature_extraction_saving_common.py @@ -0,0 +1,50 @@ +# coding=utf-8 +# Copyright 2021 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import json +import os +import tempfile + + +class FeatureExtractionSavingTestMixin: + def test_feat_extract_to_json_string(self): + feat_extract = self.feature_extraction_class(**self.feat_extract_dict) + obj = json.loads(feat_extract.to_json_string()) + for key, value in self.feat_extract_dict.items(): + self.assertEqual(obj[key], value) + + def test_feat_extract_to_json_file(self): + feat_extract_first = self.feature_extraction_class(**self.feat_extract_dict) + + with tempfile.TemporaryDirectory() as tmpdirname: + json_file_path = os.path.join(tmpdirname, "feat_extract.json") + feat_extract_first.to_json_file(json_file_path) + feat_extract_second = self.feature_extraction_class.from_json_file(json_file_path) + + self.assertEqual(feat_extract_second.to_dict(), feat_extract_first.to_dict()) + + def test_feat_extract_from_and_save_pretrained(self): + feat_extract_first = self.feature_extraction_class(**self.feat_extract_dict) + + with tempfile.TemporaryDirectory() as tmpdirname: + feat_extract_first.save_pretrained(tmpdirname) + feat_extract_second = self.feature_extraction_class.from_pretrained(tmpdirname) + + self.assertEqual(feat_extract_second.to_dict(), feat_extract_first.to_dict()) + + def test_init_without_params(self): + feat_extract = self.feature_extraction_class() + self.assertIsNotNone(feat_extract) diff --git a/tests/test_feature_extraction_wav2vec2.py b/tests/test_feature_extraction_wav2vec2.py index 179bafe6137ab9..771974a3982179 100644 --- a/tests/test_feature_extraction_wav2vec2.py +++ b/tests/test_feature_extraction_wav2vec2.py @@ -23,7 +23,7 @@ from transformers import WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST, Wav2Vec2Config, Wav2Vec2FeatureExtractor from transformers.testing_utils import slow -from .test_feature_extraction_common import FeatureExtractionMixin +from .test_sequence_feature_extraction_common import SequenceFeatureExtractionTestMixin global_rng = random.Random() @@ -94,7 +94,7 @@ def _flatten(list_of_lists): return speech_inputs -class Wav2Vec2FeatureExtractionTest(FeatureExtractionMixin, unittest.TestCase): +class Wav2Vec2FeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.TestCase): feature_extraction_class = Wav2Vec2FeatureExtractor diff --git a/tests/test_feature_extraction_common.py b/tests/test_sequence_feature_extraction_common.py similarity index 83% rename from tests/test_feature_extraction_common.py rename to tests/test_sequence_feature_extraction_common.py index 77b82019bd4e5f..0dfd1deffbc931 100644 --- a/tests/test_feature_extraction_common.py +++ b/tests/test_sequence_feature_extraction_common.py @@ -14,17 +14,15 @@ # limitations under the License. -import json -import os -import tempfile - import numpy as np -from transformers import BatchFeature +from transformers import BatchSequenceFeature from transformers.testing_utils import require_tf, require_torch +from .test_feature_extraction_saving_common import FeatureExtractionSavingTestMixin + -class FeatureExtractionMixin: +class SequenceFeatureExtractionTestMixin(FeatureExtractionSavingTestMixin): # to overwrite at feature extractactor specific tests feat_extract_tester = None @@ -40,46 +38,17 @@ def test_feat_extract_common_properties(self): self.assertTrue(hasattr(feat_extract, "sampling_rate")) self.assertTrue(hasattr(feat_extract, "padding_value")) - def test_feat_extract_to_json_string(self): - feat_extract = self.feature_extraction_class(**self.feat_extract_dict) - obj = json.loads(feat_extract.to_json_string()) - for key, value in self.feat_extract_dict.items(): - self.assertEqual(obj[key], value) - - def test_feat_extract_to_json_file(self): - feat_extract_first = self.feature_extraction_class(**self.feat_extract_dict) - - with tempfile.TemporaryDirectory() as tmpdirname: - json_file_path = os.path.join(tmpdirname, "feat_extract.json") - feat_extract_first.to_json_file(json_file_path) - feat_extract_second = self.feature_extraction_class.from_json_file(json_file_path) - - self.assertEqual(feat_extract_second.to_dict(), feat_extract_first.to_dict()) - - def test_feat_extract_from_and_save_pretrained(self): - feat_extract_first = self.feature_extraction_class(**self.feat_extract_dict) - - with tempfile.TemporaryDirectory() as tmpdirname: - feat_extract_first.save_pretrained(tmpdirname) - feat_extract_second = self.feature_extraction_class.from_pretrained(tmpdirname) - - self.assertEqual(feat_extract_second.to_dict(), feat_extract_first.to_dict()) - - def test_init_without_params(self): - feat_extract = self.feature_extraction_class() - self.assertIsNotNone(feat_extract) - def test_batch_feature(self): speech_inputs = self.feat_extract_tester.prepare_inputs_for_common() feat_extract = self.feature_extraction_class(**self.feat_extract_dict) input_name = feat_extract.model_input_names[0] - processed_features = BatchFeature({input_name: speech_inputs}) + processed_features = BatchSequenceFeature({input_name: speech_inputs}) self.assertTrue(all(len(x) == len(y) for x, y in zip(speech_inputs, processed_features[input_name]))) speech_inputs = self.feat_extract_tester.prepare_inputs_for_common(equal_length=True) - processed_features = BatchFeature({input_name: speech_inputs}, tensor_type="np") + processed_features = BatchSequenceFeature({input_name: speech_inputs}, tensor_type="np") batch_features_input = processed_features[input_name] @@ -97,7 +66,7 @@ def test_batch_feature_pt(self): feat_extract = self.feature_extraction_class(**self.feat_extract_dict) input_name = feat_extract.model_input_names[0] - processed_features = BatchFeature({input_name: speech_inputs}, tensor_type="pt") + processed_features = BatchSequenceFeature({input_name: speech_inputs}, tensor_type="pt") batch_features_input = processed_features[input_name] @@ -115,7 +84,7 @@ def test_batch_feature_tf(self): feat_extract = self.feature_extraction_class(**self.feat_extract_dict) input_name = feat_extract.model_input_names[0] - processed_features = BatchFeature({input_name: speech_inputs}, tensor_type="tf") + processed_features = BatchSequenceFeature({input_name: speech_inputs}, tensor_type="tf") batch_features_input = processed_features[input_name] @@ -148,7 +117,7 @@ def _inputs_are_equal(input_1, input_2): speech_inputs = self.feat_extract_tester.prepare_inputs_for_common(numpify=numpify) input_name = feat_extract.model_input_names[0] - processed_features = BatchFeature({input_name: speech_inputs}) + processed_features = BatchSequenceFeature({input_name: speech_inputs}) pad_diff = self.feat_extract_tester.seq_length_diff pad_max_length = self.feat_extract_tester.max_seq_length + pad_diff @@ -248,7 +217,7 @@ def test_padding_accepts_tensors_pt(self): speech_inputs = self.feat_extract_tester.prepare_inputs_for_common() input_name = feat_extract.model_input_names[0] - processed_features = BatchFeature({input_name: speech_inputs}) + processed_features = BatchSequenceFeature({input_name: speech_inputs}) input_np = feat_extract.pad(processed_features, padding="longest", return_tensors="np")[input_name] input_pt = feat_extract.pad(processed_features, padding="longest", return_tensors="pt")[input_name] @@ -261,7 +230,7 @@ def test_padding_accepts_tensors_tf(self): speech_inputs = self.feat_extract_tester.prepare_inputs_for_common() input_name = feat_extract.model_input_names[0] - processed_features = BatchFeature({input_name: speech_inputs}) + processed_features = BatchSequenceFeature({input_name: speech_inputs}) input_np = feat_extract.pad(processed_features, padding="longest", return_tensors="np")[input_name] input_tf = feat_extract.pad(processed_features, padding="longest", return_tensors="tf")[input_name] @@ -276,7 +245,7 @@ def test_attention_mask(self): input_lenghts = [len(x) for x in speech_inputs] input_name = feat_extract.model_input_names[0] - processed = BatchFeature({input_name: speech_inputs}) + processed = BatchSequenceFeature({input_name: speech_inputs}) processed = feat_extract.pad(processed, padding="longest", return_tensors="np") self.assertIn("attention_mask", processed) From 88ed32f315904a7e441532ccafc502581dd3c073 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 8 Mar 2021 17:45:17 +0300 Subject: [PATCH 2/9] finish refactor --- .../source/main_classes/feature_extractor.rst | 7 +-- .../feature_extraction_saving_utils.py | 52 +++++++++---------- .../feature_extraction_sequence_utils.py | 5 +- src/transformers/tokenization_utils_base.py | 3 +- 4 files changed, 31 insertions(+), 36 deletions(-) diff --git a/docs/source/main_classes/feature_extractor.rst b/docs/source/main_classes/feature_extractor.rst index 5aa9f112a3ab9d..58d88878d50fc1 100644 --- a/docs/source/main_classes/feature_extractor.rst +++ b/docs/source/main_classes/feature_extractor.rst @@ -14,9 +14,10 @@ Feature Extractor ----------------------------------------------------------------------------------------------------------------------- -A feature extractor is in charge of preparing read-in audio files for a speech model. This includes feature extraction, -such as processing audio files to, *e.g.*, Log-Mel Spectrogram features, but also padding, normalization, and -conversion to Numpy, PyTorch, and TensorFlow tensors. +A feature extractor is in charge of preparing input features for a multi-modal model. This includes feature extraction +from sequences, *e.g.*, pre-processing audio files to Log-Mel Spectrogram features, feature extraction from images +*e.g.* cropping image image files, but also padding, normalization, and conversion to Numpy, PyTorch, and TensorFlow +tensors. PreTrainedSequenceFeatureExtractor diff --git a/src/transformers/feature_extraction_saving_utils.py b/src/transformers/feature_extraction_saving_utils.py index a7278762999bc6..4f8f8c5928eef8 100644 --- a/src/transformers/feature_extraction_saving_utils.py +++ b/src/transformers/feature_extraction_saving_utils.py @@ -13,8 +13,9 @@ # See the License for the specific language governing permissions and # limitations under the License. """ - Feature extraction common class for python feature extractors. + Feature extraction saving/loading class for common feature extractors. """ + import copy import json import os @@ -26,20 +27,13 @@ logger = logging.get_logger(__name__) -PreTrainedFeatureExtractor = Union["PreTrainedSequenceFeatureExtractor"] +PreTrainedFeatureExtractor = Union["PreTrainedSequenceFeatureExtractor"] # noqa: F821 class FeatureExtractionSavingUtilsMixin: """ - This is a general feature extraction class for speech recognition. - - Args: - feature_size (:obj:`int`): - The feature dimension of the extracted features. - sampling_rate (:obj:`int`): - The sampling rate at which the audio files should be digitalized expressed in Hertz per second (Hz). - padding_value (:obj:`float`): - The value that is used to fill the padding values / vectors. + This is a feature extraction mixin used to provide saving/loading functionality for sequential and image feature + extractors. """ @classmethod @@ -47,8 +41,8 @@ def from_pretrained( cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs ) -> "PreTrainedFeatureExtractor": r""" - Instantiate a :class:`~transformers.PreTrainedFeatureExtractor` (or a derived class) from a pretrained feature - extractor. + Instantiate a :class:`~transformers.PreTrainedSequenceFeatureExtractor` (or a derived class) from a pretrained + feature extractor. Args: pretrained_model_name_or_path (:obj:`str` or :obj:`os.PathLike`): @@ -58,7 +52,7 @@ def from_pretrained( huggingface.co. Valid model ids can be located at the root-level, like ``bert-base-uncased``, or namespaced under a user or organization name, like ``dbmdz/bert-base-german-cased``. - a path to a `directory` containing a feature extractor file saved using the - :func:`~transformers.PreTrainedFeatureExtractor.save_pretrained` method, e.g., + :func:`~transformers.PreTrainedSequenceFeatureExtractor.save_pretrained` method, e.g., ``./my_model_directory/``. - a path or url to a saved feature extractor JSON `file`, e.g., ``./my_model_directory/feature_extraction_config.json``. @@ -93,11 +87,11 @@ def from_pretrained( Passing :obj:`use_auth_token=True` is required when you want to use a private model. Returns: - :class:`~transformers.PreTrainedFeatureExtractor`: The feature extractor object instantiated from this - pretrained model. + :class:`~transformers.PreTrainedSequenceFeatureExtractor`: The feature extractor object instantiated from + this pretrained model. Examples:: - # We can't instantiate directly the base class `PreTrainedFeatureExtractor` so let's show the examples on a + # We can't instantiate directly the base class `PreTrainedSequenceFeatureExtractor` so let's show the examples on a # derived class: Wav2Vec2FeatureExtractor feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained('facebook/wav2vec2-base-960h') # Download feature_extraction_config from huggingface.co and cache. feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained('./test/saved_model/') # E.g. feature_extractor (or model) was saved using `save_pretrained('./test/saved_model/')` @@ -116,7 +110,7 @@ def from_pretrained( def save_pretrained(self, save_directory: Union[str, os.PathLike]): """ Save a feature_extractor object to the directory ``save_directory``, so that it can be re-loaded using the - :func:`~transformers.PreTrainedFeatureExtractor.from_pretrained` class method. + :func:`~transformers.PreTrainedSequenceFeatureExtractor.from_pretrained` class method. Args: save_directory (:obj:`str` or :obj:`os.PathLike`): @@ -137,7 +131,7 @@ def get_feature_extractor_dict( ) -> Tuple[Dict[str, Any], Dict[str, Any]]: """ From a ``pretrained_model_name_or_path``, resolve to a dictionary of parameters, to be used for instantiating a - :class:`~transformers.PreTrainedFeatureExtractor` using ``from_dict``. + :class:`~transformers.PreTrainedSequenceFeatureExtractor` using ``from_dict``. Parameters: pretrained_model_name_or_path (:obj:`str` or :obj:`os.PathLike`): @@ -208,21 +202,22 @@ def get_feature_extractor_dict( return feature_extractor_dict, kwargs @classmethod - def from_dict(cls, feature_extractor_dict: Dict[str, Any], **kwargs) -> "PreTrainedFeatureExtractor": + def from_dict(cls, feature_extractor_dict: Dict[str, Any], **kwargs) -> PreTrainedFeatureExtractor: """ - Instantiates a :class:`~transformers.PreTrainedFeatureExtractor` from a Python dictionary of parameters. + Instantiates a :class:`~transformers.PreTrainedSequenceFeatureExtractor` from a Python dictionary of + parameters. Args: feature_extractor_dict (:obj:`Dict[str, Any]`): Dictionary that will be used to instantiate the feature extractor object. Such a dictionary can be retrieved from a pretrained checkpoint by leveraging the - :func:`~transformers.PreTrainedFeatureExtractor.to_dict` method. + :func:`~transformers.PreTrainedSequenceFeatureExtractor.to_dict` method. kwargs (:obj:`Dict[str, Any]`): Additional parameters from which to initialize the feature extractor object. Returns: - :class:`~transformers.PreTrainedFeatureExtractor`: The feature extractor object instantiated from those - parameters. + :class:`~transformers.PreTrainedSequenceFeatureExtractor`: The feature extractor object instantiated from + those parameters. """ return_unused_kwargs = kwargs.pop("return_unused_kwargs", False) @@ -255,17 +250,18 @@ def to_dict(self) -> Dict[str, Any]: return output @classmethod - def from_json_file(cls, json_file: Union[str, os.PathLike]) -> "PreTrainedFeatureExtractor": + def from_json_file(cls, json_file: Union[str, os.PathLike]) -> PreTrainedFeatureExtractor: """ - Instantiates a :class:`~transformers.PreTrainedFeatureExtractor` from the path to a JSON file of parameters. + Instantiates a :class:`~transformers.PreTrainedSequenceFeatureExtractor` from the path to a JSON file of + parameters. Args: json_file (:obj:`str` or :obj:`os.PathLike`): Path to the JSON file containing the parameters. Returns: - :class:`~transformers.PreTrainedFeatureExtractor`: The feature_extractor object instantiated from that JSON - file. + :class:`~transformers.PreTrainedSequenceFeatureExtractor`: The feature_extractor object instantiated from + that JSON file. """ with open(json_file, "r", encoding="utf-8") as reader: text = reader.read() diff --git a/src/transformers/feature_extraction_sequence_utils.py b/src/transformers/feature_extraction_sequence_utils.py index c84923b33ff8f3..fe28ace79d2752 100644 --- a/src/transformers/feature_extraction_sequence_utils.py +++ b/src/transformers/feature_extraction_sequence_utils.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """ - Feature extraction common class for python feature extractors. + Sequence feature extraction class for common feature extrcactors to preprocess sequences. """ from collections import UserDict from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union @@ -172,8 +172,7 @@ def to(self, device: Union[str, "torch.device"]) -> "BatchSequenceFeature": device (:obj:`str` or :obj:`torch.device`): The device to put the tensors on. Returns: - :class:`~transformers.BatchSequenceFeature`: The same instance of - :class:`~transformers.BatchSequenceFeature` after modification. + :class:`~transformers.BatchSequenceFeature`: The same instance after modification. """ # This check catches things like APEX blindly calling "to" on all inputs to a module diff --git a/src/transformers/tokenization_utils_base.py b/src/transformers/tokenization_utils_base.py index aefe209b65edf1..20678875d7b138 100644 --- a/src/transformers/tokenization_utils_base.py +++ b/src/transformers/tokenization_utils_base.py @@ -727,8 +727,7 @@ def to(self, device: Union[str, "torch.device"]) -> "BatchEncoding": device (:obj:`str` or :obj:`torch.device`): The device to put the tensors on. Returns: - :class:`~transformers.BatchEncoding`: The same instance of :class:`~transformers.BatchEncoding` after - modification. + :class:`~transformers.BatchEncoding`: The same instance after modification. """ # This check catches things like APEX blindly calling "to" on all inputs to a module From add513aef6fa986d0e5f072e2658401167ea25d9 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 8 Mar 2021 18:08:20 +0300 Subject: [PATCH 3/9] finish refactor --- .../source/main_classes/feature_extractor.rst | 4 +- src/transformers/__init__.py | 4 +- ....py => feature_extraction_common_utils.py} | 166 ++++++++++++++- .../feature_extraction_sequence_utils.py | 189 ++---------------- .../wav2vec2/feature_extraction_wav2vec2.py | 7 +- ...test_sequence_feature_extraction_common.py | 18 +- 6 files changed, 198 insertions(+), 190 deletions(-) rename src/transformers/{feature_extraction_saving_utils.py => feature_extraction_common_utils.py} (71%) diff --git a/docs/source/main_classes/feature_extractor.rst b/docs/source/main_classes/feature_extractor.rst index 58d88878d50fc1..adbda92a01b1fb 100644 --- a/docs/source/main_classes/feature_extractor.rst +++ b/docs/source/main_classes/feature_extractor.rst @@ -27,8 +27,8 @@ PreTrainedSequenceFeatureExtractor :members: from_pretrained, save_pretrained, pad -BatchSequenceFeature +BatchFeature ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -.. autoclass:: transformers.BatchSequenceFeature +.. autoclass:: transformers.BatchFeature :members: diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 290411216dffd4..cd61380d6f1efc 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -246,7 +246,7 @@ "SpecialTokensMixin", "TokenSpan", ], - "feature_extraction_sequence_utils": ["PreTrainedSequenceFeatureExtractor", "BatchSequenceFeature"], + "feature_extraction_sequence_utils": ["PreTrainedSequenceFeatureExtractor", "BatchFeature"], "trainer_callback": [ "DefaultFlowCallback", "EarlyStoppingCallback", @@ -1250,7 +1250,7 @@ ) # Feature Extractor - from .feature_extraction_sequence_utils import BatchSequenceFeature, PreTrainedSequenceFeatureExtractor + from .feature_extraction_common_utils import BatchFeature, PreTrainedSequenceFeatureExtractor # Files and general utilities from .file_utils import ( diff --git a/src/transformers/feature_extraction_saving_utils.py b/src/transformers/feature_extraction_common_utils.py similarity index 71% rename from src/transformers/feature_extraction_saving_utils.py rename to src/transformers/feature_extraction_common_utils.py index 4f8f8c5928eef8..d81eff5349ea86 100644 --- a/src/transformers/feature_extraction_saving_utils.py +++ b/src/transformers/feature_extraction_common_utils.py @@ -19,17 +19,177 @@ import copy import json import os -from typing import Any, Dict, Tuple, Union - -from .file_utils import FEATURE_EXTRACTOR_NAME, cached_path, hf_bucket_url, is_remote_url +from collections import UserDict +from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union + +import numpy as np + +from .file_utils import ( + FEATURE_EXTRACTOR_NAME, + TensorType, + _is_jax, + _is_numpy, + _is_torch_device, + cached_path, + hf_bucket_url, + is_flax_available, + is_remote_url, + is_tf_available, + is_torch_available, + torch_required, +) from .utils import logging +if TYPE_CHECKING: + if is_torch_available(): + import torch + + logger = logging.get_logger(__name__) PreTrainedFeatureExtractor = Union["PreTrainedSequenceFeatureExtractor"] # noqa: F821 +class BatchFeature(UserDict): + r""" + Holds the output of the :meth:`~transformers.PreTrainedSequenceFeatureExtractor.pad` and feature extractor specific + ``__call__`` methods. + + This class is derived from a python dictionary and can be used as a dictionary. + + Args: + data (:obj:`dict`): + Dictionary of lists/arrays/tensors returned by the __call__/pad methods ('input_values', 'attention_mask', + etc.). + tensor_type (:obj:`Union[None, str, TensorType]`, `optional`): + You can give a tensor_type here to convert the lists of integers in PyTorch/TensorFlow/Numpy Tensors at + initialization. + """ + + def __init__(self, data: Optional[Dict[str, Any]] = None, tensor_type: Union[None, str, TensorType] = None): + super().__init__(data) + self.convert_to_tensors(tensor_type=tensor_type) + + def __getitem__(self, item: str) -> Union[Any]: + """ + If the key is a string, returns the value of the dict associated to :obj:`key` ('input_values', + 'attention_mask', etc.). + """ + if isinstance(item, str): + return self.data[item] + else: + raise KeyError("Indexing with integers is not available when using Python based feature extractors") + + def __getattr__(self, item: str): + try: + return self.data[item] + except KeyError: + raise AttributeError + + def __getstate__(self): + return {"data": self.data} + + def __setstate__(self, state): + if "data" in state: + self.data = state["data"] + + # Copied from transformers.tokenization_utils_base.BatchEncoding.keys + def keys(self): + return self.data.keys() + + # Copied from transformers.tokenization_utils_base.BatchEncoding.values + def values(self): + return self.data.values() + + # Copied from transformers.tokenization_utils_base.BatchEncoding.items + def items(self): + return self.data.items() + + def convert_to_tensors(self, tensor_type: Optional[Union[str, TensorType]] = None): + """ + Convert the inner content to tensors. + + Args: + tensor_type (:obj:`str` or :class:`~transformers.file_utils.TensorType`, `optional`): + The type of tensors to use. If :obj:`str`, should be one of the values of the enum + :class:`~transformers.file_utils.TensorType`. If :obj:`None`, no modification is done. + """ + if tensor_type is None: + return self + + # Convert to TensorType + if not isinstance(tensor_type, TensorType): + tensor_type = TensorType(tensor_type) + + # Get a function reference for the correct framework + if tensor_type == TensorType.TENSORFLOW: + if not is_tf_available(): + raise ImportError( + "Unable to convert output to TensorFlow tensors format, TensorFlow is not installed." + ) + import tensorflow as tf + + as_tensor = tf.constant + is_tensor = tf.is_tensor + elif tensor_type == TensorType.PYTORCH: + if not is_torch_available(): + raise ImportError("Unable to convert output to PyTorch tensors format, PyTorch is not installed.") + import torch + + as_tensor = torch.tensor + is_tensor = torch.is_tensor + elif tensor_type == TensorType.JAX: + if not is_flax_available(): + raise ImportError("Unable to convert output to JAX tensors format, JAX is not installed.") + import jax.numpy as jnp # noqa: F811 + + as_tensor = jnp.array + is_tensor = _is_jax + else: + as_tensor = np.asarray + is_tensor = _is_numpy + + # Do the tensor conversion in batch + for key, value in self.items(): + try: + if not is_tensor(value): + tensor = as_tensor(value) + + self[key] = tensor + except: # noqa E722 + if key == "overflowing_values": + raise ValueError("Unable to create tensor returning overflowing values of different lengths. ") + raise ValueError( + "Unable to create tensor, you should probably activate padding " + "with 'padding=True' to have batched tensors with the same length." + ) + + return self + + @torch_required + # Copied from transformers.tokenization_utils_base.BatchEncoding.to with BatchEncoding->BatchFeature + def to(self, device: Union[str, "torch.device"]) -> "BatchFeature": + """ + Send all values to device by calling :obj:`v.to(device)` (PyTorch only). + + Args: + device (:obj:`str` or :obj:`torch.device`): The device to put the tensors on. + + Returns: + :class:`~transformers.BatchFeature`: The same instance after modification. + """ + + # This check catches things like APEX blindly calling "to" on all inputs to a module + # Otherwise it passes the casts down and casts the LongTensor containing the token idxs + # into a HalfTensor + if isinstance(device, str) or _is_torch_device(device) or isinstance(device, int): + self.data = {k: v.to(device=device) for k, v in self.data.items()} + else: + logger.warning(f"Attempting to cast a BatchFeature to type {str(device)}. This is not supported.") + return self + + class FeatureExtractionSavingUtilsMixin: """ This is a feature extraction mixin used to provide saving/loading functionality for sequential and image feature diff --git a/src/transformers/feature_extraction_sequence_utils.py b/src/transformers/feature_extraction_sequence_utils.py index fe28ace79d2752..8069a35a82f777 100644 --- a/src/transformers/feature_extraction_sequence_utils.py +++ b/src/transformers/feature_extraction_sequence_utils.py @@ -15,25 +15,19 @@ """ Sequence feature extraction class for common feature extrcactors to preprocess sequences. """ -from collections import UserDict -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union +from typing import Dict, List, Optional, Union import numpy as np -from .feature_extraction_saving_utils import FeatureExtractionSavingUtilsMixin +from .feature_extraction_common_utils import BatchFeature, FeatureExtractionSavingUtilsMixin from .file_utils import ( PaddingStrategy, TensorType, - _is_jax, - _is_numpy, _is_tensorflow, _is_torch, - _is_torch_device, - is_flax_available, is_tf_available, is_torch_available, to_py_obj, - torch_required, ) from .utils import logging @@ -41,150 +35,6 @@ logger = logging.get_logger(__name__) -if TYPE_CHECKING: - if is_torch_available(): - import torch - - -class BatchSequenceFeature(UserDict): - r""" - Holds the output of the :meth:`~transformers.PreTrainedSequenceFeatureExtractor.pad` and feature extractor specific - ``__call__`` methods. - - This class is derived from a python dictionary and can be used as a dictionary. - - Args: - data (:obj:`dict`): - Dictionary of lists/arrays/tensors returned by the __call__/pad methods ('input_values', 'attention_mask', - etc.). - tensor_type (:obj:`Union[None, str, TensorType]`, `optional`): - You can give a tensor_type here to convert the lists of integers in PyTorch/TensorFlow/Numpy Tensors at - initialization. - """ - - def __init__(self, data: Optional[Dict[str, Any]] = None, tensor_type: Union[None, str, TensorType] = None): - super().__init__(data) - self.convert_to_tensors(tensor_type=tensor_type) - - def __getitem__(self, item: str) -> Union[Any]: - """ - If the key is a string, returns the value of the dict associated to :obj:`key` ('input_values', - 'attention_mask', etc.). - """ - if isinstance(item, str): - return self.data[item] - else: - raise KeyError("Indexing with integers is not available when using Python based feature extractors") - - def __getattr__(self, item: str): - try: - return self.data[item] - except KeyError: - raise AttributeError - - def __getstate__(self): - return {"data": self.data} - - def __setstate__(self, state): - if "data" in state: - self.data = state["data"] - - # Copied from transformers.tokenization_utils_base.BatchEncoding.keys - def keys(self): - return self.data.keys() - - # Copied from transformers.tokenization_utils_base.BatchEncoding.values - def values(self): - return self.data.values() - - # Copied from transformers.tokenization_utils_base.BatchEncoding.items - def items(self): - return self.data.items() - - def convert_to_tensors(self, tensor_type: Optional[Union[str, TensorType]] = None): - """ - Convert the inner content to tensors. - - Args: - tensor_type (:obj:`str` or :class:`~transformers.file_utils.TensorType`, `optional`): - The type of tensors to use. If :obj:`str`, should be one of the values of the enum - :class:`~transformers.file_utils.TensorType`. If :obj:`None`, no modification is done. - """ - if tensor_type is None: - return self - - # Convert to TensorType - if not isinstance(tensor_type, TensorType): - tensor_type = TensorType(tensor_type) - - # Get a function reference for the correct framework - if tensor_type == TensorType.TENSORFLOW: - if not is_tf_available(): - raise ImportError( - "Unable to convert output to TensorFlow tensors format, TensorFlow is not installed." - ) - import tensorflow as tf - - as_tensor = tf.constant - is_tensor = tf.is_tensor - elif tensor_type == TensorType.PYTORCH: - if not is_torch_available(): - raise ImportError("Unable to convert output to PyTorch tensors format, PyTorch is not installed.") - import torch - - as_tensor = torch.tensor - is_tensor = torch.is_tensor - elif tensor_type == TensorType.JAX: - if not is_flax_available(): - raise ImportError("Unable to convert output to JAX tensors format, JAX is not installed.") - import jax.numpy as jnp # noqa: F811 - - as_tensor = jnp.array - is_tensor = _is_jax - else: - as_tensor = np.asarray - is_tensor = _is_numpy - - # Do the tensor conversion in batch - for key, value in self.items(): - try: - if not is_tensor(value): - tensor = as_tensor(value) - - self[key] = tensor - except: # noqa E722 - if key == "overflowing_values": - raise ValueError("Unable to create tensor returning overflowing values of different lengths. ") - raise ValueError( - "Unable to create tensor, you should probably activate padding " - "with 'padding=True' to have batched tensors with the same length." - ) - - return self - - @torch_required - # Copied from transformers.tokenization_utils_base.BatchEncoding.to with BatchEncoding->BatchSequenceFeature - def to(self, device: Union[str, "torch.device"]) -> "BatchSequenceFeature": - """ - Send all values to device by calling :obj:`v.to(device)` (PyTorch only). - - Args: - device (:obj:`str` or :obj:`torch.device`): The device to put the tensors on. - - Returns: - :class:`~transformers.BatchSequenceFeature`: The same instance after modification. - """ - - # This check catches things like APEX blindly calling "to" on all inputs to a module - # Otherwise it passes the casts down and casts the LongTensor containing the token idxs - # into a HalfTensor - if isinstance(device, str) or _is_torch_device(device) or isinstance(device, int): - self.data = {k: v.to(device=device) for k, v in self.data.items()} - else: - logger.warning(f"Attempting to cast a BatchSequenceFeature to type {str(device)}. This is not supported.") - return self - - class PreTrainedSequenceFeatureExtractor(FeatureExtractionSavingUtilsMixin): """ This is a general feature extraction class for speech recognition. @@ -217,18 +67,18 @@ def __init__(self, feature_size: int, sampling_rate: int, padding_value: float, def pad( self, processed_features: Union[ - BatchSequenceFeature, - List[BatchSequenceFeature], - Dict[str, BatchSequenceFeature], - Dict[str, List[BatchSequenceFeature]], - List[Dict[str, BatchSequenceFeature]], + BatchFeature, + List[BatchFeature], + Dict[str, BatchFeature], + Dict[str, List[BatchFeature]], + List[Dict[str, BatchFeature]], ], padding: Union[bool, str, PaddingStrategy] = True, max_length: Optional[int] = None, pad_to_multiple_of: Optional[int] = None, return_attention_mask: Optional[bool] = None, return_tensors: Optional[Union[str, TensorType]] = None, - ) -> BatchSequenceFeature: + ) -> BatchFeature: """ Pad input values / input vectors or a batch of input values / input vectors up to predefined length or to the max sequence length in the batch. @@ -243,12 +93,11 @@ def pad( the case of PyTorch tensors, you will lose the specific device of your tensors however. Args: - processed_features (:class:`~transformers.BatchSequenceFeature`, list of :class:`~transformers.BatchSequenceFeature`, :obj:`Dict[str, List[float]]`, :obj:`Dict[str, List[List[float]]` or :obj:`List[Dict[str, List[float]]]`): - Processed inputs. Can represent one input (:class:`~transformers.BatchSequenceFeature` or - :obj:`Dict[str, List[float]]`) or a batch of input values / vectors (list of - :class:`~transformers.BatchSequenceFeature`, `Dict[str, List[List[float]]]` or `List[Dict[str, - List[float]]]`) so you can use this method during preprocessing as well as in a PyTorch Dataloader - collate function. + processed_features (:class:`~transformers.BatchFeature`, list of :class:`~transformers.BatchFeature`, :obj:`Dict[str, List[float]]`, :obj:`Dict[str, List[List[float]]` or :obj:`List[Dict[str, List[float]]]`): + Processed inputs. Can represent one input (:class:`~transformers.BatchFeature` or :obj:`Dict[str, + List[float]]`) or a batch of input values / vectors (list of :class:`~transformers.BatchFeature`, + `Dict[str, List[List[float]]]` or `List[Dict[str, List[float]]]`) so you can use this method during + preprocessing as well as in a PyTorch Dataloader collate function. Instead of :obj:`List[float]` you can have tensors (numpy arrays, PyTorch tensors or TensorFlow tensors), see the note above for the return type. @@ -283,9 +132,7 @@ def pad( """ # If we have a list of dicts, let's convert it in a dict of lists # We do this to allow using this method as a collate_fn function in PyTorch Dataloader - if isinstance(processed_features, (list, tuple)) and isinstance( - processed_features[0], (dict, BatchSequenceFeature) - ): + if isinstance(processed_features, (list, tuple)) and isinstance(processed_features[0], (dict, BatchFeature)): processed_features = { key: [example[key] for example in processed_features] for key in processed_features[0].keys() } @@ -293,7 +140,7 @@ def pad( # The model's main input name, usually `input_values`, has be passed for padding if self.model_input_names[0] not in processed_features: raise ValueError( - "You should supply an instance of :class:`~transformers.BatchSequenceFeature` or list of :class:`~transformers.BatchSequenceFeature` to this method" + "You should supply an instance of :class:`~transformers.BatchFeature` or list of :class:`~transformers.BatchFeature` to this method" f"that includes {self.model_input_names[0]}, but you provided {list(processed_features.keys())}" ) @@ -348,7 +195,7 @@ def pad( pad_to_multiple_of=pad_to_multiple_of, return_attention_mask=return_attention_mask, ) - return BatchSequenceFeature(processed_features, tensor_type=return_tensors) + return BatchFeature(processed_features, tensor_type=return_tensors) batch_size = len(required_input) assert all( @@ -375,11 +222,11 @@ def pad( batch_outputs[key] = [] batch_outputs[key].append(value) - return BatchSequenceFeature(batch_outputs, tensor_type=return_tensors) + return BatchFeature(batch_outputs, tensor_type=return_tensors) def _pad( self, - processed_features: Union[Dict[str, List[float]], BatchSequenceFeature], + processed_features: Union[Dict[str, List[float]], BatchFeature], max_length: Optional[int] = None, padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD, pad_to_multiple_of: Optional[int] = None, diff --git a/src/transformers/models/wav2vec2/feature_extraction_wav2vec2.py b/src/transformers/models/wav2vec2/feature_extraction_wav2vec2.py index 4a2a868de47302..b9a0354dccc619 100644 --- a/src/transformers/models/wav2vec2/feature_extraction_wav2vec2.py +++ b/src/transformers/models/wav2vec2/feature_extraction_wav2vec2.py @@ -20,7 +20,8 @@ import numpy as np -from ...feature_extraction_sequence_utils import BatchSequenceFeature, PreTrainedSequenceFeatureExtractor +from ...feature_extraction_common_utils import BatchFeature +from ...feature_extraction_sequence_utils import PreTrainedSequenceFeatureExtractor from ...file_utils import PaddingStrategy, TensorType from ...utils import logging @@ -93,7 +94,7 @@ def __call__( return_tensors: Optional[Union[str, TensorType]] = None, sampling_rate: Optional[int] = None, **kwargs - ) -> BatchSequenceFeature: + ) -> BatchFeature: """ Main method to featurize and prepare for the model one or several sequence(s). sequences. @@ -179,7 +180,7 @@ def __call__( raw_speech = self.zero_mean_unit_var_norm(raw_speech) # convert into correct format for padding - encoded_inputs = BatchSequenceFeature({"input_values": raw_speech}) + encoded_inputs = BatchFeature({"input_values": raw_speech}) padded_inputs = self.pad( encoded_inputs, diff --git a/tests/test_sequence_feature_extraction_common.py b/tests/test_sequence_feature_extraction_common.py index 0dfd1deffbc931..af0dd3d3dd6fbd 100644 --- a/tests/test_sequence_feature_extraction_common.py +++ b/tests/test_sequence_feature_extraction_common.py @@ -16,7 +16,7 @@ import numpy as np -from transformers import BatchSequenceFeature +from transformers import BatchFeature from transformers.testing_utils import require_tf, require_torch from .test_feature_extraction_saving_common import FeatureExtractionSavingTestMixin @@ -43,12 +43,12 @@ def test_batch_feature(self): feat_extract = self.feature_extraction_class(**self.feat_extract_dict) input_name = feat_extract.model_input_names[0] - processed_features = BatchSequenceFeature({input_name: speech_inputs}) + processed_features = BatchFeature({input_name: speech_inputs}) self.assertTrue(all(len(x) == len(y) for x, y in zip(speech_inputs, processed_features[input_name]))) speech_inputs = self.feat_extract_tester.prepare_inputs_for_common(equal_length=True) - processed_features = BatchSequenceFeature({input_name: speech_inputs}, tensor_type="np") + processed_features = BatchFeature({input_name: speech_inputs}, tensor_type="np") batch_features_input = processed_features[input_name] @@ -66,7 +66,7 @@ def test_batch_feature_pt(self): feat_extract = self.feature_extraction_class(**self.feat_extract_dict) input_name = feat_extract.model_input_names[0] - processed_features = BatchSequenceFeature({input_name: speech_inputs}, tensor_type="pt") + processed_features = BatchFeature({input_name: speech_inputs}, tensor_type="pt") batch_features_input = processed_features[input_name] @@ -84,7 +84,7 @@ def test_batch_feature_tf(self): feat_extract = self.feature_extraction_class(**self.feat_extract_dict) input_name = feat_extract.model_input_names[0] - processed_features = BatchSequenceFeature({input_name: speech_inputs}, tensor_type="tf") + processed_features = BatchFeature({input_name: speech_inputs}, tensor_type="tf") batch_features_input = processed_features[input_name] @@ -117,7 +117,7 @@ def _inputs_are_equal(input_1, input_2): speech_inputs = self.feat_extract_tester.prepare_inputs_for_common(numpify=numpify) input_name = feat_extract.model_input_names[0] - processed_features = BatchSequenceFeature({input_name: speech_inputs}) + processed_features = BatchFeature({input_name: speech_inputs}) pad_diff = self.feat_extract_tester.seq_length_diff pad_max_length = self.feat_extract_tester.max_seq_length + pad_diff @@ -217,7 +217,7 @@ def test_padding_accepts_tensors_pt(self): speech_inputs = self.feat_extract_tester.prepare_inputs_for_common() input_name = feat_extract.model_input_names[0] - processed_features = BatchSequenceFeature({input_name: speech_inputs}) + processed_features = BatchFeature({input_name: speech_inputs}) input_np = feat_extract.pad(processed_features, padding="longest", return_tensors="np")[input_name] input_pt = feat_extract.pad(processed_features, padding="longest", return_tensors="pt")[input_name] @@ -230,7 +230,7 @@ def test_padding_accepts_tensors_tf(self): speech_inputs = self.feat_extract_tester.prepare_inputs_for_common() input_name = feat_extract.model_input_names[0] - processed_features = BatchSequenceFeature({input_name: speech_inputs}) + processed_features = BatchFeature({input_name: speech_inputs}) input_np = feat_extract.pad(processed_features, padding="longest", return_tensors="np")[input_name] input_tf = feat_extract.pad(processed_features, padding="longest", return_tensors="tf")[input_name] @@ -245,7 +245,7 @@ def test_attention_mask(self): input_lenghts = [len(x) for x in speech_inputs] input_name = feat_extract.model_input_names[0] - processed = BatchSequenceFeature({input_name: speech_inputs}) + processed = BatchFeature({input_name: speech_inputs}) processed = feat_extract.pad(processed, padding="longest", return_tensors="np") self.assertIn("attention_mask", processed) From ce029eb9d5d8c0bd5e974491f670acfd6294000e Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 8 Mar 2021 18:20:56 +0300 Subject: [PATCH 4/9] correct naming --- docs/source/main_classes/feature_extractor.rst | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/docs/source/main_classes/feature_extractor.rst b/docs/source/main_classes/feature_extractor.rst index adbda92a01b1fb..567c20bd2361d8 100644 --- a/docs/source/main_classes/feature_extractor.rst +++ b/docs/source/main_classes/feature_extractor.rst @@ -20,11 +20,18 @@ from sequences, *e.g.*, pre-processing audio files to Log-Mel Spectrogram featur tensors. +FeatureExtractionSavingUtilsMixin +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.feature_extraction_common_utils.FeatureExtractionSavingUtilsMixin + :members: from_pretrained, save_pretrained + + PreTrainedSequenceFeatureExtractor ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. autoclass:: transformers.PreTrainedSequenceFeatureExtractor - :members: from_pretrained, save_pretrained, pad + :members: pad BatchFeature From 94210bbea7ef596fce1bdc2f4fb5ed2bdab69040 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 8 Mar 2021 18:26:00 +0300 Subject: [PATCH 5/9] correct naming --- src/transformers/feature_extraction_common_utils.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/transformers/feature_extraction_common_utils.py b/src/transformers/feature_extraction_common_utils.py index d81eff5349ea86..d8d91de6b2c184 100644 --- a/src/transformers/feature_extraction_common_utils.py +++ b/src/transformers/feature_extraction_common_utils.py @@ -243,14 +243,19 @@ def from_pretrained( kwargs (:obj:`Dict[str, Any]`, `optional`): The values in kwargs of any keys which are feature extractor attributes will be used to override the loaded values. Behavior concerning key/value pairs whose keys are *not* feature extractor attributes is - controlled by the ``return_unused_kwargs`` keyword parameter. .. note:: + controlled by the ``return_unused_kwargs`` keyword parameter. + + .. note:: + Passing :obj:`use_auth_token=True` is required when you want to use a private model. + Returns: :class:`~transformers.PreTrainedSequenceFeatureExtractor`: The feature extractor object instantiated from this pretrained model. Examples:: + # We can't instantiate directly the base class `PreTrainedSequenceFeatureExtractor` so let's show the examples on a # derived class: Wav2Vec2FeatureExtractor feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained('facebook/wav2vec2-base-960h') # Download feature_extraction_config from huggingface.co and cache. From 17a116f91066f684782d6a7e7f60acb9a56d40e4 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 8 Mar 2021 21:26:56 +0300 Subject: [PATCH 6/9] shorter names --- .../source/main_classes/feature_extractor.rst | 4 +-- src/transformers/__init__.py | 4 +-- .../feature_extraction_common_utils.py | 36 +++++++++---------- .../feature_extraction_sequence_utils.py | 2 +- .../wav2vec2/feature_extraction_wav2vec2.py | 4 +-- .../models/wav2vec2/processing_wav2vec2.py | 10 +++--- 6 files changed, 28 insertions(+), 32 deletions(-) diff --git a/docs/source/main_classes/feature_extractor.rst b/docs/source/main_classes/feature_extractor.rst index 567c20bd2361d8..8bbe9164ee05de 100644 --- a/docs/source/main_classes/feature_extractor.rst +++ b/docs/source/main_classes/feature_extractor.rst @@ -27,10 +27,10 @@ FeatureExtractionSavingUtilsMixin :members: from_pretrained, save_pretrained -PreTrainedSequenceFeatureExtractor +SequenceFeatureExtractor ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -.. autoclass:: transformers.PreTrainedSequenceFeatureExtractor +.. autoclass:: transformers.SequenceFeatureExtractor :members: pad diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index cd61380d6f1efc..85afaece4de71c 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -246,7 +246,7 @@ "SpecialTokensMixin", "TokenSpan", ], - "feature_extraction_sequence_utils": ["PreTrainedSequenceFeatureExtractor", "BatchFeature"], + "feature_extraction_sequence_utils": ["SequenceFeatureExtractor", "BatchFeature"], "trainer_callback": [ "DefaultFlowCallback", "EarlyStoppingCallback", @@ -1250,7 +1250,7 @@ ) # Feature Extractor - from .feature_extraction_common_utils import BatchFeature, PreTrainedSequenceFeatureExtractor + from .feature_extraction_common_utils import BatchFeature, SequenceFeatureExtractor # Files and general utilities from .file_utils import ( diff --git a/src/transformers/feature_extraction_common_utils.py b/src/transformers/feature_extraction_common_utils.py index d8d91de6b2c184..1a295d2cca5357 100644 --- a/src/transformers/feature_extraction_common_utils.py +++ b/src/transformers/feature_extraction_common_utils.py @@ -48,12 +48,12 @@ logger = logging.get_logger(__name__) -PreTrainedFeatureExtractor = Union["PreTrainedSequenceFeatureExtractor"] # noqa: F821 +PreTrainedFeatureExtractor = Union["SequenceFeatureExtractor"] # noqa: F821 class BatchFeature(UserDict): r""" - Holds the output of the :meth:`~transformers.PreTrainedSequenceFeatureExtractor.pad` and feature extractor specific + Holds the output of the :meth:`~transformers.SequenceFeatureExtractor.pad` and feature extractor specific ``__call__`` methods. This class is derived from a python dictionary and can be used as a dictionary. @@ -201,8 +201,8 @@ def from_pretrained( cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs ) -> "PreTrainedFeatureExtractor": r""" - Instantiate a :class:`~transformers.PreTrainedSequenceFeatureExtractor` (or a derived class) from a pretrained - feature extractor. + Instantiate a :class:`~transformers.SequenceFeatureExtractor` (or a derived class) from a pretrained feature + extractor. Args: pretrained_model_name_or_path (:obj:`str` or :obj:`os.PathLike`): @@ -212,7 +212,7 @@ def from_pretrained( huggingface.co. Valid model ids can be located at the root-level, like ``bert-base-uncased``, or namespaced under a user or organization name, like ``dbmdz/bert-base-german-cased``. - a path to a `directory` containing a feature extractor file saved using the - :func:`~transformers.PreTrainedSequenceFeatureExtractor.save_pretrained` method, e.g., + :func:`~transformers.SequenceFeatureExtractor.save_pretrained` method, e.g., ``./my_model_directory/``. - a path or url to a saved feature extractor JSON `file`, e.g., ``./my_model_directory/feature_extraction_config.json``. @@ -251,12 +251,12 @@ def from_pretrained( Returns: - :class:`~transformers.PreTrainedSequenceFeatureExtractor`: The feature extractor object instantiated from - this pretrained model. + :class:`~transformers.SequenceFeatureExtractor`: The feature extractor object instantiated from this + pretrained model. Examples:: - # We can't instantiate directly the base class `PreTrainedSequenceFeatureExtractor` so let's show the examples on a + # We can't instantiate directly the base class `SequenceFeatureExtractor` so let's show the examples on a # derived class: Wav2Vec2FeatureExtractor feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained('facebook/wav2vec2-base-960h') # Download feature_extraction_config from huggingface.co and cache. feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained('./test/saved_model/') # E.g. feature_extractor (or model) was saved using `save_pretrained('./test/saved_model/')` @@ -275,7 +275,7 @@ def from_pretrained( def save_pretrained(self, save_directory: Union[str, os.PathLike]): """ Save a feature_extractor object to the directory ``save_directory``, so that it can be re-loaded using the - :func:`~transformers.PreTrainedSequenceFeatureExtractor.from_pretrained` class method. + :func:`~transformers.SequenceFeatureExtractor.from_pretrained` class method. Args: save_directory (:obj:`str` or :obj:`os.PathLike`): @@ -296,7 +296,7 @@ def get_feature_extractor_dict( ) -> Tuple[Dict[str, Any], Dict[str, Any]]: """ From a ``pretrained_model_name_or_path``, resolve to a dictionary of parameters, to be used for instantiating a - :class:`~transformers.PreTrainedSequenceFeatureExtractor` using ``from_dict``. + :class:`~transformers.SequenceFeatureExtractor` using ``from_dict``. Parameters: pretrained_model_name_or_path (:obj:`str` or :obj:`os.PathLike`): @@ -369,20 +369,19 @@ def get_feature_extractor_dict( @classmethod def from_dict(cls, feature_extractor_dict: Dict[str, Any], **kwargs) -> PreTrainedFeatureExtractor: """ - Instantiates a :class:`~transformers.PreTrainedSequenceFeatureExtractor` from a Python dictionary of - parameters. + Instantiates a :class:`~transformers.SequenceFeatureExtractor` from a Python dictionary of parameters. Args: feature_extractor_dict (:obj:`Dict[str, Any]`): Dictionary that will be used to instantiate the feature extractor object. Such a dictionary can be retrieved from a pretrained checkpoint by leveraging the - :func:`~transformers.PreTrainedSequenceFeatureExtractor.to_dict` method. + :func:`~transformers.SequenceFeatureExtractor.to_dict` method. kwargs (:obj:`Dict[str, Any]`): Additional parameters from which to initialize the feature extractor object. Returns: - :class:`~transformers.PreTrainedSequenceFeatureExtractor`: The feature extractor object instantiated from - those parameters. + :class:`~transformers.SequenceFeatureExtractor`: The feature extractor object instantiated from those + parameters. """ return_unused_kwargs = kwargs.pop("return_unused_kwargs", False) @@ -417,16 +416,15 @@ def to_dict(self) -> Dict[str, Any]: @classmethod def from_json_file(cls, json_file: Union[str, os.PathLike]) -> PreTrainedFeatureExtractor: """ - Instantiates a :class:`~transformers.PreTrainedSequenceFeatureExtractor` from the path to a JSON file of - parameters. + Instantiates a :class:`~transformers.SequenceFeatureExtractor` from the path to a JSON file of parameters. Args: json_file (:obj:`str` or :obj:`os.PathLike`): Path to the JSON file containing the parameters. Returns: - :class:`~transformers.PreTrainedSequenceFeatureExtractor`: The feature_extractor object instantiated from - that JSON file. + :class:`~transformers.SequenceFeatureExtractor`: The feature_extractor object instantiated from that JSON + file. """ with open(json_file, "r", encoding="utf-8") as reader: text = reader.read() diff --git a/src/transformers/feature_extraction_sequence_utils.py b/src/transformers/feature_extraction_sequence_utils.py index 8069a35a82f777..e5b130a3212ca5 100644 --- a/src/transformers/feature_extraction_sequence_utils.py +++ b/src/transformers/feature_extraction_sequence_utils.py @@ -35,7 +35,7 @@ logger = logging.get_logger(__name__) -class PreTrainedSequenceFeatureExtractor(FeatureExtractionSavingUtilsMixin): +class SequenceFeatureExtractor(FeatureExtractionSavingUtilsMixin): """ This is a general feature extraction class for speech recognition. diff --git a/src/transformers/models/wav2vec2/feature_extraction_wav2vec2.py b/src/transformers/models/wav2vec2/feature_extraction_wav2vec2.py index b9a0354dccc619..a623a3bd97c4f2 100644 --- a/src/transformers/models/wav2vec2/feature_extraction_wav2vec2.py +++ b/src/transformers/models/wav2vec2/feature_extraction_wav2vec2.py @@ -21,7 +21,7 @@ import numpy as np from ...feature_extraction_common_utils import BatchFeature -from ...feature_extraction_sequence_utils import PreTrainedSequenceFeatureExtractor +from ...feature_extraction_sequence_utils import SequenceFeatureExtractor from ...file_utils import PaddingStrategy, TensorType from ...utils import logging @@ -29,7 +29,7 @@ logger = logging.get_logger(__name__) -class Wav2Vec2FeatureExtractor(PreTrainedSequenceFeatureExtractor): +class Wav2Vec2FeatureExtractor(SequenceFeatureExtractor): r""" Constructs a Wav2Vec2 feature extractor. diff --git a/src/transformers/models/wav2vec2/processing_wav2vec2.py b/src/transformers/models/wav2vec2/processing_wav2vec2.py index 9676225447dbc3..b53cfa4fc911df 100644 --- a/src/transformers/models/wav2vec2/processing_wav2vec2.py +++ b/src/transformers/models/wav2vec2/processing_wav2vec2.py @@ -59,8 +59,7 @@ def save_pretrained(self, save_directory): .. note:: - This class method is simply calling - :meth:`~transformers.PreTrainedSequenceFeatureExtractor.save_pretrained` and + This class method is simply calling :meth:`~transformers.SequenceFeatureExtractor.save_pretrained` and :meth:`~transformers.tokenization_utils_base.PreTrainedTokenizer.save_pretrained`. Please refer to the docstrings of the methods above for more information. @@ -81,7 +80,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): .. note:: This class method is simply calling Wav2Vec2FeatureExtractor's - :meth:`~transformers.PreTrainedSequenceFeatureExtractor.from_pretrained` and Wav2Vec2CTCTokenizer's + :meth:`~transformers.SequenceFeatureExtractor.from_pretrained` and Wav2Vec2CTCTokenizer's :meth:`~transformers.tokenization_utils_base.PreTrainedTokenizer.from_pretrained`. Please refer to the docstrings of the methods above for more information. @@ -93,13 +92,12 @@ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): huggingface.co. Valid model ids can be located at the root-level, like ``bert-base-uncased``, or namespaced under a user or organization name, like ``dbmdz/bert-base-german-cased``. - a path to a `directory` containing a feature extractor file saved using the - :meth:`~transformers.PreTrainedSequenceFeatureExtractor.save_pretrained` method, e.g., + :meth:`~transformers.SequenceFeatureExtractor.save_pretrained` method, e.g., ``./my_model_directory/``. - a path or url to a saved feature extractor JSON `file`, e.g., ``./my_model_directory/feature_extraction_config.json``. **kwargs - Additional keyword arguments passed along to both - :class:`~transformers.PreTrainedSequenceFeatureExtractor` and + Additional keyword arguments passed along to both :class:`~transformers.SequenceFeatureExtractor` and :class:`~transformers.PreTrainedTokenizer` """ feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(pretrained_model_name_or_path, **kwargs) From e1ee6466c832c43a5298b472143d30abfccf0699 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 8 Mar 2021 21:35:02 +0300 Subject: [PATCH 7/9] Update src/transformers/feature_extraction_common_utils.py Co-authored-by: Lysandre Debut --- src/transformers/feature_extraction_common_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/feature_extraction_common_utils.py b/src/transformers/feature_extraction_common_utils.py index f933d5b7efa590..047ced504d54d6 100644 --- a/src/transformers/feature_extraction_common_utils.py +++ b/src/transformers/feature_extraction_common_utils.py @@ -200,7 +200,7 @@ class FeatureExtractionSavingUtilsMixin: @classmethod def from_pretrained( cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs - ) -> "PreTrainedFeatureExtractor": + ) -> PreTrainedFeatureExtractor: r""" Instantiate a :class:`~transformers.SequenceFeatureExtractor` (or a derived class) from a pretrained feature extractor. From 3090d54aa158521b359ba403a9fa857edf3e5be6 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 9 Mar 2021 10:37:38 +0300 Subject: [PATCH 8/9] change name --- docs/source/main_classes/feature_extractor.rst | 2 +- src/transformers/__init__.py | 2 +- src/transformers/feature_extraction_sequence_utils.py | 2 +- ...e_extraction_common_utils.py => feature_extraction_utils.py} | 2 +- src/transformers/models/wav2vec2/feature_extraction_wav2vec2.py | 2 +- 5 files changed, 5 insertions(+), 5 deletions(-) rename src/transformers/{feature_extraction_common_utils.py => feature_extraction_utils.py} (99%) diff --git a/docs/source/main_classes/feature_extractor.rst b/docs/source/main_classes/feature_extractor.rst index 8bbe9164ee05de..25ceff4524e1e7 100644 --- a/docs/source/main_classes/feature_extractor.rst +++ b/docs/source/main_classes/feature_extractor.rst @@ -23,7 +23,7 @@ tensors. FeatureExtractionSavingUtilsMixin ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -.. autoclass:: transformers.feature_extraction_common_utils.FeatureExtractionSavingUtilsMixin +.. autoclass:: transformers.feature_extraction_utils.FeatureExtractionSavingUtilsMixin :members: from_pretrained, save_pretrained diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 85afaece4de71c..d8188c2454e46a 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -1250,7 +1250,7 @@ ) # Feature Extractor - from .feature_extraction_common_utils import BatchFeature, SequenceFeatureExtractor + from .feature_extraction_utils import BatchFeature, SequenceFeatureExtractor # Files and general utilities from .file_utils import ( diff --git a/src/transformers/feature_extraction_sequence_utils.py b/src/transformers/feature_extraction_sequence_utils.py index e5b130a3212ca5..588968cec7f16f 100644 --- a/src/transformers/feature_extraction_sequence_utils.py +++ b/src/transformers/feature_extraction_sequence_utils.py @@ -19,7 +19,7 @@ import numpy as np -from .feature_extraction_common_utils import BatchFeature, FeatureExtractionSavingUtilsMixin +from .feature_extraction_utils import BatchFeature, FeatureExtractionSavingUtilsMixin from .file_utils import ( PaddingStrategy, TensorType, diff --git a/src/transformers/feature_extraction_common_utils.py b/src/transformers/feature_extraction_utils.py similarity index 99% rename from src/transformers/feature_extraction_common_utils.py rename to src/transformers/feature_extraction_utils.py index 047ced504d54d6..df2974026b652c 100644 --- a/src/transformers/feature_extraction_common_utils.py +++ b/src/transformers/feature_extraction_utils.py @@ -202,7 +202,7 @@ def from_pretrained( cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs ) -> PreTrainedFeatureExtractor: r""" - Instantiate a :class:`~transformers.SequenceFeatureExtractor` (or a derived class) from a pretrained feature + Instantiate a type of :class:`~transformers.feature_extraction_common:class:`~transformers.SequenceFeatureExtractor` (or a derived class) from a pretrained feature extractor. Args: diff --git a/src/transformers/models/wav2vec2/feature_extraction_wav2vec2.py b/src/transformers/models/wav2vec2/feature_extraction_wav2vec2.py index a623a3bd97c4f2..265a9ffb97a203 100644 --- a/src/transformers/models/wav2vec2/feature_extraction_wav2vec2.py +++ b/src/transformers/models/wav2vec2/feature_extraction_wav2vec2.py @@ -20,7 +20,7 @@ import numpy as np -from ...feature_extraction_common_utils import BatchFeature +from ...feature_extraction_utils import BatchFeature from ...feature_extraction_sequence_utils import SequenceFeatureExtractor from ...file_utils import PaddingStrategy, TensorType from ...utils import logging From 49b33644abc952520e209e9213b48438df3cc8f8 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 9 Mar 2021 11:00:58 +0300 Subject: [PATCH 9/9] finish --- .../source/main_classes/feature_extractor.rst | 4 +-- .../feature_extraction_sequence_utils.py | 4 +-- src/transformers/feature_extraction_utils.py | 36 ++++++++++--------- .../wav2vec2/feature_extraction_wav2vec2.py | 2 +- .../models/wav2vec2/processing_wav2vec2.py | 9 ++--- ...n.py => test_feature_extraction_common.py} | 0 ...test_sequence_feature_extraction_common.py | 2 +- 7 files changed, 30 insertions(+), 27 deletions(-) rename tests/{test_feature_extraction_saving_common.py => test_feature_extraction_common.py} (100%) diff --git a/docs/source/main_classes/feature_extractor.rst b/docs/source/main_classes/feature_extractor.rst index 25ceff4524e1e7..d8d95941538eb5 100644 --- a/docs/source/main_classes/feature_extractor.rst +++ b/docs/source/main_classes/feature_extractor.rst @@ -20,10 +20,10 @@ from sequences, *e.g.*, pre-processing audio files to Log-Mel Spectrogram featur tensors. -FeatureExtractionSavingUtilsMixin +FeatureExtractionMixin ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -.. autoclass:: transformers.feature_extraction_utils.FeatureExtractionSavingUtilsMixin +.. autoclass:: transformers.feature_extraction_utils.FeatureExtractionMixin :members: from_pretrained, save_pretrained diff --git a/src/transformers/feature_extraction_sequence_utils.py b/src/transformers/feature_extraction_sequence_utils.py index 588968cec7f16f..318e7a3dfb1b68 100644 --- a/src/transformers/feature_extraction_sequence_utils.py +++ b/src/transformers/feature_extraction_sequence_utils.py @@ -19,7 +19,7 @@ import numpy as np -from .feature_extraction_utils import BatchFeature, FeatureExtractionSavingUtilsMixin +from .feature_extraction_utils import BatchFeature, FeatureExtractionMixin from .file_utils import ( PaddingStrategy, TensorType, @@ -35,7 +35,7 @@ logger = logging.get_logger(__name__) -class SequenceFeatureExtractor(FeatureExtractionSavingUtilsMixin): +class SequenceFeatureExtractor(FeatureExtractionMixin): """ This is a general feature extraction class for speech recognition. diff --git a/src/transformers/feature_extraction_utils.py b/src/transformers/feature_extraction_utils.py index df2974026b652c..9995026541462d 100644 --- a/src/transformers/feature_extraction_utils.py +++ b/src/transformers/feature_extraction_utils.py @@ -191,7 +191,7 @@ def to(self, device: Union[str, "torch.device"]) -> "BatchFeature": return self -class FeatureExtractionSavingUtilsMixin: +class FeatureExtractionMixin: """ This is a feature extraction mixin used to provide saving/loading functionality for sequential and image feature extractors. @@ -202,8 +202,8 @@ def from_pretrained( cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs ) -> PreTrainedFeatureExtractor: r""" - Instantiate a type of :class:`~transformers.feature_extraction_common:class:`~transformers.SequenceFeatureExtractor` (or a derived class) from a pretrained feature - extractor. + Instantiate a type of :class:`~transformers.feature_extraction_utils.FeatureExtractionMixin` from a feature + extractor, *e.g.* a derived class of :class:`~transformers.SequenceFeatureExtractor`. Args: pretrained_model_name_or_path (:obj:`str` or :obj:`os.PathLike`): @@ -213,7 +213,7 @@ def from_pretrained( huggingface.co. Valid model ids can be located at the root-level, like ``bert-base-uncased``, or namespaced under a user or organization name, like ``dbmdz/bert-base-german-cased``. - a path to a `directory` containing a feature extractor file saved using the - :func:`~transformers.SequenceFeatureExtractor.save_pretrained` method, e.g., + :func:`~transformers.feature_extraction_utils.FeatureExtractionMixin.save_pretrained` method, e.g., ``./my_model_directory/``. - a path or url to a saved feature extractor JSON `file`, e.g., ``./my_model_directory/feature_extraction_config.json``. @@ -252,13 +252,12 @@ def from_pretrained( Returns: - :class:`~transformers.SequenceFeatureExtractor`: The feature extractor object instantiated from this - pretrained model. + A feature extractor of type :class:`~transformers.feature_extraction_utils.FeatureExtractionMixin`. Examples:: - # We can't instantiate directly the base class `SequenceFeatureExtractor` so let's show the examples on a - # derived class: Wav2Vec2FeatureExtractor + # We can't instantiate directly the base class `FeatureExtractionMixin` nor `SequenceFeatureExtractor` so let's show the examples on a + # derived class: `Wav2Vec2FeatureExtractor` feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained('facebook/wav2vec2-base-960h') # Download feature_extraction_config from huggingface.co and cache. feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained('./test/saved_model/') # E.g. feature_extractor (or model) was saved using `save_pretrained('./test/saved_model/')` feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained('./test/saved_model/preprocessor_config.json') @@ -276,7 +275,7 @@ def from_pretrained( def save_pretrained(self, save_directory: Union[str, os.PathLike]): """ Save a feature_extractor object to the directory ``save_directory``, so that it can be re-loaded using the - :func:`~transformers.SequenceFeatureExtractor.from_pretrained` class method. + :func:`~transformers.feature_extraction_utils.FeatureExtractionMixin.from_pretrained` class method. Args: save_directory (:obj:`str` or :obj:`os.PathLike`): @@ -297,7 +296,8 @@ def get_feature_extractor_dict( ) -> Tuple[Dict[str, Any], Dict[str, Any]]: """ From a ``pretrained_model_name_or_path``, resolve to a dictionary of parameters, to be used for instantiating a - :class:`~transformers.SequenceFeatureExtractor` using ``from_dict``. + feature extractor of type :class:`~transformers.feature_extraction_utils.FeatureExtractionMixin` using + ``from_dict``. Parameters: pretrained_model_name_or_path (:obj:`str` or :obj:`os.PathLike`): @@ -374,19 +374,20 @@ def get_feature_extractor_dict( @classmethod def from_dict(cls, feature_extractor_dict: Dict[str, Any], **kwargs) -> PreTrainedFeatureExtractor: """ - Instantiates a :class:`~transformers.SequenceFeatureExtractor` from a Python dictionary of parameters. + Instantiates a type of :class:`~transformers.feature_extraction_utils.FeatureExtractionMixin` from a Python + dictionary of parameters. Args: feature_extractor_dict (:obj:`Dict[str, Any]`): Dictionary that will be used to instantiate the feature extractor object. Such a dictionary can be retrieved from a pretrained checkpoint by leveraging the - :func:`~transformers.SequenceFeatureExtractor.to_dict` method. + :func:`~transformers.feature_extraction_utils.FeatureExtractionMixin.to_dict` method. kwargs (:obj:`Dict[str, Any]`): Additional parameters from which to initialize the feature extractor object. Returns: - :class:`~transformers.SequenceFeatureExtractor`: The feature extractor object instantiated from those - parameters. + :class:`~transformers.feature_extraction_utils.FeatureExtractionMixin`: The feature extractor object + instantiated from those parameters. """ return_unused_kwargs = kwargs.pop("return_unused_kwargs", False) @@ -421,15 +422,16 @@ def to_dict(self) -> Dict[str, Any]: @classmethod def from_json_file(cls, json_file: Union[str, os.PathLike]) -> PreTrainedFeatureExtractor: """ - Instantiates a :class:`~transformers.SequenceFeatureExtractor` from the path to a JSON file of parameters. + Instantiates a feature extractor of type :class:`~transformers.feature_extraction_utils.FeatureExtractionMixin` + from the path to a JSON file of parameters. Args: json_file (:obj:`str` or :obj:`os.PathLike`): Path to the JSON file containing the parameters. Returns: - :class:`~transformers.SequenceFeatureExtractor`: The feature_extractor object instantiated from that JSON - file. + A feature extractor of type :class:`~transformers.feature_extraction_utils.FeatureExtractionMixin`: The + feature_extractor object instantiated from that JSON file. """ with open(json_file, "r", encoding="utf-8") as reader: text = reader.read() diff --git a/src/transformers/models/wav2vec2/feature_extraction_wav2vec2.py b/src/transformers/models/wav2vec2/feature_extraction_wav2vec2.py index 265a9ffb97a203..6e49ba4d69352a 100644 --- a/src/transformers/models/wav2vec2/feature_extraction_wav2vec2.py +++ b/src/transformers/models/wav2vec2/feature_extraction_wav2vec2.py @@ -20,8 +20,8 @@ import numpy as np -from ...feature_extraction_utils import BatchFeature from ...feature_extraction_sequence_utils import SequenceFeatureExtractor +from ...feature_extraction_utils import BatchFeature from ...file_utils import PaddingStrategy, TensorType from ...utils import logging diff --git a/src/transformers/models/wav2vec2/processing_wav2vec2.py b/src/transformers/models/wav2vec2/processing_wav2vec2.py index b53cfa4fc911df..88e3235abd7d4f 100644 --- a/src/transformers/models/wav2vec2/processing_wav2vec2.py +++ b/src/transformers/models/wav2vec2/processing_wav2vec2.py @@ -59,7 +59,8 @@ def save_pretrained(self, save_directory): .. note:: - This class method is simply calling :meth:`~transformers.SequenceFeatureExtractor.save_pretrained` and + This class method is simply calling + :meth:`~transformers.feature_extraction_utils.FeatureExtractionMixin.save_pretrained` and :meth:`~transformers.tokenization_utils_base.PreTrainedTokenizer.save_pretrained`. Please refer to the docstrings of the methods above for more information. @@ -80,9 +81,9 @@ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): .. note:: This class method is simply calling Wav2Vec2FeatureExtractor's - :meth:`~transformers.SequenceFeatureExtractor.from_pretrained` and Wav2Vec2CTCTokenizer's - :meth:`~transformers.tokenization_utils_base.PreTrainedTokenizer.from_pretrained`. Please refer to the - docstrings of the methods above for more information. + :meth:`~transformers.feature_extraction_utils.FeatureExtractionMixin.from_pretrained` and + Wav2Vec2CTCTokenizer's :meth:`~transformers.tokenization_utils_base.PreTrainedTokenizer.from_pretrained`. + Please refer to the docstrings of the methods above for more information. Args: pretrained_model_name_or_path (:obj:`str` or :obj:`os.PathLike`): diff --git a/tests/test_feature_extraction_saving_common.py b/tests/test_feature_extraction_common.py similarity index 100% rename from tests/test_feature_extraction_saving_common.py rename to tests/test_feature_extraction_common.py diff --git a/tests/test_sequence_feature_extraction_common.py b/tests/test_sequence_feature_extraction_common.py index af0dd3d3dd6fbd..8c1777553ac6bd 100644 --- a/tests/test_sequence_feature_extraction_common.py +++ b/tests/test_sequence_feature_extraction_common.py @@ -19,7 +19,7 @@ from transformers import BatchFeature from transformers.testing_utils import require_tf, require_torch -from .test_feature_extraction_saving_common import FeatureExtractionSavingTestMixin +from .test_feature_extraction_common import FeatureExtractionSavingTestMixin class SequenceFeatureExtractionTestMixin(FeatureExtractionSavingTestMixin):