Skip to content

Commit bff3267

Browse files
Fix incorrect batching audio index calculation for Phi-4-Multimodal (#38103)
* fix Signed-off-by: Isotr0py <2037008807@qq.com> * add tests Signed-off-by: Isotr0py <2037008807@qq.com> * code format Signed-off-by: Isotr0py <2037008807@qq.com> * Update src/transformers/models/phi4_multimodal/feature_extraction_phi4_multimodal.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> --------- Signed-off-by: Isotr0py <2037008807@qq.com> Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
1 parent 9f0402b commit bff3267

File tree

2 files changed

+289
-1
lines changed

2 files changed

+289
-1
lines changed

src/transformers/models/phi4_multimodal/feature_extraction_phi4_multimodal.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -300,7 +300,7 @@ def _torch_extract_fbank_features(
300300
to_mask_batch_idxs = torch.arange(batch_size)[audio_lengths != audio_lengths.max()]
301301
if to_mask_batch_idxs.numel() > 0:
302302
batch_idxs_down = (audio_lengths[to_mask_batch_idxs] - self.win_length) // self.hop_length + 1
303-
batch_idxs_up = audio_lengths[to_mask_batch_idxs] // self.hop_length + 1
303+
batch_idxs_up = (audio_lengths[to_mask_batch_idxs] // self.hop_length) - 1
304304
offset_idx = batch_idxs_down.min()
305305
max_idx = batch_idxs_up.max()
306306

Lines changed: 288 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,288 @@
1+
# Copyright 2025 HuggingFace Inc.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
import itertools
17+
import os
18+
import random
19+
import tempfile
20+
import unittest
21+
22+
import numpy as np
23+
from datasets import load_dataset
24+
25+
from transformers import Phi4MultimodalFeatureExtractor
26+
from transformers.testing_utils import check_json_file_has_correct_format, require_torch
27+
from transformers.utils.import_utils import is_torch_available
28+
29+
from ...test_sequence_feature_extraction_common import SequenceFeatureExtractionTestMixin
30+
31+
32+
if is_torch_available():
33+
import torch
34+
35+
global_rng = random.Random()
36+
37+
38+
def floats_list(shape, scale=1.0, rng=None, name=None):
39+
"""Creates a random float32 tensor"""
40+
if rng is None:
41+
rng = global_rng
42+
43+
values = []
44+
for batch_idx in range(shape[0]):
45+
values.append([])
46+
for _ in range(shape[1]):
47+
values[-1].append(rng.random() * scale)
48+
49+
return values
50+
51+
52+
class Phi4MultimodalFeatureExtractionTester:
53+
def __init__(
54+
self,
55+
parent,
56+
batch_size=7,
57+
min_seq_length=400,
58+
max_seq_length=2000,
59+
feature_size=80,
60+
hop_length=160,
61+
win_length=400,
62+
padding_value=0.0,
63+
sampling_rate=16_000,
64+
return_attention_mask=False,
65+
do_normalize=True,
66+
):
67+
self.parent = parent
68+
self.batch_size = batch_size
69+
self.min_seq_length = min_seq_length
70+
self.max_seq_length = max_seq_length
71+
self.seq_length_diff = (self.max_seq_length - self.min_seq_length) // (self.batch_size - 1)
72+
self.padding_value = padding_value
73+
self.sampling_rate = sampling_rate
74+
self.return_attention_mask = return_attention_mask
75+
self.do_normalize = do_normalize
76+
self.feature_size = feature_size
77+
self.win_length = win_length
78+
self.hop_length = hop_length
79+
80+
def prepare_feat_extract_dict(self):
81+
return {
82+
"feature_size": self.feature_size,
83+
"hop_length": self.hop_length,
84+
"win_length": self.win_length,
85+
"padding_value": self.padding_value,
86+
"sampling_rate": self.sampling_rate,
87+
"return_attention_mask": self.return_attention_mask,
88+
"do_normalize": self.do_normalize,
89+
}
90+
91+
def prepare_inputs_for_common(self, equal_length=False, numpify=False):
92+
def _flatten(list_of_lists):
93+
return list(itertools.chain(*list_of_lists))
94+
95+
if equal_length:
96+
speech_inputs = [floats_list((self.max_seq_length, self.feature_size)) for _ in range(self.batch_size)]
97+
else:
98+
# make sure that inputs increase in size
99+
speech_inputs = [
100+
floats_list((x, self.feature_size))
101+
for x in range(self.min_seq_length, self.max_seq_length, self.seq_length_diff)
102+
]
103+
if numpify:
104+
speech_inputs = [np.asarray(x) for x in speech_inputs]
105+
return speech_inputs
106+
107+
108+
class Phi4MultimodalFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.TestCase):
109+
feature_extraction_class = Phi4MultimodalFeatureExtractor
110+
111+
def setUp(self):
112+
self.feat_extract_tester = Phi4MultimodalFeatureExtractionTester(self)
113+
114+
def test_feat_extract_from_and_save_pretrained(self):
115+
feat_extract_first = self.feature_extraction_class(**self.feat_extract_dict)
116+
117+
with tempfile.TemporaryDirectory() as tmpdirname:
118+
saved_file = feat_extract_first.save_pretrained(tmpdirname)[0]
119+
check_json_file_has_correct_format(saved_file)
120+
feat_extract_second = self.feature_extraction_class.from_pretrained(tmpdirname)
121+
122+
dict_first = feat_extract_first.to_dict()
123+
dict_second = feat_extract_second.to_dict()
124+
mel_1 = feat_extract_first.mel_filters
125+
mel_2 = feat_extract_second.mel_filters
126+
self.assertTrue(np.allclose(mel_1, mel_2))
127+
self.assertEqual(dict_first, dict_second)
128+
129+
def test_feat_extract_to_json_file(self):
130+
feat_extract_first = self.feature_extraction_class(**self.feat_extract_dict)
131+
132+
with tempfile.TemporaryDirectory() as tmpdirname:
133+
json_file_path = os.path.join(tmpdirname, "feat_extract.json")
134+
feat_extract_first.to_json_file(json_file_path)
135+
feat_extract_second = self.feature_extraction_class.from_json_file(json_file_path)
136+
137+
dict_first = feat_extract_first.to_dict()
138+
dict_second = feat_extract_second.to_dict()
139+
mel_1 = feat_extract_first.mel_filters
140+
mel_2 = feat_extract_second.mel_filters
141+
self.assertTrue(np.allclose(mel_1, mel_2))
142+
self.assertEqual(dict_first, dict_second)
143+
144+
def test_feat_extract_from_pretrained_kwargs(self):
145+
feat_extract_first = self.feature_extraction_class(**self.feat_extract_dict)
146+
147+
with tempfile.TemporaryDirectory() as tmpdirname:
148+
saved_file = feat_extract_first.save_pretrained(tmpdirname)[0]
149+
check_json_file_has_correct_format(saved_file)
150+
feat_extract_second = self.feature_extraction_class.from_pretrained(
151+
tmpdirname, feature_size=2 * self.feat_extract_dict["feature_size"]
152+
)
153+
154+
mel_1 = feat_extract_first.mel_filters
155+
mel_2 = feat_extract_second.mel_filters
156+
self.assertTrue(2 * mel_1.shape[1] == mel_2.shape[1])
157+
158+
def test_call(self):
159+
# Tests that all call wrap to encode_plus and batch_encode_plus
160+
feature_extractor = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
161+
# create three inputs of length 800, 1000, and 1200
162+
speech_inputs = [floats_list((1, x))[0] for x in range(800, 1400, 200)]
163+
np_speech_inputs = [np.asarray(speech_input) for speech_input in speech_inputs]
164+
pt_speech_inputs = [torch.tensor(speech_input) for speech_input in speech_inputs]
165+
166+
# Test feature size
167+
input_features = feature_extractor(np_speech_inputs, return_tensors="np").audio_input_features
168+
max_audio_len = (1200 - feature_extractor.win_length) // feature_extractor.hop_length + 1
169+
self.assertTrue(input_features.ndim == 3)
170+
self.assertTrue(input_features.shape[-1] == feature_extractor.feature_size)
171+
self.assertTrue(input_features.shape[-2] == max_audio_len)
172+
173+
# Test not batched input
174+
encoded_sequences_1 = feature_extractor(pt_speech_inputs[0], return_tensors="np").audio_input_features
175+
encoded_sequences_2 = feature_extractor(np_speech_inputs[0], return_tensors="np").audio_input_features
176+
self.assertTrue(np.allclose(encoded_sequences_1, encoded_sequences_2, atol=1e-3))
177+
178+
# Test batched
179+
encoded_sequences_1 = feature_extractor(pt_speech_inputs, return_tensors="np").audio_input_features
180+
encoded_sequences_2 = feature_extractor(np_speech_inputs, return_tensors="np").audio_input_features
181+
for enc_seq_1, enc_seq_2 in zip(encoded_sequences_1, encoded_sequences_2):
182+
self.assertTrue(np.allclose(enc_seq_1, enc_seq_2, atol=1e-3))
183+
184+
# Test 2-D numpy arrays are batched.
185+
speech_inputs = [floats_list((1, x))[0] for x in (800, 800, 800)]
186+
np_speech_inputs = np.asarray(speech_inputs)
187+
pt_speech_inputs = torch.tensor(speech_inputs)
188+
encoded_sequences_1 = feature_extractor(pt_speech_inputs, return_tensors="np").audio_input_features
189+
encoded_sequences_2 = feature_extractor(np_speech_inputs, return_tensors="np").audio_input_features
190+
for enc_seq_1, enc_seq_2 in zip(encoded_sequences_1, encoded_sequences_2):
191+
self.assertTrue(np.allclose(enc_seq_1, enc_seq_2, atol=1e-3))
192+
193+
@require_torch
194+
def test_double_precision_pad(self):
195+
import torch
196+
197+
feature_extractor = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
198+
np_speech_inputs = np.random.rand(100, 32).astype(np.float64)
199+
py_speech_inputs = np_speech_inputs.tolist()
200+
201+
for inputs in [py_speech_inputs, np_speech_inputs]:
202+
np_processed = feature_extractor.pad([{"audio_input_features": inputs}], return_tensors="np")
203+
self.assertTrue(np_processed.audio_input_features.dtype == np.float32)
204+
pt_processed = feature_extractor.pad([{"audio_input_features": inputs}], return_tensors="pt")
205+
self.assertTrue(pt_processed.audio_input_features.dtype == torch.float32)
206+
207+
def _load_datasamples(self, num_samples):
208+
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
209+
# automatic decoding with librispeech
210+
speech_samples = ds.sort("id").select(range(num_samples))[:num_samples]["audio"]
211+
212+
return [x["array"] for x in speech_samples]
213+
214+
@require_torch
215+
def test_torch_integration(self):
216+
# fmt: off
217+
EXPECTED_INPUT_FEATURES = torch.tensor(
218+
[
219+
6.5243, 7.2267, 8.0917, 8.0041, 6.8247, 6.3216, 5.9599, 5.6770,
220+
5.7441, 5.6138, 6.6793, 6.8597, 5.5375, 6.5330, 5.4880, 7.3280,
221+
9.0736, 9.7665, 9.8773, 10.0828, 10.0518, 10.1736, 10.0145, 9.2545,
222+
11.0495, 11.6518, 10.8654, 10.2293, 9.1045, 9.4819,
223+
]
224+
)
225+
# fmt: on
226+
227+
input_speech = self._load_datasamples(1)
228+
feature_extractor = Phi4MultimodalFeatureExtractor()
229+
input_features = feature_extractor(input_speech, return_tensors="pt").audio_input_features
230+
231+
self.assertEqual(input_features.shape, (1, 584, 80))
232+
torch.testing.assert_close(input_features[0, 0, :30], EXPECTED_INPUT_FEATURES, rtol=1e-4, atol=1e-4)
233+
234+
@unittest.mock.patch(
235+
"transformers.models.phi4_multimodal.feature_extraction_phi4_multimodal.is_torch_available", lambda: False
236+
)
237+
def test_numpy_integration(self):
238+
# fmt: off
239+
EXPECTED_INPUT_FEATURES = np.array(
240+
[
241+
6.5242944, 7.226712, 8.091721, 8.004097, 6.824679, 6.3216243,
242+
5.959894, 5.676975, 5.744051, 5.61384, 6.6793485, 6.8597484,
243+
5.5374746, 6.532976, 5.4879804, 7.3279905, 9.073576, 9.766463,
244+
9.877262, 10.082759, 10.051792, 10.173581, 10.0144825, 9.254548,
245+
11.049487, 11.651841, 10.865354, 10.229329, 9.104464, 9.481946,
246+
]
247+
)
248+
# fmt: on
249+
250+
input_speech = self._load_datasamples(1)
251+
feature_extractor = Phi4MultimodalFeatureExtractor()
252+
input_features = feature_extractor(input_speech, return_tensors="np").audio_input_features
253+
self.assertEqual(input_features.shape, (1, 584, 80))
254+
self.assertTrue(np.allclose(input_features[0, 0, :30], EXPECTED_INPUT_FEATURES, atol=1e-4))
255+
256+
@require_torch
257+
def test_torch_integration_batch(self):
258+
# fmt: off
259+
EXPECTED_INPUT_FEATURES = torch.tensor(
260+
[
261+
[
262+
6.5243, 7.2267, 8.0917, 8.0041, 6.8247, 6.3216, 5.9599, 5.6770,
263+
5.7441, 5.6138, 6.6793, 6.8597, 5.5375, 6.5330, 5.4880, 7.3280,
264+
9.0736, 9.7665, 9.8773, 10.0828, 10.0518, 10.1736, 10.0145, 9.2545,
265+
11.0495, 11.6518, 10.8654, 10.2293, 9.1045, 9.4819
266+
],
267+
[
268+
7.5105, 7.9453, 8.6161, 7.7666, 7.2572, 6.8823, 6.3242, 6.1899,
269+
6.9706, 8.0810, 7.3227, 5.8580, 5.4990, 7.7373, 8.5447, 7.7203,
270+
6.3230, 7.1995, 7.1463, 7.3153, 7.4054, 7.2855, 6.9396, 7.0255,
271+
7.3285, 7.2748, 8.0742, 7.3998, 6.4813, 6.7509
272+
],
273+
[
274+
7.7932, 8.1604, 8.7653, 8.2080, 7.2630, 6.4537, 4.8394, 6.3153,
275+
8.0207, 8.3379, 6.0896, 5.7369, 5.8601, 4.7598, 4.8850, 6.2529,
276+
3.9354, 6.1577, 7.9921, 9.6577, 10.1449, 9.1414, 9.3361, 9.0022,
277+
9.2533, 10.0548, 10.4372, 8.8550, 9.1266, 9.9013
278+
]
279+
]
280+
)
281+
# fmt: on
282+
283+
input_speech = self._load_datasamples(3)
284+
feature_extractor = Phi4MultimodalFeatureExtractor()
285+
input_features = feature_extractor(input_speech, return_tensors="pt").audio_input_features
286+
self.assertEqual(input_features.shape, (3, 1247, 80))
287+
print(input_features[:, 0, :30])
288+
torch.testing.assert_close(input_features[:, 0, :30], EXPECTED_INPUT_FEATURES, rtol=1e-4, atol=1e-4)

0 commit comments

Comments
 (0)