Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make ASR pipeline compliant with Hub spec + add tests #33769

Merged
merged 8 commits into from
Oct 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions src/transformers/pipelines/automatic_speech_recognition.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
from collections import defaultdict
from typing import TYPE_CHECKING, Dict, Optional, Union

Expand Down Expand Up @@ -269,8 +270,6 @@ def __call__(
The dictionary of ad-hoc parametrization of `generate_config` to be used for the generation call. For a
complete overview of generate, check the [following
guide](https://huggingface.co/docs/transformers/en/main_classes/text_generation).
max_new_tokens (`int`, *optional*):
The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt.
Return:
`Dict`: A dictionary with the following keys:
Expand Down Expand Up @@ -310,6 +309,10 @@ def _sanitize_parameters(

forward_params = defaultdict(dict)
if max_new_tokens is not None:
warnings.warn(
"`max_new_tokens` is deprecated and will be removed in version 4.49 of Transformers. To remove this warning, pass `max_new_tokens` as a key inside `generate_kwargs` instead.",
FutureWarning,
)
forward_params["max_new_tokens"] = max_new_tokens
if generate_kwargs is not None:
if max_new_tokens is not None and "max_new_tokens" in generate_kwargs:
Expand Down
28 changes: 28 additions & 0 deletions src/transformers/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import unittest
from collections import defaultdict
from collections.abc import Mapping
from dataclasses import MISSING, fields
from functools import wraps
from io import StringIO
from pathlib import Path
Expand Down Expand Up @@ -2610,3 +2611,30 @@ def update_mapping_from_spec(device_fn_dict: Dict[str, Callable], attribute_name
update_mapping_from_spec(BACKEND_MANUAL_SEED, "MANUAL_SEED_FN")
update_mapping_from_spec(BACKEND_EMPTY_CACHE, "EMPTY_CACHE_FN")
update_mapping_from_spec(BACKEND_DEVICE_COUNT, "DEVICE_COUNT_FN")


def compare_pipeline_output_to_hub_spec(output, hub_spec):
missing_keys = []
unexpected_keys = []
all_field_names = {field.name for field in fields(hub_spec)}
matching_keys = sorted([key for key in output.keys() if key in all_field_names])

# Fields with a MISSING default are required and must be in the output
for field in fields(hub_spec):
if field.default is MISSING and field.name not in output:
missing_keys.append(field.name)

# All output keys must match either a required or optional field in the Hub spec
for output_key in output:
if output_key not in all_field_names:
unexpected_keys.append(output_key)

if missing_keys or unexpected_keys:
error = ["Pipeline output does not match Hub spec!"]
if matching_keys:
error.append(f"Matching keys: {matching_keys}")
if missing_keys:
error.append(f"Missing required keys in pipeline output: {missing_keys}")
if unexpected_keys:
error.append(f"Keys in pipeline output that are not in Hub spec: {unexpected_keys}")
raise KeyError("\n".join(error))
6 changes: 2 additions & 4 deletions tests/pipelines/test_pipelines_audio_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,14 @@
# limitations under the License.

import unittest
from dataclasses import fields

import numpy as np
from huggingface_hub import AudioClassificationOutputElement

from transformers import MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING, TF_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING
from transformers.pipelines import AudioClassificationPipeline, pipeline
from transformers.testing_utils import (
compare_pipeline_output_to_hub_spec,
is_pipeline_test,
nested_simplify,
require_tf,
Expand Down Expand Up @@ -68,10 +68,8 @@ def run_pipeline_test(self, audio_classifier, examples):

self.run_torchaudio(audio_classifier)

spec_output_keys = {field.name for field in fields(AudioClassificationOutputElement)}
for single_output in output:
output_keys = set(single_output.keys())
self.assertEqual(spec_output_keys, output_keys, msg="Pipeline output keys do not match HF Hub spec!")
compare_pipeline_output_to_hub_spec(single_output, AudioClassificationOutputElement)

@require_torchaudio
def run_torchaudio(self, audio_classifier):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import numpy as np
import pytest
from datasets import Audio, load_dataset
from huggingface_hub import hf_hub_download, snapshot_download
from huggingface_hub import AutomaticSpeechRecognitionOutput, hf_hub_download, snapshot_download

from transformers import (
MODEL_FOR_CTC_MAPPING,
Expand All @@ -36,6 +36,7 @@
from transformers.pipelines.audio_utils import chunk_bytes_iter
from transformers.pipelines.automatic_speech_recognition import _find_timestamp_sequence, chunk_iter
from transformers.testing_utils import (
compare_pipeline_output_to_hub_spec,
is_pipeline_test,
is_torch_available,
nested_simplify,
Expand Down Expand Up @@ -86,6 +87,8 @@ def run_pipeline_test(self, speech_recognizer, examples):
outputs = speech_recognizer(audio)
self.assertEqual(outputs, {"text": ANY(str)})

compare_pipeline_output_to_hub_spec(outputs, AutomaticSpeechRecognitionOutput)

# Striding
audio = {"raw": audio, "stride": (0, 4000), "sampling_rate": speech_recognizer.feature_extractor.sampling_rate}
if speech_recognizer.type == "ctc":
Expand Down
5 changes: 3 additions & 2 deletions tests/test_pipeline_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@
from textwrap import dedent
from typing import get_args

from huggingface_hub import AudioClassificationInput
from huggingface_hub import AudioClassificationInput, AutomaticSpeechRecognitionInput

from transformers.pipelines import AudioClassificationPipeline
from transformers.pipelines import AudioClassificationPipeline, AutomaticSpeechRecognitionPipeline
from transformers.testing_utils import (
is_pipeline_test,
require_decord,
Expand Down Expand Up @@ -104,6 +104,7 @@
# Adding a task to this list will cause its pipeline input signature to be checked against the corresponding
# task spec in the HF Hub
"audio-classification": (AudioClassificationPipeline, AudioClassificationInput),
"automatic-speech-recognition": (AutomaticSpeechRecognitionPipeline, AutomaticSpeechRecognitionInput),
}

for task, task_info in pipeline_test_mapping.items():
Expand Down
Loading