|
| 1 | +import itertools |
| 2 | + |
| 3 | +from torchaudio.backend import sox_io_backend |
| 4 | +from parameterized import parameterized |
| 5 | + |
| 6 | +from ..common_utils import ( |
| 7 | + TempDirMixin, |
| 8 | + PytorchTestCase, |
| 9 | + skipIfNoExec, |
| 10 | + skipIfNoExtension, |
| 11 | +) |
| 12 | +from .common import ( |
| 13 | + get_test_name, |
| 14 | + get_wav_data, |
| 15 | + load_wav, |
| 16 | +) |
| 17 | +from . import sox_utils |
| 18 | + |
| 19 | + |
| 20 | +class SaveTestBase(TempDirMixin, PytorchTestCase): |
| 21 | + def assert_wav(self, dtype, sample_rate, num_channels, num_frames): |
| 22 | + """`sox_io_backend.save` can save wav format.""" |
| 23 | + path = self.get_temp_path(f'test_wav_{dtype}_{sample_rate}_{num_channels}.wav') |
| 24 | + expected = get_wav_data(dtype, num_channels, num_frames=num_frames) |
| 25 | + sox_io_backend.save(path, expected, sample_rate) |
| 26 | + found = load_wav(path)[0] |
| 27 | + self.assertEqual(found, expected) |
| 28 | + |
| 29 | + def assert_mp3(self, sample_rate, num_channels, bit_rate, duration): |
| 30 | + """`sox_io_backend.save` can save mp3 format. |
| 31 | +
|
| 32 | + mp3 encoding introduces delay and boundary effects so |
| 33 | + we convert the resulting mp3 to wav and compare the results there |
| 34 | +
|
| 35 | + | |
| 36 | + | 1. Generate original wav with Sox |
| 37 | + | |
| 38 | + v |
| 39 | + -------------- wav ---------------- |
| 40 | + | | |
| 41 | + | 2.1. load with scipy | 3.1. Convert to mp3 with Sox |
| 42 | + | then save with torchaudio | |
| 43 | + v v |
| 44 | + mp3 mp3 |
| 45 | + | | |
| 46 | + | 2.2. Convert to wav with Sox | 3.2. Convert to wav with Sox |
| 47 | + | | |
| 48 | + v v |
| 49 | + wav wav |
| 50 | + | | |
| 51 | + | 2.3. load with scipy | 3.3. load with scipy |
| 52 | + | | |
| 53 | + v v |
| 54 | + tensor -------> compare <--------- tensor |
| 55 | +
|
| 56 | + """ |
| 57 | + src_path = self.get_temp_path(f'test_mp3_{sample_rate}_{num_channels}_{bit_rate}_{duration}.wav') |
| 58 | + mp3_path = f'{src_path}.mp3' |
| 59 | + wav_path = f'{mp3_path}.wav' |
| 60 | + mp3_path_sox = f'{src_path}.sox.mp3' |
| 61 | + wav_path_sox = f'{mp3_path_sox}.wav' |
| 62 | + |
| 63 | + # 1. Generate original wav |
| 64 | + sox_utils.gen_audio_file( |
| 65 | + src_path, sample_rate, num_channels, |
| 66 | + bit_depth=32, |
| 67 | + encoding='floating-point', |
| 68 | + duration=duration, |
| 69 | + ) |
| 70 | + # 2.1. Convert the original wav to mp3 with torchaudio |
| 71 | + sox_io_backend.save( |
| 72 | + mp3_path, load_wav(src_path)[0], sample_rate, compression=bit_rate) |
| 73 | + # 2.2. Convert the mp3 to wav with Sox |
| 74 | + sox_utils.convert_audio_file(mp3_path, wav_path) |
| 75 | + # 2.3. Load |
| 76 | + found = load_wav(wav_path)[0] |
| 77 | + |
| 78 | + # 3.1. Convert the original wav to mp3 with SoX |
| 79 | + sox_utils.convert_audio_file(src_path, mp3_path_sox, compression=bit_rate) |
| 80 | + # 3.2. Convert the mp3 to wav with Sox |
| 81 | + sox_utils.convert_audio_file(mp3_path_sox, wav_path_sox) |
| 82 | + # 3.3. Load |
| 83 | + expected = load_wav(wav_path_sox)[0] |
| 84 | + |
| 85 | + self.assertEqual(found, expected) |
| 86 | + |
| 87 | + def assert_flac(self, sample_rate, num_channels, compression_level, duration): |
| 88 | + """`sox_io_backend.save` can save flac format. |
| 89 | +
|
| 90 | + This test takes the same strategy as mp3 to compare the result |
| 91 | + """ |
| 92 | + src_path = self.get_temp_path(f'test_flac_{sample_rate}_{num_channels}_{compression_level}_{duration}.wav') |
| 93 | + flac_path = f'{src_path}.flac' |
| 94 | + wav_path = f'{flac_path}.wav' |
| 95 | + flac_path_sox = f'{src_path}.sox.flac' |
| 96 | + wav_path_sox = f'{flac_path_sox}.wav' |
| 97 | + |
| 98 | + # 1. Generate original wav |
| 99 | + sox_utils.gen_audio_file( |
| 100 | + src_path, sample_rate, num_channels, |
| 101 | + bit_depth=32, |
| 102 | + encoding='floating-point', |
| 103 | + duration=duration, |
| 104 | + ) |
| 105 | + # 2.1. Convert the original wav to flac with torchaudio |
| 106 | + sox_io_backend.save( |
| 107 | + flac_path, load_wav(src_path)[0], sample_rate, compression=compression_level) |
| 108 | + # 2.2. Convert the flac to wav with Sox |
| 109 | + # converting to 32 bit because flac file has 24 bit depth which scipy cannot handle. |
| 110 | + sox_utils.convert_audio_file(flac_path, wav_path, bit_depth=32) |
| 111 | + # 2.3. Load |
| 112 | + found = load_wav(wav_path)[0] |
| 113 | + |
| 114 | + # 3.1. Convert the original wav to flac with SoX |
| 115 | + sox_utils.convert_audio_file(src_path, flac_path_sox, compression=compression_level) |
| 116 | + # 3.2. Convert the flac to wav with Sox |
| 117 | + # converting to 32 bit because flac file has 24 bit depth which scipy cannot handle. |
| 118 | + sox_utils.convert_audio_file(flac_path_sox, wav_path_sox, bit_depth=32) |
| 119 | + # 3.3. Load |
| 120 | + expected = load_wav(wav_path_sox)[0] |
| 121 | + |
| 122 | + self.assertEqual(found, expected) |
| 123 | + |
| 124 | + def _assert_vorbis(self, sample_rate, num_channels, quality_level, duration): |
| 125 | + """`sox_io_backend.save` can save vorbis format. |
| 126 | +
|
| 127 | + This test takes the same strategy as mp3 to compare the result |
| 128 | + """ |
| 129 | + src_path = self.get_temp_path(f'test_vorbis_{sample_rate}_{num_channels}_{quality_level}_{duration}.wav') |
| 130 | + vorbis_path = f'{src_path}.vorbis' |
| 131 | + wav_path = f'{vorbis_path}.wav' |
| 132 | + vorbis_path_sox = f'{src_path}.sox.vorbis' |
| 133 | + wav_path_sox = f'{vorbis_path_sox}.wav' |
| 134 | + |
| 135 | + # 1. Generate original wav |
| 136 | + sox_utils.gen_audio_file( |
| 137 | + src_path, sample_rate, num_channels, |
| 138 | + bit_depth=16, |
| 139 | + encoding='signed-integer', |
| 140 | + duration=duration, |
| 141 | + ) |
| 142 | + # 2.1. Convert the original wav to vorbis with torchaudio |
| 143 | + sox_io_backend.save( |
| 144 | + vorbis_path, load_wav(src_path)[0], sample_rate, compression=quality_level) |
| 145 | + # 2.2. Convert the vorbis to wav with Sox |
| 146 | + sox_utils.convert_audio_file(vorbis_path, wav_path) |
| 147 | + # 2.3. Load |
| 148 | + found = load_wav(wav_path)[0] |
| 149 | + |
| 150 | + # 3.1. Convert the original wav to vorbis with SoX |
| 151 | + sox_utils.convert_audio_file(src_path, vorbis_path_sox, compression=quality_level) |
| 152 | + # 3.2. Convert the vorbis to wav with Sox |
| 153 | + sox_utils.convert_audio_file(vorbis_path_sox, wav_path_sox) |
| 154 | + # 3.3. Load |
| 155 | + expected = load_wav(wav_path_sox)[0] |
| 156 | + |
| 157 | + # sox's vorbis encoding has some randomness, which cause small number of samples yields |
| 158 | + # higher descrepency than the others. |
| 159 | + # so we allow small portions of data to be outside of absolute torelance. |
| 160 | + atol = 1.0e-4 |
| 161 | + max_failure_allowed = 0.05 # this percent of samples are allowed to outside of atol. |
| 162 | + failure_ratio = ((found - expected).abs() > atol).sum().item() / found.numel() |
| 163 | + if failure_ratio > max_failure_allowed: |
| 164 | + # it's failed and this will give a better error message. |
| 165 | + self.assertEqual(found, expected, atol=atol, rtol=1.3e-6) |
| 166 | + |
| 167 | + def assert_vorbis(self, *args, **kwargs): |
| 168 | + # sox's vorbis encoding has some randomness, so we run tests multiple time |
| 169 | + max_retry = 5 |
| 170 | + error = None |
| 171 | + for _ in range(max_retry): |
| 172 | + try: |
| 173 | + self._assert_vorbis(*args, **kwargs) |
| 174 | + break |
| 175 | + except AssertionError as e: |
| 176 | + error = e |
| 177 | + else: |
| 178 | + raise error |
| 179 | + |
| 180 | + |
| 181 | +@skipIfNoExec('sox') |
| 182 | +@skipIfNoExtension |
| 183 | +class TestSave(SaveTestBase): |
| 184 | + @parameterized.expand(list(itertools.product( |
| 185 | + ['float32', 'int32', 'int16', 'uint8'], |
| 186 | + [8000, 16000], |
| 187 | + [1, 2], |
| 188 | + )), name_func=get_test_name) |
| 189 | + def test_wav(self, dtype, sample_rate, num_channels): |
| 190 | + """`sox_io_backend.save` can save wav format.""" |
| 191 | + self.assert_wav(dtype, sample_rate, num_channels, num_frames=None) |
| 192 | + |
| 193 | + @parameterized.expand(list(itertools.product( |
| 194 | + ['float32'], |
| 195 | + [16000], |
| 196 | + [2], |
| 197 | + )), name_func=get_test_name) |
| 198 | + def test_wav_large(self, dtype, sample_rate, num_channels): |
| 199 | + """`sox_io_backend.save` can save large wav file.""" |
| 200 | + two_hours = 2 * 60 * 60 * sample_rate |
| 201 | + self.assert_wav(dtype, sample_rate, num_channels, num_frames=two_hours) |
| 202 | + |
| 203 | + @parameterized.expand(list(itertools.product( |
| 204 | + ['float32', 'int32', 'int16', 'uint8'], |
| 205 | + [4, 8, 16, 32], |
| 206 | + )), name_func=get_test_name) |
| 207 | + def test_multiple_channels(self, dtype, num_channels): |
| 208 | + """`sox_io_backend.save` can save wav with more than 2 channels.""" |
| 209 | + sample_rate = 8000 |
| 210 | + self.assert_wav(dtype, sample_rate, num_channels, num_frames=None) |
| 211 | + |
| 212 | + @parameterized.expand(list(itertools.product( |
| 213 | + [8000, 16000], |
| 214 | + [1, 2], |
| 215 | + [96, 128, 160, 192, 224, 256, 320], |
| 216 | + )), name_func=get_test_name) |
| 217 | + def test_mp3(self, sample_rate, num_channels, bit_rate): |
| 218 | + """`sox_io_backend.save` can save mp3 format.""" |
| 219 | + self.assert_mp3(sample_rate, num_channels, bit_rate, duration=1) |
| 220 | + |
| 221 | + @parameterized.expand(list(itertools.product( |
| 222 | + [16000], |
| 223 | + [2], |
| 224 | + [128], |
| 225 | + )), name_func=get_test_name) |
| 226 | + def test_mp3_large(self, sample_rate, num_channels, bit_rate): |
| 227 | + """`sox_io_backend.save` can save large mp3 file.""" |
| 228 | + two_hours = 2 * 60 * 60 |
| 229 | + self.assert_mp3(sample_rate, num_channels, bit_rate, duration=two_hours) |
| 230 | + |
| 231 | + @parameterized.expand(list(itertools.product( |
| 232 | + [8000, 16000], |
| 233 | + [1, 2], |
| 234 | + list(range(9)), |
| 235 | + )), name_func=get_test_name) |
| 236 | + def test_flac(self, sample_rate, num_channels, compression_level): |
| 237 | + """`sox_io_backend.save` can save flac format.""" |
| 238 | + self.assert_flac(sample_rate, num_channels, compression_level, duration=1) |
| 239 | + |
| 240 | + @parameterized.expand(list(itertools.product( |
| 241 | + [16000], |
| 242 | + [2], |
| 243 | + [0], |
| 244 | + )), name_func=get_test_name) |
| 245 | + def test_flac_large(self, sample_rate, num_channels, compression_level): |
| 246 | + """`sox_io_backend.save` can save large flac file.""" |
| 247 | + two_hours = 2 * 60 * 60 |
| 248 | + self.assert_flac(sample_rate, num_channels, compression_level, duration=two_hours) |
| 249 | + |
| 250 | + @parameterized.expand(list(itertools.product( |
| 251 | + [8000, 16000], |
| 252 | + [1, 2], |
| 253 | + [-1, 0, 1, 2, 3, 3.6, 5, 10], |
| 254 | + )), name_func=get_test_name) |
| 255 | + def test_vorbis(self, sample_rate, num_channels, quality_level): |
| 256 | + """`sox_io_backend.save` can save vorbis format.""" |
| 257 | + self.assert_vorbis(sample_rate, num_channels, quality_level, duration=1) |
| 258 | + |
| 259 | + # note: torchaudio can load large vorbis file, but cannot save large volbis file |
| 260 | + # the following test causes Segmentation fault |
| 261 | + # |
| 262 | + ''' |
| 263 | + @parameterized.expand(list(itertools.product( |
| 264 | + [16000], |
| 265 | + [2], |
| 266 | + [10], |
| 267 | + )), name_func=get_test_name) |
| 268 | + def test_vorbis_large(self, sample_rate, num_channels, quality_level): |
| 269 | + """`sox_io_backend.save` can save large vorbis file correctly.""" |
| 270 | + two_hours = 2 * 60 * 60 |
| 271 | + self.assert_vorbis(sample_rate, num_channels, quality_level, two_hours) |
| 272 | + ''' |
| 273 | + |
| 274 | + |
| 275 | +@skipIfNoExec('sox') |
| 276 | +@skipIfNoExtension |
| 277 | +class TestSaveParams(TempDirMixin, PytorchTestCase): |
| 278 | + """Test the correctness of optional parameters of `sox_io_backend.save`""" |
| 279 | + @parameterized.expand([(True, ), (False, )], name_func=get_test_name) |
| 280 | + def test_channels_first(self, channels_first): |
| 281 | + """channels_first swaps axes""" |
| 282 | + path = self.get_temp_path('test_channel_first_{channels_first}.wav') |
| 283 | + data = get_wav_data('int32', 2, channels_first=channels_first) |
| 284 | + sox_io_backend.save( |
| 285 | + path, data, 8000, channels_first=channels_first) |
| 286 | + found = load_wav(path)[0] |
| 287 | + expected = data if channels_first else data.transpose(1, 0) |
| 288 | + self.assertEqual(found, expected) |
| 289 | + |
| 290 | + @parameterized.expand([ |
| 291 | + 'float32', 'int32', 'int16', 'uint8' |
| 292 | + ], name_func=get_test_name) |
| 293 | + def test_noncontiguous(self, dtype): |
| 294 | + """Noncontiguous tensors are saved correctly""" |
| 295 | + path = self.get_temp_path('test_uncontiguous_{dtype}.wav') |
| 296 | + expected = get_wav_data(dtype, 4)[::2, ::2] |
| 297 | + assert not expected.is_contiguous() |
| 298 | + sox_io_backend.save(path, expected, 8000) |
| 299 | + found = load_wav(path)[0] |
| 300 | + self.assertEqual(found, expected) |
| 301 | + |
| 302 | + @parameterized.expand([ |
| 303 | + 'float32', 'int32', 'int16', 'uint8', |
| 304 | + ]) |
| 305 | + def test_tensor_preserve(self, dtype): |
| 306 | + """save function should not alter Tensor""" |
| 307 | + path = self.get_temp_path(f'test_preserve_{dtype}.wav') |
| 308 | + expected = get_wav_data(dtype, 4)[::2, ::2] |
| 309 | + |
| 310 | + data = expected.clone() |
| 311 | + sox_io_backend.save(path, data, 8000) |
| 312 | + |
| 313 | + self.assertEqual(data, expected) |
0 commit comments