diff --git a/test/torchaudio_unittest/backend/sox_io/info_test.py b/test/torchaudio_unittest/backend/sox_io/info_test.py index 7938fbcda9..289de5bcbe 100644 --- a/test/torchaudio_unittest/backend/sox_io/info_test.py +++ b/test/torchaudio_unittest/backend/sox_io/info_test.py @@ -312,23 +312,31 @@ def test_opus(self, bitrate, num_channels, compression_level): @skipIfNoSox class TestLoadWithoutExtension(PytorchTestCase): def test_mp3(self): - """Providing `format` allows to read mp3 without extension - - libsox does not check header for mp3 + """MP3 file without extension can be loaded + Originally, we added `format` argument for this case, but now we use FFmpeg + for MP3 decoding, which works even without `format` argument. https://github.com/pytorch/audio/issues/1040 The file was generated with the following command ffmpeg -f lavfi -i "sine=frequency=1000:duration=5" -ar 16000 -f mp3 test_noext """ path = get_asset_path("mp3_without_ext") - sinfo = sox_io_backend.info(path, format="mp3") + sinfo = sox_io_backend.info(path) assert sinfo.sample_rate == 16000 - assert sinfo.num_frames == 81216 + assert sinfo.num_frames == 0 assert sinfo.num_channels == 1 assert sinfo.bits_per_sample == 0 # bit_per_sample is irrelevant for compressed formats assert sinfo.encoding == "MP3" + with open(path, "rb") as fileobj: + sinfo = sox_io_backend.info(fileobj) + assert sinfo.sample_rate == 16000 + assert sinfo.num_frames == 0 + assert sinfo.num_channels == 1 + assert sinfo.bits_per_sample == 0 + assert sinfo.encoding == "MP3" + class FileObjTestBase(TempDirMixin): def _gen_file(self, ext, dtype, sample_rate, num_channels, num_frames, *, comments=None): @@ -355,6 +363,14 @@ def _gen_comment_file(self, comments): return comment_path +class Unseekable: + def __init__(self, fileobj): + self.fileobj = fileobj + + def read(self, n): + return self.fileobj.read(n) + + @skipIfNoSox @skipIfNoExec("sox") class TestFileObject(FileObjTestBase, PytorchTestCase): @@ -435,7 +451,7 @@ def test_fileobj_large_header(self, ext, dtype): num_channels = 2 comments = "metadata=" + " ".join(["value" for _ in range(1000)]) - with self.assertRaisesRegex(RuntimeError, "Failed to fetch metadata from"): + with self.assertRaises(RuntimeError): sinfo = self._query_fileobj(ext, dtype, sample_rate, num_channels, num_frames, comments=comments) with self._set_buffer_size(16384): @@ -545,7 +561,7 @@ def _query_http(self, ext, dtype, sample_rate, num_channels, num_frames): url = self.get_url(audio_file) format_ = ext if ext in ["mp3"] else None with requests.get(url, stream=True) as resp: - return sox_io_backend.info(resp.raw, format=format_) + return sox_io_backend.info(Unseekable(resp.raw), format=format_) @parameterized.expand( [ diff --git a/test/torchaudio_unittest/backend/sox_io/load_test.py b/test/torchaudio_unittest/backend/sox_io/load_test.py index 760351b3e1..37a14619cf 100644 --- a/test/torchaudio_unittest/backend/sox_io/load_test.py +++ b/test/torchaudio_unittest/backend/sox_io/load_test.py @@ -2,6 +2,8 @@ import itertools import tarfile +import torch +import torchaudio from parameterized import parameterized from torchaudio._internal import module_utils as _mod_utils from torchaudio.backend import sox_io_backend @@ -10,6 +12,7 @@ get_wav_data, HttpServerMixin, load_wav, + nested_params, PytorchTestCase, save_wav, skipIfNoExec, @@ -169,35 +172,6 @@ def test_multiple_channels(self, dtype, num_channels): normalize = False self.assert_wav(dtype, sample_rate, num_channels, normalize, duration=1) - @parameterized.expand( - list( - itertools.product( - [8000, 16000, 44100], - [1, 2], - [96, 128, 160, 192, 224, 256, 320], - ) - ), - name_func=name_func, - ) - def test_mp3(self, sample_rate, num_channels, bit_rate): - """`sox_io_backend.load` can load mp3 format correctly.""" - self.assert_format("mp3", sample_rate, num_channels, compression=bit_rate, duration=1, atol=5e-05) - - @parameterized.expand( - list( - itertools.product( - [16000], - [2], - [128], - ) - ), - name_func=name_func, - ) - def test_mp3_large(self, sample_rate, num_channels, bit_rate): - """`sox_io_backend.load` can load large mp3 file correctly.""" - two_hours = 2 * 60 * 60 - self.assert_format("mp3", sample_rate, num_channels, compression=bit_rate, duration=two_hours, atol=5e-05) - @parameterized.expand( list( itertools.product( @@ -319,72 +293,92 @@ def test_amr_nb(self): self.assert_format("amr-nb", sample_rate=8000, num_channels=1, bit_depth=32, duration=1) -@skipIfNoExec("sox") @skipIfNoSox class TestLoadParams(TempDirMixin, PytorchTestCase): """Test the correctness of frame parameters of `sox_io_backend.load`""" - original = None - path = None + def _test(self, func, frame_offset, num_frames, channels_first, normalize): + original = get_wav_data("int16", num_channels=2, normalize=False) + path = self.get_temp_path("test.wav") + save_wav(path, original, sample_rate=8000) - def setUp(self): - super().setUp() - sample_rate = 8000 - self.original = get_wav_data("float32", num_channels=2) - self.path = self.get_temp_path("test.wav") - save_wav(self.path, self.original, sample_rate) + output, _ = func(path, frame_offset, num_frames, normalize, channels_first, None) + frame_end = None if num_frames == -1 else frame_offset + num_frames + expected = original[:, slice(frame_offset, frame_end)] + if not channels_first: + expected = expected.T + if normalize: + expected = expected.to(torch.float32) / (2**15) + self.assertEqual(output, expected) + + @nested_params( + [0, 1, 10, 100, 1000], + [-1, 1, 10, 100, 1000], + [True, False], + [True, False], + ) + def test_sox(self, frame_offset, num_frames, channels_first, normalize): + """The combination of properly changes the output tensor""" - @parameterized.expand( - list( - itertools.product( - [0, 1, 10, 100, 1000], - [-1, 1, 10, 100, 1000], - ) - ), - name_func=name_func, + self._test(torch.ops.torchaudio.sox_io_load_audio_file, frame_offset, num_frames, channels_first, normalize) + + # test file-like obj + def func(path, *args): + with open(path, "rb") as fileobj: + return torchaudio._torchaudio.load_audio_fileobj(fileobj, *args) + + self._test(func, frame_offset, num_frames, channels_first, normalize) + + @nested_params( + [0, 1, 10, 100, 1000], + [-1, 1, 10, 100, 1000], + [True, False], + [True, False], ) - def test_frame(self, frame_offset, num_frames): - """num_frames and frame_offset correctly specify the region of data""" - found, _ = sox_io_backend.load(self.path, frame_offset, num_frames) - frame_end = None if num_frames == -1 else frame_offset + num_frames - self.assertEqual(found, self.original[:, frame_offset:frame_end]) + def test_ffmpeg(self, frame_offset, num_frames, channels_first, normalize): + """The combination of properly changes the output tensor""" + from torchaudio.io._compat import load_audio, load_audio_fileobj + + self._test(load_audio, frame_offset, num_frames, channels_first, normalize) + + # test file-like obj + def func(path, *args): + with open(path, "rb") as fileobj: + return load_audio_fileobj(fileobj, *args) - @parameterized.expand([(True,), (False,)], name_func=name_func) - def test_channels_first(self, channels_first): - """channels_first swaps axes""" - found, _ = sox_io_backend.load(self.path, channels_first=channels_first) - expected = self.original if channels_first else self.original.transpose(1, 0) - self.assertEqual(found, expected) + self._test(func, frame_offset, num_frames, channels_first, normalize) @skipIfNoSox class TestLoadWithoutExtension(PytorchTestCase): def test_mp3(self): - """Providing format allows to read mp3 without extension - - libsox does not check header for mp3 + """MP3 file without extension can be loaded + Originally, we added `format` argument for this case, but now we use FFmpeg + for MP3 decoding, which works even without `format` argument. https://github.com/pytorch/audio/issues/1040 The file was generated with the following command ffmpeg -f lavfi -i "sine=frequency=1000:duration=5" -ar 16000 -f mp3 test_noext """ path = get_asset_path("mp3_without_ext") - _, sr = sox_io_backend.load(path, format="mp3") + _, sr = sox_io_backend.load(path) + assert sr == 16000 + + with open(path, "rb") as fileobj: + _, sr = sox_io_backend.load(fileobj) assert sr == 16000 class CloggedFileObj: def __init__(self, fileobj): self.fileobj = fileobj - self.buffer = b"" - def read(self, n): - if not self.buffer: - self.buffer += self.fileobj.read(n) - ret = self.buffer[:2] - self.buffer = self.buffer[2:] - return ret + def read(self, _): + return self.fileobj.read(2) + + def seek(self, offset, whence): + return self.fileobj.seek(offset, whence) @skipIfNoSox @@ -557,6 +551,14 @@ def test_tarfile(self, ext, kwargs): self.assertEqual(expected, found) +class Unseekable: + def __init__(self, fileobj): + self.fileobj = fileobj + + def read(self, n): + return self.fileobj.read(n) + + @skipIfNoSox @skipIfNoExec("sox") @skipIfNoModule("requests") @@ -587,10 +589,11 @@ def test_requests(self, ext, kwargs): url = self.get_url(audio_file) with requests.get(url, stream=True) as resp: - found, sr = sox_io_backend.load(resp.raw, format=format_) + found, sr = sox_io_backend.load(Unseekable(resp.raw), format=format_) assert sr == sample_rate - self.assertEqual(expected, found) + if ext != "mp3": + self.assertEqual(expected, found) @parameterized.expand( list( diff --git a/test/torchaudio_unittest/backend/sox_io/save_test.py b/test/torchaudio_unittest/backend/sox_io/save_test.py index 848bf41310..5db7a5a9f8 100644 --- a/test/torchaudio_unittest/backend/sox_io/save_test.py +++ b/test/torchaudio_unittest/backend/sox_io/save_test.py @@ -1,6 +1,5 @@ import io import os -import unittest import torch from parameterized import parameterized @@ -179,31 +178,6 @@ def test_save_wav_dtype(self, test_mode, params): (dtype,) = params self.assert_save_consistency("wav", src_dtype=dtype, test_mode=test_mode) - @nested_params( - ["path", "fileobj", "bytesio"], - [ - None, - -4.2, - -0.2, - 0, - 0.2, - 96, - 128, - 160, - 192, - 224, - 256, - 320, - ], - ) - def test_save_mp3(self, test_mode, bit_rate): - if test_mode in ["fileobj", "bytesio"]: - if bit_rate is not None and bit_rate < 1: - raise unittest.SkipTest( - "mp3 format with variable bit rate is known to " "not yield the exact same result as sox command." - ) - self.assert_save_consistency("mp3", compression=bit_rate, test_mode=test_mode) - @nested_params( ["path", "fileobj", "bytesio"], [8, 16, 24], @@ -349,7 +323,6 @@ def test_save_gsm(self, test_mode): @parameterized.expand( [ ("wav", "PCM_S", 16), - ("mp3",), ("flac",), ("vorbis",), ("sph", "PCM_S", 16), @@ -437,5 +410,5 @@ def test_save_fail(self): When attempted to save into a non-existing dir, error message must contain the file path. """ path = os.path.join("non_existing_directory", "foo.wav") - with self.assertRaisesRegex(RuntimeError, "^Error saving audio file: failed to open file {0}$".format(path)): + with self.assertRaisesRegex(RuntimeError, path): sox_io_backend.save(path, torch.zeros(1, 1), 8000) diff --git a/test/torchaudio_unittest/backend/sox_io/smoke_test.py b/test/torchaudio_unittest/backend/sox_io/smoke_test.py index b3e39b61cb..4329209bc8 100644 --- a/test/torchaudio_unittest/backend/sox_io/smoke_test.py +++ b/test/torchaudio_unittest/backend/sox_io/smoke_test.py @@ -1,11 +1,8 @@ import io import itertools -import unittest from parameterized import parameterized -from torchaudio._internal.module_utils import is_sox_available from torchaudio.backend import sox_io_backend -from torchaudio.utils import sox_utils from torchaudio_unittest.common_utils import ( get_wav_data, skipIfNoSox, @@ -16,12 +13,6 @@ from .common import name_func -skipIfNoMP3 = unittest.skipIf( - not is_sox_available() or "mp3" not in sox_utils.list_read_formats() or "mp3" not in sox_utils.list_write_formats(), - '"sox_io" backend does not support MP3', -) - - @skipIfNoSox class SmokeTest(TempDirMixin, TorchaudioTestCase): """Run smoke test on various audio format @@ -73,7 +64,6 @@ def test_wav(self, dtype, sample_rate, num_channels): ) ) ) - @skipIfNoMP3 def test_mp3(self, sample_rate, num_channels, bit_rate): """Run smoke test on mp3 format""" self.run_smoke_test("mp3", sample_rate, num_channels, compression=bit_rate) @@ -159,7 +149,6 @@ def test_wav(self, dtype, sample_rate, num_channels): ) ) ) - @skipIfNoMP3 def test_mp3(self, sample_rate, num_channels, bit_rate): """Run smoke test on mp3 format""" self.run_smoke_test("mp3", sample_rate, num_channels, compression=bit_rate) diff --git a/torchaudio/backend/sox_io_backend.py b/torchaudio/backend/sox_io_backend.py index b6574afaff..f7cd846c01 100644 --- a/torchaudio/backend/sox_io_backend.py +++ b/torchaudio/backend/sox_io_backend.py @@ -33,10 +33,18 @@ def _fail_load_fileobj(fileobj, *args, **kwargs): raise RuntimeError(f"Failed to load audio from {fileobj}") -_fallback_info = _fail_info -_fallback_info_fileobj = _fail_info_fileobj -_fallback_load = _fail_load -_fallback_load_fileobj = _fail_load_fileobj +if torchaudio._extension._FFMPEG_INITIALIZED: + import torchaudio.io._compat as _compat + + _fallback_info = _compat.info_audio + _fallback_info_fileobj = _compat.info_audio_fileobj + _fallback_load = _compat.load_audio + _fallback_load_fileobj = _compat.load_audio_fileobj +else: + _fallback_info = _fail_info + _fallback_info_fileobj = _fail_info_fileobj + _fallback_load = _fail_load + _fallback_load_filebj = _fail_load_fileobj @_mod_utils.requires_sox() diff --git a/torchaudio/csrc/ffmpeg/stream_reader_binding.cpp b/torchaudio/csrc/ffmpeg/stream_reader_binding.cpp index 32635f72cc..5bf12a37ee 100644 --- a/torchaudio/csrc/ffmpeg/stream_reader_binding.cpp +++ b/torchaudio/csrc/ffmpeg/stream_reader_binding.cpp @@ -26,18 +26,6 @@ c10::intrusive_ptr init( get_input_format_context(src, device, map(option))); } -std::tuple, int64_t> load(const std::string& src) { - StreamReaderBinding s{get_input_format_context(src, {}, {})}; - int i = static_cast(s.find_best_audio_stream()); - auto sinfo = s.StreamReader::get_src_stream_info(i); - int64_t sample_rate = static_cast(sinfo.sample_rate); - s.add_audio_stream(i, -1, -1, {}, {}, {}); - s.process_all_packets(); - auto tensors = s.pop_chunks(); - assert(tensors.size() > 0); - return std::make_tuple<>(tensors[0], sample_rate); -} - using S = const c10::intrusive_ptr&; TORCH_LIBRARY_FRAGMENT(torchaudio, m) { @@ -47,7 +35,6 @@ TORCH_LIBRARY_FRAGMENT(torchaudio, m) { av_log_set_level(AV_LOG_ERROR); } }); - m.def("torchaudio::ffmpeg_load", load); m.class_("ffmpeg_StreamReader") .def(torch::init<>(init)) .def("num_src_streams", [](S self) { return self->num_src_streams(); }) diff --git a/torchaudio/io/_compat.py b/torchaudio/io/_compat.py new file mode 100644 index 0000000000..c97d51ef3f --- /dev/null +++ b/torchaudio/io/_compat.py @@ -0,0 +1,110 @@ +from typing import Dict, Optional, Tuple + +import torch +import torchaudio +from torchaudio.backend.common import AudioMetaData + + +# Note: need to comply TorchScript syntax -- need annotation and no f-string nor global +def _info_audio( + s: torch.classes.torchaudio.ffmpeg_StreamReader, +): + i = s.find_best_audio_stream() + sinfo = s.get_src_stream_info(i) + return AudioMetaData( + int(sinfo[7]), + sinfo[5], + sinfo[8], + sinfo[6], + sinfo[1].upper(), + ) + + +def info_audio( + src: str, + format: Optional[str], +) -> AudioMetaData: + s = torch.classes.torchaudio.ffmpeg_StreamReader(src, format, None) + return _info_audio(s) + + +def info_audio_fileobj( + src, + format: Optional[str], +) -> AudioMetaData: + s = torchaudio._torchaudio_ffmpeg.StreamReaderFileObj(src, format, None, 4096) + return _info_audio(s) + + +def _get_load_filter( + frame_offset: int = 0, + num_frames: int = -1, + convert: bool = True, +) -> Optional[str]: + if frame_offset < 0: + raise RuntimeError("Invalid argument: frame_offset must be non-negative. Found: {}".format(frame_offset)) + if num_frames == 0 or num_frames < -1: + raise RuntimeError("Invalid argument: num_frames must be -1 or greater than 0. Found: {}".format(num_frames)) + + # All default values -> no filter + if frame_offset == 0 and num_frames == -1 and not convert: + return None + # Only convert + aformat = "aformat=sample_fmts=fltp" + if frame_offset == 0 and num_frames == -1 and convert: + return aformat + # At least one of frame_offset or num_frames has non-default value + if num_frames > 0: + atrim = "atrim=start_sample={}:end_sample={}".format(frame_offset, frame_offset + num_frames) + else: + atrim = "atrim=start_sample={}".format(frame_offset) + if not convert: + return atrim + return "{},{}".format(atrim, aformat) + + +# Note: need to comply TorchScript syntax -- need annotation and no f-string nor global +def _load_audio( + s: torch.classes.torchaudio.ffmpeg_StreamReader, + frame_offset: int = 0, + num_frames: int = -1, + convert: bool = True, + channels_first: bool = True, +) -> Tuple[torch.Tensor, int]: + i = s.find_best_audio_stream() + sinfo = s.get_src_stream_info(i) + sample_rate = int(sinfo[7]) + option: Dict[str, str] = {} + s.add_audio_stream(i, -1, -1, _get_load_filter(frame_offset, num_frames, convert), None, option) + s.process_all_packets() + waveform = s.pop_chunks()[0] + if waveform is None: + raise RuntimeError("Failed to decode audio.") + assert waveform is not None + if channels_first: + waveform = waveform.T + return waveform, sample_rate + + +def load_audio( + src: str, + frame_offset: int = 0, + num_frames: int = -1, + convert: bool = True, + channels_first: bool = True, + format: Optional[str] = None, +) -> Tuple[torch.Tensor, int]: + s = torch.classes.torchaudio.ffmpeg_StreamReader(src, format, None) + return _load_audio(s, frame_offset, num_frames, convert, channels_first) + + +def load_audio_fileobj( + src: str, + frame_offset: int = 0, + num_frames: int = -1, + convert: bool = True, + channels_first: bool = True, + format: Optional[str] = None, +) -> Tuple[torch.Tensor, int]: + s = torchaudio._torchaudio_ffmpeg.StreamReaderFileObj(src, format, None, 4096) + return _load_audio(s, frame_offset, num_frames, convert, channels_first) diff --git a/torchaudio/utils/__init__.py b/torchaudio/utils/__init__.py index 87761d6b98..90874dd225 100644 --- a/torchaudio/utils/__init__.py +++ b/torchaudio/utils/__init__.py @@ -4,7 +4,7 @@ from .download import download_asset if _mod_utils.is_sox_available(): - sox_utils.set_verbosity(1) + sox_utils.set_verbosity(0) __all__ = [