Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hotfix chunk_length_s instead of _ms. #15029

Merged
merged 4 commits into from
Jan 4, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 59 additions & 44 deletions src/transformers/pipelines/automatic_speech_recognition.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,39 @@ def ffmpeg_read(bpayload: bytes, sampling_rate: int) -> np.array:
return audio


def apply_stride(tokens, stride):
max_token_n = tokens.shape[-1]
max_input_n = max(input_n for input_n, _, _ in stride)
ratio = max_token_n / max_input_n
for i, (input_n, left, right) in enumerate(stride):
token_n = int(round(input_n * ratio))
left_token = int(round(left / input_n * token_n))
right_token = int(round((input_n - right) / input_n * token_n))
# This is CTC to preseve decoding, we need to duplicate
# next letter, and last letter

first_letter = tokens[i, left_token]
tokens[i, :left_token] = first_letter

last_letter = tokens[i, right_token - 1]
tokens[i, right_token:] = last_letter


def chunk_iter(inputs, feature_extractor, chunk_len, stride_left, stride_right):
inputs_len = inputs.shape[0]
step = chunk_len - stride_left - stride_right
Copy link

@systemdevart systemdevart Mar 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does stride mean in this context? Because step doesn't equal chunk_size + stride_left, it is kind of hard to interpret what exactly stride_left or stride_right represent.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you calculate the step in this way, the actual overlap with the left chunk would be 2 * (stride_left + stride_right).
Consider, for example, chunk_len = 15, stride_left = 3, stride_right = 3, and step = 15-3-3 = 9.
0 chunk: start=0, end=15,
1 chunk: start=9, end=24,
2 chunk: start=18, end=33

For 1 chunk, the overlap with the 0 chunk is 6, and the overlap with the 2 chunk is 6, totaling 12 of overlap, but the intended overlap was 3 with the left chunk and 3 with the right chunk, totaling 5 of overlap.

for i in range(0, inputs_len, step):
# add start and end paddings to the chunk
chunk = inputs[i : i + chunk_len]
processed = feature_extractor(chunk, sampling_rate=feature_extractor.sampling_rate, return_tensors="pt")
_stride_left = 0 if i == 0 else stride_left
is_last = i + step >= inputs_len
_stride_right = 0 if is_last else stride_right

if chunk.shape[0] > _stride_left:
yield {"is_last": is_last, "stride": (chunk.shape[0], _stride_left, _stride_right), **processed}


