From 36db690fa4e615a32c7c2c1f2ab303418bfdfd28 Mon Sep 17 00:00:00 2001 From: Connor Henderson Date: Mon, 13 Feb 2023 17:20:56 -0500 Subject: [PATCH 1/7] fix: Change is_last chunk calc and add conditional break --- .../pipelines/automatic_speech_recognition.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/src/transformers/pipelines/automatic_speech_recognition.py b/src/transformers/pipelines/automatic_speech_recognition.py index 08c568e78a9c..8d9f5903a94a 100644 --- a/src/transformers/pipelines/automatic_speech_recognition.py +++ b/src/transformers/pipelines/automatic_speech_recognition.py @@ -56,14 +56,14 @@ def rescale_stride(stride, ratio): def chunk_iter(inputs, feature_extractor, chunk_len, stride_left, stride_right, rescale=True, dtype=None): inputs_len = inputs.shape[0] step = chunk_len - stride_left - stride_right - for i in range(0, inputs_len, step): - # add start and end paddings to the chunk - chunk = inputs[i : i + chunk_len] + for chunk_start_idx in range(0, inputs_len, step): + chunk_end_idx = chunk_start_idx + chunk_len + chunk = inputs[chunk_start_idx : chunk_end_idx] processed = feature_extractor(chunk, sampling_rate=feature_extractor.sampling_rate, return_tensors="pt") if dtype is not None: processed = processed.to(dtype=dtype) - _stride_left = 0 if i == 0 else stride_left - is_last = i + step + stride_left >= inputs_len + _stride_left = 0 if chunk_start_idx == 0 else stride_left + is_last = chunk_end_idx >= inputs_len _stride_right = 0 if is_last else stride_right chunk_len = chunk.shape[0] @@ -77,6 +77,8 @@ def chunk_iter(inputs, feature_extractor, chunk_len, stride_left, stride_right, stride = rescale_stride([stride], ratio)[0] if chunk.shape[0] > _stride_left: yield {"is_last": is_last, "stride": stride, **processed} + if is_last: + break def _find_timestamp_sequence(sequences, tokenizer, feature_extractor, max_source_positions): From 750c914e9d3ea6b65fd419efa8a82c59d9df2e35 Mon Sep 17 00:00:00 2001 From: Connor Henderson Date: Mon, 13 Feb 2023 17:27:04 -0500 Subject: [PATCH 2/7] format fix --- src/transformers/pipelines/automatic_speech_recognition.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/pipelines/automatic_speech_recognition.py b/src/transformers/pipelines/automatic_speech_recognition.py index 8d9f5903a94a..d1f21ebbb555 100644 --- a/src/transformers/pipelines/automatic_speech_recognition.py +++ b/src/transformers/pipelines/automatic_speech_recognition.py @@ -58,7 +58,7 @@ def chunk_iter(inputs, feature_extractor, chunk_len, stride_left, stride_right, step = chunk_len - stride_left - stride_right for chunk_start_idx in range(0, inputs_len, step): chunk_end_idx = chunk_start_idx + chunk_len - chunk = inputs[chunk_start_idx : chunk_end_idx] + chunk = inputs[chunk_start_idx:chunk_end_idx] processed = feature_extractor(chunk, sampling_rate=feature_extractor.sampling_rate, return_tensors="pt") if dtype is not None: processed = processed.to(dtype=dtype) From 0440ddf9961b4fe7a014b9eb4ed0ae628aa2310d Mon Sep 17 00:00:00 2001 From: Connor Henderson Date: Mon, 13 Feb 2023 18:46:09 -0500 Subject: [PATCH 3/7] account for 0 and full stride_rights, add comment --- src/transformers/pipelines/automatic_speech_recognition.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/transformers/pipelines/automatic_speech_recognition.py b/src/transformers/pipelines/automatic_speech_recognition.py index d1f21ebbb555..2780355d953b 100644 --- a/src/transformers/pipelines/automatic_speech_recognition.py +++ b/src/transformers/pipelines/automatic_speech_recognition.py @@ -63,7 +63,8 @@ def chunk_iter(inputs, feature_extractor, chunk_len, stride_left, stride_right, if dtype is not None: processed = processed.to(dtype=dtype) _stride_left = 0 if chunk_start_idx == 0 else stride_left - is_last = chunk_end_idx >= inputs_len + # all right strides must be full, otherwise it is the last item + is_last = chunk_end_idx > inputs_len if stride_right > 0 else chunk_end_idx >= inputs_len _stride_right = 0 if is_last else stride_right chunk_len = chunk.shape[0] From 5ebeceb88b54728fdfb7fba5aca548083d923542 Mon Sep 17 00:00:00 2001 From: Connor Henderson Date: Mon, 13 Feb 2023 19:06:03 -0500 Subject: [PATCH 4/7] add new test --- .../pipelines/test_pipelines_automatic_speech_recognition.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tests/pipelines/test_pipelines_automatic_speech_recognition.py b/tests/pipelines/test_pipelines_automatic_speech_recognition.py index 2bda09fe00c7..8dfbc806c040 100644 --- a/tests/pipelines/test_pipelines_automatic_speech_recognition.py +++ b/tests/pipelines/test_pipelines_automatic_speech_recognition.py @@ -1109,6 +1109,11 @@ def test_chunk_iterator_stride(self): self.assertEqual(len(outs), 2) self.assertEqual([o["stride"] for o in outs], [(90, 0, 0), (30, 20, 0)]) self.assertEqual([o["input_values"].shape for o in outs], [(1, 90), (1, 30)]) + + outs = list(chunk_iter(inputs, feature_extractor, 36, 6, 6, ratio)) + self.assertEqual(len(outs), 4) + self.assertEqual([o["stride"] for o in outs], [(36, 0, 6), (36, 6, 6), (36, 6, 6), (28, 6, 0)]) + self.assertEqual([o["input_values"].shape for o in outs], [(1, 36), (1, 36), (1, 36), (1, 28)]) inputs = torch.LongTensor([i % 2 for i in range(100)]) input_values = feature_extractor(inputs, sampling_rate=feature_extractor.sampling_rate, return_tensors="pt")[ From b85c1507ff89d2e87cd9cd629ec50fe055ad27d2 Mon Sep 17 00:00:00 2001 From: Connor Henderson Date: Mon, 13 Feb 2023 21:04:13 -0500 Subject: [PATCH 5/7] make style --- tests/pipelines/test_pipelines_automatic_speech_recognition.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/pipelines/test_pipelines_automatic_speech_recognition.py b/tests/pipelines/test_pipelines_automatic_speech_recognition.py index 8dfbc806c040..5b29d29e20fe 100644 --- a/tests/pipelines/test_pipelines_automatic_speech_recognition.py +++ b/tests/pipelines/test_pipelines_automatic_speech_recognition.py @@ -1109,7 +1109,7 @@ def test_chunk_iterator_stride(self): self.assertEqual(len(outs), 2) self.assertEqual([o["stride"] for o in outs], [(90, 0, 0), (30, 20, 0)]) self.assertEqual([o["input_values"].shape for o in outs], [(1, 90), (1, 30)]) - + outs = list(chunk_iter(inputs, feature_extractor, 36, 6, 6, ratio)) self.assertEqual(len(outs), 4) self.assertEqual([o["stride"] for o in outs], [(36, 0, 6), (36, 6, 6), (36, 6, 6), (28, 6, 0)]) From 023996547c6c0cffd6263ac4a369d15cf8492acc Mon Sep 17 00:00:00 2001 From: Connor Henderson Date: Mon, 20 Feb 2023 23:26:08 -0500 Subject: [PATCH 6/7] update slow whisper asr test timestamps --- .../pipelines/test_pipelines_automatic_speech_recognition.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/pipelines/test_pipelines_automatic_speech_recognition.py b/tests/pipelines/test_pipelines_automatic_speech_recognition.py index 5b29d29e20fe..7f452466ba4a 100644 --- a/tests/pipelines/test_pipelines_automatic_speech_recognition.py +++ b/tests/pipelines/test_pipelines_automatic_speech_recognition.py @@ -548,11 +548,11 @@ def test_whisper_timestamp_prediction(self): }, { "text": " the thousands of spectators, retrievality is not worth thinking about.", - "timestamp": (19.6, 24.98), + "timestamp": (19.6, 26.66), }, { "text": " His instant panic was followed by a small, sharp blow high on his chest.", - "timestamp": (24.98, 30.98), + "timestamp": (26.66, 31.060000000000002), }, ], "text": ( From a87c9f11494b8c28788bb5125076f0780ec3c648 Mon Sep 17 00:00:00 2001 From: Connor Henderson Date: Thu, 23 Feb 2023 13:39:36 -0500 Subject: [PATCH 7/7] use nested_simplify on output and round timestamp to hundreths place --- .../pipelines/test_pipelines_automatic_speech_recognition.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/pipelines/test_pipelines_automatic_speech_recognition.py b/tests/pipelines/test_pipelines_automatic_speech_recognition.py index 7f452466ba4a..f5b0f78dff47 100644 --- a/tests/pipelines/test_pipelines_automatic_speech_recognition.py +++ b/tests/pipelines/test_pipelines_automatic_speech_recognition.py @@ -526,7 +526,7 @@ def test_whisper_timestamp_prediction(self): output = pipe(array, chunk_length_s=10) self.assertDictEqual( - output, + nested_simplify(output), { "chunks": [ {"text": " A man said to the universe, Sir, I exist.", "timestamp": (0.0, 5.5)}, @@ -552,7 +552,7 @@ def test_whisper_timestamp_prediction(self): }, { "text": " His instant panic was followed by a small, sharp blow high on his chest.", - "timestamp": (26.66, 31.060000000000002), + "timestamp": (26.66, 31.06), }, ], "text": (