Skip to content

Commit 1f49bd9

Browse files
committed
Add save function
1 parent 62ff088 commit 1f49bd9

File tree

10 files changed

+728
-1
lines changed

10 files changed

+728
-1
lines changed

test/sox_io_backend/sox_utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,10 @@ def gen_audio_file(
2929
command = [
3030
'sox',
3131
'-V', # verbose
32+
]
33+
if bit_depth is not None:
34+
command += ['--bits', str(bit_depth)]
35+
command += [
3236
'--rate', str(sample_rate),
3337
'--null', # no input
3438
'--channels', str(num_channels),

test/sox_io_backend/test_roundtrip.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
import itertools
2+
3+
from torchaudio.backend import sox_io_backend
4+
from parameterized import parameterized
5+
6+
from ..common_utils import (
7+
TempDirMixin,
8+
PytorchTestCase,
9+
skipIfNoExec,
10+
skipIfNoExtension,
11+
)
12+
from .common import (
13+
get_test_name,
14+
get_wav_data,
15+
)
16+
from . import sox_utils
17+
18+
19+
@skipIfNoExec('sox')
20+
@skipIfNoExtension
21+
class TestRoundTripIO(TempDirMixin, PytorchTestCase):
22+
"""save/load round trip should not degrade data for lossless formats"""
23+
@parameterized.expand(list(itertools.product(
24+
['float32', 'int32', 'int16', 'uint8'],
25+
[8000, 16000],
26+
[1, 2],
27+
[False, True]
28+
)), name_func=get_test_name)
29+
def test_roundtrip_wav(self, dtype, sample_rate, num_channels, normalize):
30+
"""save/load round trip should not degrade data for wav formats"""
31+
original = get_wav_data(dtype, num_channels, normalize=normalize)
32+
data = original
33+
for i in range(10):
34+
path = self.get_temp_path(f'{dtype}_{sample_rate}_{num_channels}_{normalize}_{i}.wav')
35+
sox_io_backend.save(path, data, sample_rate)
36+
data, sr = sox_io_backend.load(path, normalize=normalize)
37+
assert sr == sample_rate
38+
self.assertEqual(original, data)
39+
40+
@parameterized.expand(list(itertools.product(
41+
[8000, 16000],
42+
[1, 2],
43+
list(range(9)),
44+
)), name_func=get_test_name)
45+
def test_roundtrip_flac(self, sample_rate, num_channels, compression_level):
46+
"""save/load round trip should not degrade data for flac formats"""
47+
path = self.get_temp_path(f'{sample_rate}_{num_channels}_{compression_level}.flac')
48+
sox_utils.gen_audio_file(path, sample_rate, num_channels, compression=compression_level)
49+
original = sox_io_backend.load(path)[0]
50+
51+
data = original
52+
for i in range(10):
53+
path = self.get_temp_path(f'{sample_rate}_{num_channels}_{compression_level}_{i}.flac')
54+
sox_io_backend.save(path, data, sample_rate)
55+
data, sr = sox_io_backend.load(path)
56+
assert sr == sample_rate
57+
self.assertEqual(original, data)

test/sox_io_backend/test_save.py

Lines changed: 313 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,313 @@
1+
import itertools
2+
3+
from torchaudio.backend import sox_io_backend
4+
from parameterized import parameterized
5+
6+
from ..common_utils import (
7+
TempDirMixin,
8+
PytorchTestCase,
9+
skipIfNoExec,
10+
skipIfNoExtension,
11+
)
12+
from .common import (
13+
get_test_name,
14+
get_wav_data,
15+
load_wav,
16+
)
17+
from . import sox_utils
18+
19+
20+
class SaveTestBase(TempDirMixin, PytorchTestCase):
21+
def assert_wav(self, dtype, sample_rate, num_channels, num_frames):
22+
"""`sox_io_backend.save` can save wav format."""
23+
path = self.get_temp_path(f'test_wav_{dtype}_{sample_rate}_{num_channels}.wav')
24+
expected = get_wav_data(dtype, num_channels, num_frames=num_frames)
25+
sox_io_backend.save(path, expected, sample_rate)
26+
found = load_wav(path)[0]
27+
self.assertEqual(found, expected)
28+
29+
def assert_mp3(self, sample_rate, num_channels, bit_rate, duration):
30+
"""`sox_io_backend.save` can save mp3 format.
31+
32+
mp3 encoding introduces delay and boundary effects so
33+
we convert the resulting mp3 to wav and compare the results there
34+
35+
|
36+
| 1. Generate original wav with Sox
37+
|
38+
v
39+
-------------- wav ----------------
40+
| |
41+
| 2.1. load with scipy | 3.1. Convert to mp3 with Sox
42+
| then save with torchaudio |
43+
v v
44+
mp3 mp3
45+
| |
46+
| 2.2. Convert to wav with Sox | 3.2. Convert to wav with Sox
47+
| |
48+
v v
49+
wav wav
50+
| |
51+
| 2.3. load with scipy | 3.3. load with scipy
52+
| |
53+
v v
54+
tensor -------> compare <--------- tensor
55+
56+
"""
57+
src_path = self.get_temp_path(f'test_mp3_{sample_rate}_{num_channels}_{bit_rate}_{duration}.wav')
58+
mp3_path = f'{src_path}.mp3'
59+
wav_path = f'{mp3_path}.wav'
60+
mp3_path_sox = f'{src_path}.sox.mp3'
61+
wav_path_sox = f'{mp3_path_sox}.wav'
62+
63+
# 1. Generate original wav
64+
sox_utils.gen_audio_file(
65+
src_path, sample_rate, num_channels,
66+
bit_depth=32,
67+
encoding='floating-point',
68+
duration=duration,
69+
)
70+
# 2.1. Convert the original wav to mp3 with torchaudio
71+
sox_io_backend.save(
72+
mp3_path, load_wav(src_path)[0], sample_rate, compression=bit_rate)
73+
# 2.2. Convert the mp3 to wav with Sox
74+
sox_utils.convert_audio_file(mp3_path, wav_path)
75+
# 2.3. Load
76+
found = load_wav(wav_path)[0]
77+
78+
# 3.1. Convert the original wav to mp3 with SoX
79+
sox_utils.convert_audio_file(src_path, mp3_path_sox, compression=bit_rate)
80+
# 3.2. Convert the mp3 to wav with Sox
81+
sox_utils.convert_audio_file(mp3_path_sox, wav_path_sox)
82+
# 3.3. Load
83+
expected = load_wav(wav_path_sox)[0]
84+
85+
self.assertEqual(found, expected)
86+
87+
def assert_flac(self, sample_rate, num_channels, compression_level, duration):
88+
"""`sox_io_backend.save` can save flac format.
89+
90+
This test takes the same strategy as mp3 to compare the result
91+
"""
92+
src_path = self.get_temp_path(f'test_flac_{sample_rate}_{num_channels}_{compression_level}_{duration}.wav')
93+
flac_path = f'{src_path}.flac'
94+
wav_path = f'{flac_path}.wav'
95+
flac_path_sox = f'{src_path}.sox.flac'
96+
wav_path_sox = f'{flac_path_sox}.wav'
97+
98+
# 1. Generate original wav
99+
sox_utils.gen_audio_file(
100+
src_path, sample_rate, num_channels,
101+
bit_depth=32,
102+
encoding='floating-point',
103+
duration=duration,
104+
)
105+
# 2.1. Convert the original wav to flac with torchaudio
106+
sox_io_backend.save(
107+
flac_path, load_wav(src_path)[0], sample_rate, compression=compression_level)
108+
# 2.2. Convert the flac to wav with Sox
109+
# converting to 32 bit because flac file has 24 bit depth which scipy cannot handle.
110+
sox_utils.convert_audio_file(flac_path, wav_path, bit_depth=32)
111+
# 2.3. Load
112+
found = load_wav(wav_path)[0]
113+
114+
# 3.1. Convert the original wav to flac with SoX
115+
sox_utils.convert_audio_file(src_path, flac_path_sox, compression=compression_level)
116+
# 3.2. Convert the flac to wav with Sox
117+
# converting to 32 bit because flac file has 24 bit depth which scipy cannot handle.
118+
sox_utils.convert_audio_file(flac_path_sox, wav_path_sox, bit_depth=32)
119+
# 3.3. Load
120+
expected = load_wav(wav_path_sox)[0]
121+
122+
self.assertEqual(found, expected)
123+
124+
def _assert_vorbis(self, sample_rate, num_channels, quality_level, duration):
125+
"""`sox_io_backend.save` can save vorbis format.
126+
127+
This test takes the same strategy as mp3 to compare the result
128+
"""
129+
src_path = self.get_temp_path(f'test_vorbis_{sample_rate}_{num_channels}_{quality_level}_{duration}.wav')
130+
vorbis_path = f'{src_path}.vorbis'
131+
wav_path = f'{vorbis_path}.wav'
132+
vorbis_path_sox = f'{src_path}.sox.vorbis'
133+
wav_path_sox = f'{vorbis_path_sox}.wav'
134+
135+
# 1. Generate original wav
136+
sox_utils.gen_audio_file(
137+
src_path, sample_rate, num_channels,
138+
bit_depth=16,
139+
encoding='signed-integer',
140+
duration=duration,
141+
)
142+
# 2.1. Convert the original wav to vorbis with torchaudio
143+
sox_io_backend.save(
144+
vorbis_path, load_wav(src_path)[0], sample_rate, compression=quality_level)
145+
# 2.2. Convert the vorbis to wav with Sox
146+
sox_utils.convert_audio_file(vorbis_path, wav_path)
147+
# 2.3. Load
148+
found = load_wav(wav_path)[0]
149+
150+
# 3.1. Convert the original wav to vorbis with SoX
151+
sox_utils.convert_audio_file(src_path, vorbis_path_sox, compression=quality_level)
152+
# 3.2. Convert the vorbis to wav with Sox
153+
sox_utils.convert_audio_file(vorbis_path_sox, wav_path_sox)
154+
# 3.3. Load
155+
expected = load_wav(wav_path_sox)[0]
156+
157+
# sox's vorbis encoding has some randomness, which cause small number of samples yields
158+
# higher descrepency than the others.
159+
# so we allow small portions of data to be outside of absolute torelance.
160+
atol = 1.0e-4
161+
max_failure_allowed = 0.05 # this percent of samples are allowed to outside of atol.
162+
failure_ratio = ((found - expected).abs() > atol).sum().item() / found.numel()
163+
if failure_ratio > max_failure_allowed:
164+
# it's failed and this will give a better error message.
165+
self.assertEqual(found, expected, atol=atol, rtol=1.3e-6)
166+
167+
def assert_vorbis(self, *args, **kwargs):
168+
# sox's vorbis encoding has some randomness, so we run tests multiple time
169+
max_retry = 5
170+
error = None
171+
for _ in range(max_retry):
172+
try:
173+
self._assert_vorbis(*args, **kwargs)
174+
break
175+
except AssertionError as e:
176+
error = e
177+
else:
178+
raise error
179+
180+
181+
@skipIfNoExec('sox')
182+
@skipIfNoExtension
183+
class TestSave(SaveTestBase):
184+
@parameterized.expand(list(itertools.product(
185+
['float32', 'int32', 'int16', 'uint8'],
186+
[8000, 16000],
187+
[1, 2],
188+
)), name_func=get_test_name)
189+
def test_wav(self, dtype, sample_rate, num_channels):
190+
"""`sox_io_backend.save` can save wav format."""
191+
self.assert_wav(dtype, sample_rate, num_channels, num_frames=None)
192+
193+
@parameterized.expand(list(itertools.product(
194+
['float32'],
195+
[16000],
196+
[2],
197+
)), name_func=get_test_name)
198+
def test_wav_large(self, dtype, sample_rate, num_channels):
199+
"""`sox_io_backend.save` can save large wav file."""
200+
two_hours = 2 * 60 * 60 * sample_rate
201+
self.assert_wav(dtype, sample_rate, num_channels, num_frames=two_hours)
202+
203+
@parameterized.expand(list(itertools.product(
204+
['float32', 'int32', 'int16', 'uint8'],
205+
[4, 8, 16, 32],
206+
)), name_func=get_test_name)
207+
def test_multiple_channels(self, dtype, num_channels):
208+
"""`sox_io_backend.save` can save wav with more than 2 channels."""
209+
sample_rate = 8000
210+
self.assert_wav(dtype, sample_rate, num_channels, num_frames=None)
211+
212+
@parameterized.expand(list(itertools.product(
213+
[8000, 16000],
214+
[1, 2],
215+
[96, 128, 160, 192, 224, 256, 320],
216+
)), name_func=get_test_name)
217+
def test_mp3(self, sample_rate, num_channels, bit_rate):
218+
"""`sox_io_backend.save` can save mp3 format."""
219+
self.assert_mp3(sample_rate, num_channels, bit_rate, duration=1)
220+
221+
@parameterized.expand(list(itertools.product(
222+
[16000],
223+
[2],
224+
[128],
225+
)), name_func=get_test_name)
226+
def test_mp3_large(self, sample_rate, num_channels, bit_rate):
227+
"""`sox_io_backend.save` can save large mp3 file."""
228+
two_hours = 2 * 60 * 60
229+
self.assert_mp3(sample_rate, num_channels, bit_rate, duration=two_hours)
230+
231+
@parameterized.expand(list(itertools.product(
232+
[8000, 16000],
233+
[1, 2],
234+
list(range(9)),
235+
)), name_func=get_test_name)
236+
def test_flac(self, sample_rate, num_channels, compression_level):
237+
"""`sox_io_backend.save` can save flac format."""
238+
self.assert_flac(sample_rate, num_channels, compression_level, duration=1)
239+
240+
@parameterized.expand(list(itertools.product(
241+
[16000],
242+
[2],
243+
[0],
244+
)), name_func=get_test_name)
245+
def test_flac_large(self, sample_rate, num_channels, compression_level):
246+
"""`sox_io_backend.save` can save large flac file."""
247+
two_hours = 2 * 60 * 60
248+
self.assert_flac(sample_rate, num_channels, compression_level, duration=two_hours)
249+
250+
@parameterized.expand(list(itertools.product(
251+
[8000, 16000],
252+
[1, 2],
253+
[-1, 0, 1, 2, 3, 3.6, 5, 10],
254+
)), name_func=get_test_name)
255+
def test_vorbis(self, sample_rate, num_channels, quality_level):
256+
"""`sox_io_backend.save` can save vorbis format."""
257+
self.assert_vorbis(sample_rate, num_channels, quality_level, duration=1)
258+
259+
# note: torchaudio can load large vorbis file, but cannot save large volbis file
260+
# the following test causes Segmentation fault
261+
#
262+
'''
263+
@parameterized.expand(list(itertools.product(
264+
[16000],
265+
[2],
266+
[10],
267+
)), name_func=get_test_name)
268+
def test_vorbis_large(self, sample_rate, num_channels, quality_level):
269+
"""`sox_io_backend.save` can save large vorbis file correctly."""
270+
two_hours = 2 * 60 * 60
271+
self.assert_vorbis(sample_rate, num_channels, quality_level, two_hours)
272+
'''
273+
274+
275+
@skipIfNoExec('sox')
276+
@skipIfNoExtension
277+
class TestSaveParams(TempDirMixin, PytorchTestCase):
278+
"""Test the correctness of optional parameters of `sox_io_backend.save`"""
279+
@parameterized.expand([(True, ), (False, )], name_func=get_test_name)
280+
def test_channels_first(self, channels_first):
281+
"""channels_first swaps axes"""
282+
path = self.get_temp_path('test_channel_first_{channels_first}.wav')
283+
data = get_wav_data('int32', 2, channels_first=channels_first)
284+
sox_io_backend.save(
285+
path, data, 8000, channels_first=channels_first)
286+
found = load_wav(path)[0]
287+
expected = data if channels_first else data.transpose(1, 0)
288+
self.assertEqual(found, expected)
289+
290+
@parameterized.expand([
291+
'float32', 'int32', 'int16', 'uint8'
292+
], name_func=get_test_name)
293+
def test_noncontiguous(self, dtype):
294+
"""Noncontiguous tensors are saved correctly"""
295+
path = self.get_temp_path('test_uncontiguous_{dtype}.wav')
296+
expected = get_wav_data(dtype, 4)[::2, ::2]
297+
assert not expected.is_contiguous()
298+
sox_io_backend.save(path, expected, 8000)
299+
found = load_wav(path)[0]
300+
self.assertEqual(found, expected)
301+
302+
@parameterized.expand([
303+
'float32', 'int32', 'int16', 'uint8',
304+
])
305+
def test_tensor_preserve(self, dtype):
306+
"""save function should not alter Tensor"""
307+
path = self.get_temp_path(f'test_preserve_{dtype}.wav')
308+
expected = get_wav_data(dtype, 4)[::2, ::2]
309+
310+
data = expected.clone()
311+
sox_io_backend.save(path, data, 8000)
312+
313+
self.assertEqual(data, expected)

0 commit comments

Comments
 (0)