Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make chuking smartly (long files) work on asr ctc_with_lm. #15219

Merged
merged 7 commits into from
Jan 19, 2022
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 41 additions & 9 deletions src/transformers/pipelines/automatic_speech_recognition.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,14 +66,27 @@ def ffmpeg_read(bpayload: bytes, sampling_rate: int) -> np.array:
return audio


def apply_stride(tokens, stride):
max_token_n = tokens.shape[-1]
def audio_to_logits(tokens_or_logits, stride):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What exactly does the function do? Could we add some docstring? Also I don't understand the name audio_to_logits really

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's just change the stride number to go from audio space (10s at 16_000 means (160_000, 8_000, 8_000)) for instance. to logits_space (2333, 160, 160) for instance.

Do you think of a better name ? Doctstring could help a little.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I see. Maybe get_output_stride_from_input_stride(input_shape, stride) and directly pass the shape? Yeah think a little docstring can help here

# Shape is [B, SEQ] for tokens
# [B, SEQ, V] for logits
max_token_n = tokens_or_logits.shape[1]
max_input_n = max(input_n for input_n, _, _ in stride)
ratio = max_token_n / max_input_n
for i, (input_n, left, right) in enumerate(stride):
new_strides = []
for input_n, left, right in stride:
token_n = int(round(input_n * ratio))
left_token = int(round(left / input_n * token_n))
right_token = int(round((input_n - right) / input_n * token_n))
left = int(round(left / input_n * token_n))
right = int(round(right / input_n * token_n))
new_stride = (token_n, left, right)
new_strides.append(new_stride)
return new_strides


def apply_stride(tokens, stride):
new_stride = audio_to_logits(tokens, stride)
for i, (input_n, left, right) in enumerate(new_stride):
left_token = left
right_token = input_n - right
# This is CTC to preseve decoding, we need to duplicate
# next letter, and last letter

Expand Down Expand Up @@ -215,7 +228,7 @@ def preprocess(self, inputs, chunk_length_s=0, stride_length_s=None):
stride_left = int(round(stride_length_s[0] * self.feature_extractor.sampling_rate))
stride_right = int(round(stride_length_s[1] * self.feature_extractor.sampling_rate))

if self.type != "ctc":
if self.type not in {"ctc", "ctc_with_lm"}:
raise ValueError(
"`chunk_length_s` is only valid for CTC models, use other chunking options for other models"
)
Expand Down Expand Up @@ -244,9 +257,15 @@ def _forward(self, model_inputs):
)
out = {"tokens": tokens}
elif self.type == "ctc_with_lm":
stride = model_inputs.pop("stride", None)
outputs = self.model(**model_inputs)
out = {"logits": outputs.logits}

logits = outputs.logits
out = {"logits": logits}
if stride is not None:
# Send stride to `postprocess`.
# it needs to be handled there where
# the pieces are to be concatenated.
out["stride"] = audio_to_logits(logits, stride)
elif self.type == "ctc":
stride = model_inputs.pop("stride", None)
outputs = self.model(**model_inputs)
Expand All @@ -266,7 +285,20 @@ def _forward(self, model_inputs):

def postprocess(self, model_outputs):
if self.type == "ctc_with_lm":
logits = np.concatenate([outputs["logits"].numpy() for outputs in model_outputs], axis=1)
final_logits = []
for outputs in model_outputs:
logits = outputs["logits"].numpy()
stride = outputs.get("stride", None)
if stride is not None:
total_n, left, right = stride
# Total_n might be < logits.shape[1]
# because of padding, that's why
# we need to reconstruct this information
# This won't work with left padding (which doesn't exist right now)
right_n = total_n - right
logits = logits[:, left:right_n]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice, yes I think that's the only approach that'll work right now

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Padding is not possible here sadly

Copy link
Contributor Author

@Narsil Narsil Jan 19, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Padding works, it's handled by the batching mecanism. All I mentionned here is why we need so much information (logits might get padded while stride no).

We can make it work, relatively trivially for left padding.

if self.feature_extractor.padding_side == "left":
    left = logits.shape[1] - total_n + left_n
    right = logits.shaoe[1] - total_n + right_n
else:
   left = left_n
   right_n = total_n - right_n

Just thought this was overly complex since left padding doesn't seem likely here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I see - ok yeah

final_logits.append(logits)
logits = np.concatenate(final_logits, axis=1)
logits = logits.squeeze(0)
text = self.decoder.decode_beams(logits)[0][0]
else:
Expand Down
34 changes: 34 additions & 0 deletions tests/test_pipelines_automatic_speech_recognition.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,22 @@ def test_chunking_fast(self):
self.assertEqual(output, [{"text": ANY(str)}])
self.assertEqual(output[0]["text"][:6], "ZBT ZC")

@require_torch
def test_chunking_fast_with_lm(self):
speech_recognizer = pipeline(
model="hf-internal-testing/processor_with_lm",
chunk_length_s=10.0,
)

ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation").sort("id")
audio = ds[40]["audio"]["array"]

n_repeats = 2
audio_tiled = np.tile(audio, n_repeats)
output = speech_recognizer([audio_tiled], batch_size=2)
self.assertEqual(output, [{"text": ANY(str)}])
self.assertEqual(output[0]["text"][:6], "<s> <s")

@require_torch
@require_pyctcdecode
def test_with_lm_fast(self):
Expand Down Expand Up @@ -340,6 +356,24 @@ def test_chunking(self):
expected = [{"text": expected_text.strip()}]
self.assertEqual(output, expected)

@require_torch
@slow
def test_chunking_with_lm(self):
speech_recognizer = pipeline(
task="automatic-speech-recognition",
model="patrickvonplaten/wav2vec2-base-100h-with-lm",
chunk_length_s=10.0,
)
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation").sort("id")
audio = ds[40]["audio"]["array"]

n_repeats = 10
audio = np.tile(audio, n_repeats)
output = speech_recognizer([audio], batch_size=2)
expected_text = "A MAN SAID TO THE UNIVERSE SIR I EXIST " * n_repeats
expected = [{"text": expected_text.strip()}]
self.assertEqual(output, expected)

@require_torch
def test_chunk_iterator(self):
feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-base-960h")
Expand Down