Skip to content

Commit

Permalink
Make ASR pipeline compliant with Hub spec + add tests (huggingface#33769
Browse files Browse the repository at this point in the history
)

* Remove max_new_tokens arg

* Add ASR pipeline to testing

* make fixup

* Factor the output test out into a util

* Full error reporting

* Full error reporting

* Update src/transformers/pipelines/automatic_speech_recognition.py

Co-authored-by: Lysandre Debut <hi@lysand.re>

* Small comment

---------

Co-authored-by: Lysandre Debut <hi@lysand.re>
  • Loading branch information
Rocketknight1 and LysandreJik authored Oct 1, 2024
1 parent 0256520 commit a43e84c
Show file tree
Hide file tree
Showing 5 changed files with 42 additions and 9 deletions.
7 changes: 5 additions & 2 deletions src/transformers/pipelines/automatic_speech_recognition.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
from collections import defaultdict
from typing import TYPE_CHECKING, Dict, Optional, Union

Expand Down Expand Up @@ -269,8 +270,6 @@ def __call__(
The dictionary of ad-hoc parametrization of `generate_config` to be used for the generation call. For a
complete overview of generate, check the [following
guide](https://huggingface.co/docs/transformers/en/main_classes/text_generation).
max_new_tokens (`int`, *optional*):
The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt.
Return:
`Dict`: A dictionary with the following keys:
Expand Down Expand Up @@ -310,6 +309,10 @@ def _sanitize_parameters(

forward_params = defaultdict(dict)
if max_new_tokens is not None:
warnings.warn(
"`max_new_tokens` is deprecated and will be removed in version 4.49 of Transformers. To remove this warning, pass `max_new_tokens` as a key inside `generate_kwargs` instead.",
FutureWarning,
)
forward_params["max_new_tokens"] = max_new_tokens
if generate_kwargs is not None:
if max_new_tokens is not None and "max_new_tokens" in generate_kwargs:
Expand Down
28 changes: 28 additions & 0 deletions src/transformers/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import unittest
from collections import defaultdict
from collections.abc import Mapping
from dataclasses import MISSING, fields
from functools import wraps
from io import StringIO
from pathlib import Path
Expand Down Expand Up @@ -2610,3 +2611,30 @@ def update_mapping_from_spec(device_fn_dict: Dict[str, Callable], attribute_name
update_mapping_from_spec(BACKEND_MANUAL_SEED, "MANUAL_SEED_FN")
update_mapping_from_spec(BACKEND_EMPTY_CACHE, "EMPTY_CACHE_FN")
update_mapping_from_spec(BACKEND_DEVICE_COUNT, "DEVICE_COUNT_FN")


def compare_pipeline_output_to_hub_spec(output, hub_spec):
missing_keys = []
unexpected_keys = []
all_field_names = {field.name for field in fields(hub_spec)}
matching_keys = sorted([key for key in output.keys() if key in all_field_names])

# Fields with a MISSING default are required and must be in the output
for field in fields(hub_spec):
if field.default is MISSING and field.name not in output:
missing_keys.append(field.name)

# All output keys must match either a required or optional field in the Hub spec
for output_key in output:
if output_key not in all_field_names:
unexpected_keys.append(output_key)

if missing_keys or unexpected_keys:
error = ["Pipeline output does not match Hub spec!"]
if matching_keys:
error.append(f"Matching keys: {matching_keys}")
if missing_keys:
error.append(f"Missing required keys in pipeline output: {missing_keys}")
if unexpected_keys:
error.append(f"Keys in pipeline output that are not in Hub spec: {unexpected_keys}")
raise KeyError("\n".join(error))
6 changes: 2 additions & 4 deletions tests/pipelines/test_pipelines_audio_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,14 @@
# limitations under the License.

import unittest
from dataclasses import fields

import numpy as np
from huggingface_hub import AudioClassificationOutputElement

from transformers import MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING, TF_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING
from transformers.pipelines import AudioClassificationPipeline, pipeline
from transformers.testing_utils import (
compare_pipeline_output_to_hub_spec,
is_pipeline_test,
nested_simplify,
require_tf,
Expand Down Expand Up @@ -68,10 +68,8 @@ def run_pipeline_test(self, audio_classifier, examples):

self.run_torchaudio(audio_classifier)

spec_output_keys = {field.name for field in fields(AudioClassificationOutputElement)}
for single_output in output:
output_keys = set(single_output.keys())
self.assertEqual(spec_output_keys, output_keys, msg="Pipeline output keys do not match HF Hub spec!")
compare_pipeline_output_to_hub_spec(single_output, AudioClassificationOutputElement)

@require_torchaudio
def run_torchaudio(self, audio_classifier):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import numpy as np
import pytest
from datasets import Audio, load_dataset
from huggingface_hub import hf_hub_download, snapshot_download
from huggingface_hub import AutomaticSpeechRecognitionOutput, hf_hub_download, snapshot_download

from transformers import (
MODEL_FOR_CTC_MAPPING,
Expand All @@ -36,6 +36,7 @@
from transformers.pipelines.audio_utils import chunk_bytes_iter
from transformers.pipelines.automatic_speech_recognition import _find_timestamp_sequence, chunk_iter
from transformers.testing_utils import (
compare_pipeline_output_to_hub_spec,
is_pipeline_test,
is_torch_available,
nested_simplify,
Expand Down Expand Up @@ -86,6 +87,8 @@ def run_pipeline_test(self, speech_recognizer, examples):
outputs = speech_recognizer(audio)
self.assertEqual(outputs, {"text": ANY(str)})

compare_pipeline_output_to_hub_spec(outputs, AutomaticSpeechRecognitionOutput)

# Striding
audio = {"raw": audio, "stride": (0, 4000), "sampling_rate": speech_recognizer.feature_extractor.sampling_rate}
if speech_recognizer.type == "ctc":
Expand Down
5 changes: 3 additions & 2 deletions tests/test_pipeline_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@
from textwrap import dedent
from typing import get_args

from huggingface_hub import AudioClassificationInput
from huggingface_hub import AudioClassificationInput, AutomaticSpeechRecognitionInput

from transformers.pipelines import AudioClassificationPipeline
from transformers.pipelines import AudioClassificationPipeline, AutomaticSpeechRecognitionPipeline
from transformers.testing_utils import (
is_pipeline_test,
require_decord,
Expand Down Expand Up @@ -104,6 +104,7 @@
# Adding a task to this list will cause its pipeline input signature to be checked against the corresponding
# task spec in the HF Hub
"audio-classification": (AudioClassificationPipeline, AudioClassificationInput),
"automatic-speech-recognition": (AutomaticSpeechRecognitionPipeline, AutomaticSpeechRecognitionInput),
}

for task, task_info in pipeline_test_mapping.items():
Expand Down

0 comments on commit a43e84c

Please sign in to comment.