diff --git a/src/transformers/pipelines/automatic_speech_recognition.py b/src/transformers/pipelines/automatic_speech_recognition.py index bfebba5abbe..e4745bcd81e 100644 --- a/src/transformers/pipelines/automatic_speech_recognition.py +++ b/src/transformers/pipelines/automatic_speech_recognition.py @@ -66,6 +66,39 @@ def ffmpeg_read(bpayload: bytes, sampling_rate: int) -> np.array: return audio +def apply_stride(tokens, stride): + max_token_n = tokens.shape[-1] + max_input_n = max(input_n for input_n, _, _ in stride) + ratio = max_token_n / max_input_n + for i, (input_n, left, right) in enumerate(stride): + token_n = int(round(input_n * ratio)) + left_token = int(round(left / input_n * token_n)) + right_token = int(round((input_n - right) / input_n * token_n)) + # This is CTC to preseve decoding, we need to duplicate + # next letter, and last letter + + first_letter = tokens[i, left_token] + tokens[i, :left_token] = first_letter + + last_letter = tokens[i, right_token - 1] + tokens[i, right_token:] = last_letter + + +def chunk_iter(inputs, feature_extractor, chunk_len, stride_left, stride_right): + inputs_len = inputs.shape[0] + step = chunk_len - stride_left - stride_right + for i in range(0, inputs_len, step): + # add start and end paddings to the chunk + chunk = inputs[i : i + chunk_len] + processed = feature_extractor(chunk, sampling_rate=feature_extractor.sampling_rate, return_tensors="pt") + _stride_left = 0 if i == 0 else stride_left + is_last = i + step >= inputs_len + _stride_right = 0 if is_last else stride_right + + if chunk.shape[0] > _stride_left: + yield {"is_last": is_last, "stride": (chunk.shape[0], _stride_left, _stride_right), **processed} + + class AutomaticSpeechRecognitionPipeline(ChunkPipeline): """ Pipeline that aims at extracting spoken text contained within some audio. @@ -85,11 +118,11 @@ def __init__(self, feature_extractor: Union["SequenceFeatureExtractor", str], *a tokenizer ([`PreTrainedTokenizer`]): The tokenizer that will be used by the pipeline to encode data for the model. This object inherits from [`PreTrainedTokenizer`]. - chunk_length_ms (`int`, *optional*, defaults to 0): + chunk_length_s (`float`, *optional*, defaults to 0): The input length for in each chunk. If `0` then chunking is disabled (default). Only available for CTC models. - stride_length_ms (`int`, *optional*, defaults to `chunk_length_ms / 6`): - The length of stride on the left and right of each chunk. Used only with `chunk_length_ms > 0`. This + stride_length_s (`float`, *optional*, defaults to `chunk_length_s / 6`): + The length of stride on the left and right of each chunk. Used only with `chunk_length_s > 0`. This enables the model to *see* more context and infer letters better than without this context but the pipeline discards the stride bits at the end to make the final reconstitution as perfect as possible. framework (`str`, *optional*): @@ -111,6 +144,7 @@ def __init__(self, feature_extractor: Union["SequenceFeatureExtractor", str], *a raise ValueError("The AutomaticSpeechRecognitionPipeline is only available in PyTorch.") self.check_model_type(dict(MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING.items() + MODEL_FOR_CTC_MAPPING.items())) + self.is_ctc = self.model.__class__ in MODEL_FOR_CTC_MAPPING.values() def __call__( self, @@ -139,13 +173,13 @@ def __call__( def _sanitize_parameters(self, **kwargs): # No parameters on this pipeline right now preprocess_params = {} - if "chunk_length_ms" in kwargs: - preprocess_params["chunk_length_ms"] = kwargs["chunk_length_ms"] - if "stride_length_ms" in kwargs: - preprocess_params["stride_length_ms"] = kwargs["stride_length_ms"] + if "chunk_length_s" in kwargs: + preprocess_params["chunk_length_s"] = kwargs["chunk_length_s"] + if "stride_length_s" in kwargs: + preprocess_params["stride_length_s"] = kwargs["stride_length_s"] return preprocess_params, {}, {} - def preprocess(self, inputs, chunk_length_ms=0, stride_length_ms=None): + def preprocess(self, inputs, chunk_length_s=0, stride_length_s=None): if isinstance(inputs, str): with open(inputs, "rb") as f: inputs = f.read() @@ -158,39 +192,28 @@ def preprocess(self, inputs, chunk_length_ms=0, stride_length_ms=None): if len(inputs.shape) != 1: raise ValueError("We expect a single channel audio input for AutomaticSpeechRecognitionPipeline") - if chunk_length_ms: - if stride_length_ms is None: - stride_length_ms = chunk_length_ms // 6 - inputs_len = len(inputs) - chunk_len = chunk_length_ms * self.feature_extractor.sampling_rate // 1000 - stride_len = stride_length_ms * self.feature_extractor.sampling_rate // 1000 + if chunk_length_s: + if stride_length_s is None: + stride_length_s = chunk_length_s / 6 + + chunk_len = int(round(chunk_length_s * self.feature_extractor.sampling_rate)) + + if isinstance(stride_length_s, (int, float)): + stride_length_s = [stride_length_s, stride_length_s] - # Redefine chunk_len to useful chunk length - # Not the size - # chunk_len = chunk_len - 2 * stride_len + stride_left = int(round(stride_length_s[0] * self.feature_extractor.sampling_rate)) + stride_right = int(round(stride_length_s[1] * self.feature_extractor.sampling_rate)) - if self.model.__class__ not in MODEL_FOR_CTC_MAPPING.values(): + if not self.is_ctc: raise ValueError( - "`chunk_length_ms` is only valid for CTC models, use other chunking options for other models" + "`chunk_length_s` is only valid for CTC models, use other chunking options for other models" ) - if chunk_len < stride_len: + if chunk_len < stride_left + stride_right: raise ValueError("Chunk length must be superior to stride length") # make sure that - step = chunk_len - for i in range(0, inputs_len, step): - # add start and end paddings to the chunk - start = 0 if i - stride_len < 0 else i - stride_len - stop = inputs_len if i + chunk_len + stride_len > inputs_len else i + chunk_len + stride_len - chunk = inputs[start:stop] - processed = self.feature_extractor( - chunk, sampling_rate=self.feature_extractor.sampling_rate, return_tensors="pt" - ) - stride_left = i - start - stride_right = max(stop - (i + chunk_len), 0) - is_last = i + step > inputs_len - - yield {"is_last": is_last, "stride": (stop - start, stride_left, stride_right), **processed} + for item in chunk_iter(inputs, self.feature_extractor, chunk_len, stride_left, stride_right): + yield item else: processed = self.feature_extractor( inputs, sampling_rate=self.feature_extractor.sampling_rate, return_tensors="pt" @@ -198,8 +221,8 @@ def preprocess(self, inputs, chunk_length_ms=0, stride_length_ms=None): yield {"is_last": True, **processed} def _forward(self, model_inputs): - model_class = self.model.__class__ is_last = model_inputs.pop("is_last") + model_class = self.model.__class__ if model_class in MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING.values(): encoder = self.model.get_encoder() # we need to pass `processed.get("attention_mask")` here since audio encoder @@ -217,15 +240,7 @@ def _forward(self, model_inputs): if isinstance(stride, tuple): stride = [stride] - max_token_n = tokens.shape[-1] - max_input_n = max(input_n for input_n, _, _ in stride) - ratio = max_token_n / max_input_n - for i, (input_n, left, right) in enumerate(stride): - token_n = int(input_n * ratio) + 1 - left_token = int(left / input_n * token_n) - right_token = int((input_n - right) / input_n * token_n) + 1 - tokens[i, :left_token] = self.tokenizer.pad_token_id - tokens[i, right_token:] = self.tokenizer.pad_token_id + apply_stride(tokens, stride) else: logger.warning("This is an unknown class, treating it as CTC.") outputs = self.model(**model_inputs) diff --git a/src/transformers/pipelines/pt_utils.py b/src/transformers/pipelines/pt_utils.py index 4a94caf4936..8eb7ac77982 100644 --- a/src/transformers/pipelines/pt_utils.py +++ b/src/transformers/pipelines/pt_utils.py @@ -276,7 +276,7 @@ def __next__(self): else: item = processed is_last = item.pop("is_last") - accumulator.append(item) + accumulator.append(item) return accumulator diff --git a/tests/test_pipelines_automatic_speech_recognition.py b/tests/test_pipelines_automatic_speech_recognition.py index ec8996e922f..f951b8a90f5 100644 --- a/tests/test_pipelines_automatic_speech_recognition.py +++ b/tests/test_pipelines_automatic_speech_recognition.py @@ -27,11 +27,24 @@ Wav2Vec2ForCTC, ) from transformers.pipelines import AutomaticSpeechRecognitionPipeline, pipeline -from transformers.testing_utils import is_pipeline_test, require_tf, require_torch, require_torchaudio, slow +from transformers.pipelines.automatic_speech_recognition import apply_stride, chunk_iter +from transformers.testing_utils import ( + is_pipeline_test, + is_torch_available, + nested_simplify, + require_tf, + require_torch, + require_torchaudio, + slow, +) from .test_pipelines_common import ANY, PipelineTestCaseMeta +if is_torch_available(): + import torch + + # We can't use this mixin because it assumes TF support. # from .test_pipelines_common import CustomInputPipelineCommonMixin @@ -245,17 +258,119 @@ def test_chunking(self): tokenizer=tokenizer, feature_extractor=feature_extractor, framework="pt", - chunk_length_ms=10_000, + chunk_length_s=10.0, ) - from datasets import load_dataset - ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation").sort("id") audio = ds[40]["audio"]["array"] - n_repeats = 100 + n_repeats = 10 audio = np.tile(audio, n_repeats) output = speech_recognizer([audio], batch_size=2) expected_text = "A MAN SAID TO THE UNIVERSE SIR I EXIST " * n_repeats expected = [{"text": expected_text.strip()}] self.assertEqual(output, expected) + + @require_torch + def test_chunk_iterator(self): + feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-base-960h") + inputs = torch.arange(100).long() + + outs = list(chunk_iter(inputs, feature_extractor, 100, 0, 0)) + self.assertEqual(len(outs), 1) + self.assertEqual([o["stride"] for o in outs], [(100, 0, 0)]) + self.assertEqual([o["input_values"].shape for o in outs], [(1, 100)]) + self.assertEqual([o["is_last"] for o in outs], [True]) + + # two chunks no stride + outs = list(chunk_iter(inputs, feature_extractor, 50, 0, 0)) + self.assertEqual(len(outs), 2) + self.assertEqual([o["stride"] for o in outs], [(50, 0, 0), (50, 0, 0)]) + self.assertEqual([o["input_values"].shape for o in outs], [(1, 50), (1, 50)]) + self.assertEqual([o["is_last"] for o in outs], [False, True]) + + # two chunks incomplete last + outs = list(chunk_iter(inputs, feature_extractor, 80, 0, 0)) + self.assertEqual(len(outs), 2) + self.assertEqual([o["stride"] for o in outs], [(80, 0, 0), (20, 0, 0)]) + self.assertEqual([o["input_values"].shape for o in outs], [(1, 80), (1, 20)]) + self.assertEqual([o["is_last"] for o in outs], [False, True]) + + @require_torch + def test_chunk_iterator_stride(self): + feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-base-960h") + inputs = torch.arange(100).long() + input_values = feature_extractor(inputs, sampling_rate=feature_extractor.sampling_rate, return_tensors="pt")[ + "input_values" + ] + + outs = list(chunk_iter(inputs, feature_extractor, 100, 20, 10)) + self.assertEqual(len(outs), 2) + self.assertEqual([o["stride"] for o in outs], [(100, 0, 10), (30, 20, 0)]) + self.assertEqual([o["input_values"].shape for o in outs], [(1, 100), (1, 30)]) + self.assertEqual([o["is_last"] for o in outs], [False, True]) + + outs = list(chunk_iter(inputs, feature_extractor, 80, 20, 10)) + self.assertEqual(len(outs), 2) + self.assertEqual([o["stride"] for o in outs], [(80, 0, 10), (50, 20, 0)]) + self.assertEqual([o["input_values"].shape for o in outs], [(1, 80), (1, 50)]) + self.assertEqual([o["is_last"] for o in outs], [False, True]) + + outs = list(chunk_iter(inputs, feature_extractor, 90, 20, 0)) + self.assertEqual(len(outs), 2) + self.assertEqual([o["stride"] for o in outs], [(90, 0, 0), (30, 20, 0)]) + self.assertEqual([o["input_values"].shape for o in outs], [(1, 90), (1, 30)]) + + inputs = torch.LongTensor([i % 2 for i in range(100)]) + input_values = feature_extractor(inputs, sampling_rate=feature_extractor.sampling_rate, return_tensors="pt")[ + "input_values" + ] + outs = list(chunk_iter(inputs, feature_extractor, 30, 5, 5)) + self.assertEqual(len(outs), 5) + self.assertEqual([o["stride"] for o in outs], [(30, 0, 5), (30, 5, 5), (30, 5, 5), (30, 5, 5), (20, 5, 0)]) + self.assertEqual([o["input_values"].shape for o in outs], [(1, 30), (1, 30), (1, 30), (1, 30), (1, 20)]) + self.assertEqual([o["is_last"] for o in outs], [False, False, False, False, True]) + # (0, 25) + self.assertEqual(nested_simplify(input_values[:, :30]), nested_simplify(outs[0]["input_values"])) + # (25, 45) + self.assertEqual(nested_simplify(input_values[:, 20:50]), nested_simplify(outs[1]["input_values"])) + # (45, 65) + self.assertEqual(nested_simplify(input_values[:, 40:70]), nested_simplify(outs[2]["input_values"])) + # (65, 85) + self.assertEqual(nested_simplify(input_values[:, 60:90]), nested_simplify(outs[3]["input_values"])) + # (85, 100) + self.assertEqual(nested_simplify(input_values[:, 80:100]), nested_simplify(outs[4]["input_values"])) + + +@require_torch +class ApplyStrideTest(unittest.TestCase): + def test_apply_stride(self): + tokens = torch.arange(10).long().reshape((2, 5)) + + # No stride + apply_stride(tokens, [(100, 0, 0), (100, 0, 0)]) + + expected = torch.arange(10).long().reshape((2, 5)) + self.assertEqual(expected.tolist(), tokens.tolist()) + + def test_apply_stride_real_stride(self): + # Stride aligned + tokens = torch.arange(10).long().reshape((2, 5)) + apply_stride(tokens, [(100, 20, 0), (100, 0, 20)]) + self.assertEqual([[1, 1, 2, 3, 4], [5, 6, 7, 8, 8]], tokens.tolist()) + + # Stride rounded + tokens = torch.arange(10).long().reshape((2, 5)) + apply_stride(tokens, [(100, 15, 0), (100, 0, 15)]) + self.assertEqual([[1, 1, 2, 3, 4], [5, 6, 7, 8, 8]], tokens.tolist()) + + # No stride rounded + tokens = torch.arange(10).long().reshape((2, 5)) + apply_stride(tokens, [(100, 5, 0), (100, 0, 5)]) + self.assertEqual([[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]], tokens.tolist()) + + def test_apply_stride_with_padding(self): + # Stride aligned + tokens = torch.arange(10).long().reshape((2, 5)) + apply_stride(tokens, [(100, 20, 0), (60, 0, 20)]) + self.assertEqual([[1, 1, 2, 3, 4], [5, 6, 6, 6, 6]], tokens.tolist()) diff --git a/tests/test_pipelines_common.py b/tests/test_pipelines_common.py index e761647e668..c832f3c8917 100644 --- a/tests/test_pipelines_common.py +++ b/tests/test_pipelines_common.py @@ -584,3 +584,14 @@ def add(number, extra=0): outputs = [item for item in dataset] self.assertEqual(outputs, [[{"id": 2}, {"id": 3}], [{"id": 4}, {"id": 5}]]) + + # is_false Across batch + dummy_dataset = [{"id": [0, 1, 2], "is_last": [False, False, False]}, {"id": [3], "is_last": [True]}] + + def add(number, extra=0): + return {"id": [i + extra for i in number["id"]], "is_last": number["is_last"]} + + dataset = PipelinePackIterator(dummy_dataset, add, {"extra": 2}, loader_batch_size=3) + + outputs = [item for item in dataset] + self.assertEqual(outputs, [[{"id": 2}, {"id": 3}, {"id": 4}, {"id": 5}]])