class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
"""
Pipeline that aims at extracting spoken text contained within some audio.
Expand All @@ -85,11 +118,11 @@ def __init__(self, feature_extractor: Union["SequenceFeatureExtractor", str], *a
tokenizer ([`PreTrainedTokenizer`]):
The tokenizer that will be used by the pipeline to encode data for the model. This object inherits from
[`PreTrainedTokenizer`].
chunk_length_ms (`int`, *optional*, defaults to 0):
chunk_length_s (`float`, *optional*, defaults to 0):
The input length for in each chunk. If `0` then chunking is disabled (default). Only available for CTC
models.
stride_length_ms (`int`, *optional*, defaults to `chunk_length_ms / 6`):
The length of stride on the left and right of each chunk. Used only with `chunk_length_ms > 0`. This
stride_length_s (`float`, *optional*, defaults to `chunk_length_s / 6`):
The length of stride on the left and right of each chunk. Used only with `chunk_length_s > 0`. This
enables the model to *see* more context and infer letters better than without this context but the
pipeline discards the stride bits at the end to make the final reconstitution as perfect as possible.
framework (`str`, *optional*):
Expand All @@ -111,6 +144,7 @@ def __init__(self, feature_extractor: Union["SequenceFeatureExtractor", str], *a
raise ValueError("The AutomaticSpeechRecognitionPipeline is only available in PyTorch.")

self.check_model_type(dict(MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING.items() + MODEL_FOR_CTC_MAPPING.items()))
self.is_ctc = self.model.__class__ in MODEL_FOR_CTC_MAPPING.values()

def __call__(
self,
Expand Down Expand Up @@ -139,13 +173,13 @@ def __call__(
def _sanitize_parameters(self, **kwargs):
# No parameters on this pipeline right now
preprocess_params = {}
if "chunk_length_ms" in kwargs:
preprocess_params["chunk_length_ms"] = kwargs["chunk_length_ms"]
if "stride_length_ms" in kwargs:
preprocess_params["stride_length_ms"] = kwargs["stride_length_ms"]
if "chunk_length_s" in kwargs:
preprocess_params["chunk_length_s"] = kwargs["chunk_length_s"]
if "stride_length_s" in kwargs:
preprocess_params["stride_length_s"] = kwargs["stride_length_s"]
return preprocess_params, {}, {}

def preprocess(self, inputs, chunk_length_ms=0, stride_length_ms=None):
def preprocess(self, inputs, chunk_length_s=0, stride_length_s=None):
if isinstance(inputs, str):
with open(inputs, "rb") as f:
inputs = f.read()
Expand All @@ -158,48 +192,37 @@ def preprocess(self, inputs, chunk_length_ms=0, stride_length_ms=None):
if len(inputs.shape) != 1:
raise ValueError("We expect a single channel audio input for AutomaticSpeechRecognitionPipeline")

if chunk_length_ms:
if stride_length_ms is None:
stride_length_ms = chunk_length_ms // 6
inputs_len = len(inputs)
chunk_len = chunk_length_ms * self.feature_extractor.sampling_rate // 1000
stride_len = stride_length_ms * self.feature_extractor.sampling_rate // 1000
if chunk_length_s:
if stride_length_s is None:
stride_length_s = chunk_length_s / 6

chunk_len = int(round(chunk_length_s * self.feature_extractor.sampling_rate))

if isinstance(stride_length_s, (int, float)):
stride_length_s = [stride_length_s, stride_length_s]

# Redefine chunk_len to useful chunk length
# Not the size
# chunk_len = chunk_len - 2 * stride_len
stride_left = int(round(stride_length_s[0] * self.feature_extractor.sampling_rate))
stride_right = int(round(stride_length_s[1] * self.feature_extractor.sampling_rate))

if self.model.__class__ not in MODEL_FOR_CTC_MAPPING.values():
if not self.is_ctc:
raise ValueError(
"`chunk_length_ms` is only valid for CTC models, use other chunking options for other models"
"`chunk_length_s` is only valid for CTC models, use other chunking options for other models"
)
if chunk_len < stride_len:
if chunk_len < stride_left + stride_right:
raise ValueError("Chunk length must be superior to stride length")

# make sure that
step = chunk_len
for i in range(0, inputs_len, step):
# add start and end paddings to the chunk
start = 0 if i - stride_len < 0 else i - stride_len
stop = inputs_len if i + chunk_len + stride_len > inputs_len else i + chunk_len + stride_len
chunk = inputs[start:stop]
processed = self.feature_extractor(
chunk, sampling_rate=self.feature_extractor.sampling_rate, return_tensors="pt"
)
stride_left = i - start
stride_right = max(stop - (i + chunk_len), 0)
is_last = i + step > inputs_len

yield {"is_last": is_last, "stride": (stop - start, stride_left, stride_right), **processed}
for item in chunk_iter(inputs, self.feature_extractor, chunk_len, stride_left, stride_right):
yield item
else:
processed = self.feature_extractor(
inputs, sampling_rate=self.feature_extractor.sampling_rate, return_tensors="pt"
)
yield {"is_last": True, **processed}

def _forward(self, model_inputs):
model_class = self.model.__class__
is_last = model_inputs.pop("is_last")
model_class = self.model.__class__
if model_class in MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING.values():
encoder = self.model.get_encoder()
# we need to pass `processed.get("attention_mask")` here since audio encoder
Expand All @@ -217,15 +240,7 @@ def _forward(self, model_inputs):
if isinstance(stride, tuple):
stride = [stride]

max_token_n = tokens.shape[-1]
max_input_n = max(input_n for input_n, _, _ in stride)
ratio = max_token_n / max_input_n
for i, (input_n, left, right) in enumerate(stride):
token_n = int(input_n * ratio) + 1
left_token = int(left / input_n * token_n)
right_token = int((input_n - right) / input_n * token_n) + 1
tokens[i, :left_token] = self.tokenizer.pad_token_id
tokens[i, right_token:] = self.tokenizer.pad_token_id
apply_stride(tokens, stride)
else:
logger.warning("This is an unknown class, treating it as CTC.")
outputs = self.model(**model_inputs)
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/pipelines/pt_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ def __next__(self):
else:
item = processed
is_last = item.pop("is_last")
accumulator.append(item)
accumulator.append(item)
return accumulator


Expand Down
125 changes: 120 additions & 5 deletions tests/test_pipelines_automatic_speech_recognition.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,24 @@
Wav2Vec2ForCTC,
)
from transformers.pipelines import AutomaticSpeechRecognitionPipeline, pipeline
from transformers.testing_utils import is_pipeline_test, require_tf, require_torch, require_torchaudio, slow
from transformers.pipelines.automatic_speech_recognition import apply_stride, chunk_iter
from transformers.testing_utils import (
is_pipeline_test,
is_torch_available,
nested_simplify,
require_tf,
require_torch,
require_torchaudio,
slow,
)

from .test_pipelines_common import ANY, PipelineTestCaseMeta


if is_torch_available():
import torch


# We can't use this mixin because it assumes TF support.
# from .test_pipelines_common import CustomInputPipelineCommonMixin

Expand Down Expand Up @@ -245,17 +258,119 @@ def test_chunking(self):
tokenizer=tokenizer,
feature_extractor=feature_extractor,
framework="pt",
chunk_length_ms=10_000,
chunk_length_s=10.0,
)

from datasets import load_dataset

ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation").sort("id")
audio = ds[40]["audio"]["array"]

n_repeats = 100
n_repeats = 10
audio = np.tile(audio, n_repeats)
output = speech_recognizer([audio], batch_size=2)
expected_text = "A MAN SAID TO THE UNIVERSE SIR I EXIST " * n_repeats
expected = [{"text": expected_text.strip()}]
self.assertEqual(output, expected)

@require_torch
def test_chunk_iterator(self):
feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-base-960h")
inputs = torch.arange(100).long()

outs = list(chunk_iter(inputs, feature_extractor, 100, 0, 0))
self.assertEqual(len(outs), 1)
self.assertEqual([o["stride"] for o in outs], [(100, 0, 0)])
self.assertEqual([o["input_values"].shape for o in outs], [(1, 100)])
self.assertEqual([o["is_last"] for o in outs], [True])

# two chunks no stride
outs = list(chunk_iter(inputs, feature_extractor, 50, 0, 0))
self.assertEqual(len(outs), 2)
self.assertEqual([o["stride"] for o in outs], [(50, 0, 0), (50, 0, 0)])
self.assertEqual([o["input_values"].shape for o in outs], [(1, 50), (1, 50)])
self.assertEqual([o["is_last"] for o in outs], [False, True])

# two chunks incomplete last
outs = list(chunk_iter(inputs, feature_extractor, 80, 0, 0))
self.assertEqual(len(outs), 2)
self.assertEqual([o["stride"] for o in outs], [(80, 0, 0), (20, 0, 0)])
self.assertEqual([o["input_values"].shape for o in outs], [(1, 80), (1, 20)])
self.assertEqual([o["is_last"] for o in outs], [False, True])

@require_torch
def test_chunk_iterator_stride(self):
feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-base-960h")
inputs = torch.arange(100).long()
input_values = feature_extractor(inputs, sampling_rate=feature_extractor.sampling_rate, return_tensors="pt")[
"input_values"
]

outs = list(chunk_iter(inputs, feature_extractor, 100, 20, 10))
self.assertEqual(len(outs), 2)
self.assertEqual([o["stride"] for o in outs], [(100, 0, 10), (30, 20, 0)])
self.assertEqual([o["input_values"].shape for o in outs], [(1, 100), (1, 30)])
self.assertEqual([o["is_last"] for o in outs], [False, True])

outs = list(chunk_iter(inputs, feature_extractor, 80, 20, 10))
self.assertEqual(len(outs), 2)
self.assertEqual([o["stride"] for o in outs], [(80, 0, 10), (50, 20, 0)])
self.assertEqual([o["input_values"].shape for o in outs], [(1, 80), (1, 50)])
self.assertEqual([o["is_last"] for o in outs], [False, True])

outs = list(chunk_iter(inputs, feature_extractor, 90, 20, 0))
self.assertEqual(len(outs), 2)
self.assertEqual([o["stride"] for o in outs], [(90, 0, 0), (30, 20, 0)])
self.assertEqual([o["input_values"].shape for o in outs], [(1, 90), (1, 30)])

inputs = torch.LongTensor([i % 2 for i in range(100)])
input_values = feature_extractor(inputs, sampling_rate=feature_extractor.sampling_rate, return_tensors="pt")[
"input_values"
]
outs = list(chunk_iter(inputs, feature_extractor, 30, 5, 5))
self.assertEqual(len(outs), 5)
self.assertEqual([o["stride"] for o in outs], [(30, 0, 5), (30, 5, 5), (30, 5, 5), (30, 5, 5), (20, 5, 0)])
self.assertEqual([o["input_values"].shape for o in outs], [(1, 30), (1, 30), (1, 30), (1, 30), (1, 20)])
self.assertEqual([o["is_last"] for o in outs], [False, False, False, False, True])
# (0, 25)
self.assertEqual(nested_simplify(input_values[:, :30]), nested_simplify(outs[0]["input_values"]))
# (25, 45)
self.assertEqual(nested_simplify(input_values[:, 20:50]), nested_simplify(outs[1]["input_values"]))
# (45, 65)
self.assertEqual(nested_simplify(input_values[:, 40:70]), nested_simplify(outs[2]["input_values"]))
# (65, 85)
self.assertEqual(nested_simplify(input_values[:, 60:90]), nested_simplify(outs[3]["input_values"]))
# (85, 100)
self.assertEqual(nested_simplify(input_values[:, 80:100]), nested_simplify(outs[4]["input_values"]))


@require_torch
class ApplyStrideTest(unittest.TestCase):
def test_apply_stride(self):
tokens = torch.arange(10).long().reshape((2, 5))

# No stride
apply_stride(tokens, [(100, 0, 0), (100, 0, 0)])

expected = torch.arange(10).long().reshape((2, 5))
self.assertEqual(expected.tolist(), tokens.tolist())

def test_apply_stride_real_stride(self):
# Stride aligned
tokens = torch.arange(10).long().reshape((2, 5))
apply_stride(tokens, [(100, 20, 0), (100, 0, 20)])
self.assertEqual([[1, 1, 2, 3, 4], [5, 6, 7, 8, 8]], tokens.tolist())

# Stride rounded
tokens = torch.arange(10).long().reshape((2, 5))
apply_stride(tokens, [(100, 15, 0), (100, 0, 15)])
self.assertEqual([[1, 1, 2, 3, 4], [5, 6, 7, 8, 8]], tokens.tolist())

# No stride rounded
tokens = torch.arange(10).long().reshape((2, 5))
apply_stride(tokens, [(100, 5, 0), (100, 0, 5)])
self.assertEqual([[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]], tokens.tolist())

def test_apply_stride_with_padding(self):
# Stride aligned
tokens = torch.arange(10).long().reshape((2, 5))
apply_stride(tokens, [(100, 20, 0), (60, 0, 20)])
self.assertEqual([[1, 1, 2, 3, 4], [5, 6, 6, 6, 6]], tokens.tolist())
11 changes: 11 additions & 0 deletions tests/test_pipelines_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -584,3 +584,14 @@ def add(number, extra=0):

outputs = [item for item in dataset]
self.assertEqual(outputs, [[{"id": 2}, {"id": 3}], [{"id": 4}, {"id": 5}]])

# is_false Across batch
dummy_dataset = [{"id": [0, 1, 2], "is_last": [False, False, False]}, {"id": [3], "is_last": [True]}]

def add(number, extra=0):
return {"id": [i + extra for i in number["id"]], "is_last": number["is_last"]}

dataset = PipelinePackIterator(dummy_dataset, add, {"extra": 2}, loader_batch_size=3)

outputs = [item for item in dataset]
self.assertEqual(outputs, [[{"id": 2}, {"id": 3}, {"id": 4}, {"id": 5}]])