From 19c60a08a3ce51eaf74883a3952cb6fabad1ac0a Mon Sep 17 00:00:00 2001 From: moto <855818+mthrok@users.noreply.github.com> Date: Wed, 1 Jun 2022 17:40:13 -0700 Subject: [PATCH] Use FFmpeg-based I/O as fallback in sox_io backend (#2419) Summary: This commit add fallback mechanism to `info` and `load` functions of sox_io backend. If torchaudio is compiled to use FFmpeg, and runtime dependencies are properly loaded, in case `info` and `load` fail, it fallback to FFmpeg-based implementation. BC-breaking changes: - FFmpeg does not report the number of frames for MP3, this is because MP3 does not store the information of the number of frames. It can be estimated from the audio duration and sample rate, but it might be inaccurate, so we keep it 0. Depends on - https://github.com/pytorch/audio/issues/2416 - https://github.com/pytorch/audio/issues/2417 - https://github.com/pytorch/audio/issues/2418 - https://github.com/pytorch/audio/issues/2423 - https://github.com/pytorch/audio/issues/2427 Pull Request resolved: https://github.com/pytorch/audio/pull/2419 Reviewed By: carolineechen Differential Revision: D36740306 Pulled By: mthrok fbshipit-source-id: 9e2ad095b8b39e41404970de0d8d9b5aaa856c97 --- .../backend/sox_io/info_test.py | 30 +++- .../backend/sox_io/load_test.py | 143 +++++++++--------- .../backend/sox_io/save_test.py | 29 +--- .../backend/sox_io/smoke_test.py | 11 -- torchaudio/backend/sox_io_backend.py | 16 +- .../csrc/ffmpeg/stream_reader_binding.cpp | 13 -- torchaudio/io/_compat.py | 110 ++++++++++++++ torchaudio/utils/__init__.py | 2 +- 8 files changed, 220 insertions(+), 134 deletions(-) create mode 100644 torchaudio/io/_compat.py diff --git a/test/torchaudio_unittest/backend/sox_io/info_test.py b/test/torchaudio_unittest/backend/sox_io/info_test.py index 7938fbcda9..289de5bcbe 100644 --- a/test/torchaudio_unittest/backend/sox_io/info_test.py +++ b/test/torchaudio_unittest/backend/sox_io/info_test.py @@ -312,23 +312,31 @@ def test_opus(self, bitrate, num_channels, compression_level): @skipIfNoSox class TestLoadWithoutExtension(PytorchTestCase): def test_mp3(self): - """Providing `format` allows to read mp3 without extension - - libsox does not check header for mp3 + """MP3 file without extension can be loaded + Originally, we added `format` argument for this case, but now we use FFmpeg + for MP3 decoding, which works even without `format` argument. https://github.com/pytorch/audio/issues/1040 The file was generated with the following command ffmpeg -f lavfi -i "sine=frequency=1000:duration=5" -ar 16000 -f mp3 test_noext """ path = get_asset_path("mp3_without_ext") - sinfo = sox_io_backend.info(path, format="mp3") + sinfo = sox_io_backend.info(path) assert sinfo.sample_rate == 16000 - assert sinfo.num_frames == 81216 + assert sinfo.num_frames == 0 assert sinfo.num_channels == 1 assert sinfo.bits_per_sample == 0 # bit_per_sample is irrelevant for compressed formats assert sinfo.encoding == "MP3" + with open(path, "rb") as fileobj: + sinfo = sox_io_backend.info(fileobj) + assert sinfo.sample_rate == 16000 + assert sinfo.num_frames == 0 + assert sinfo.num_channels == 1 + assert sinfo.bits_per_sample == 0 + assert sinfo.encoding == "MP3" + class FileObjTestBase(TempDirMixin): def _gen_file(self, ext, dtype, sample_rate, num_channels, num_frames, *, comments=None): @@ -355,6 +363,14 @@ def _gen_comment_file(self, comments): return comment_path +class Unseekable: + def __init__(self, fileobj): + self.fileobj = fileobj + + def read(self, n): + return self.fileobj.read(n) + + @skipIfNoSox @skipIfNoExec("sox") class TestFileObject(FileObjTestBase, PytorchTestCase): @@ -435,7 +451,7 @@ def test_fileobj_large_header(self, ext, dtype): num_channels = 2 comments = "metadata=" + " ".join(["value" for _ in range(1000)]) - with self.assertRaisesRegex(RuntimeError, "Failed to fetch metadata from"): + with self.assertRaises(RuntimeError): sinfo = self._query_fileobj(ext, dtype, sample_rate, num_channels, num_frames, comments=comments) with self._set_buffer_size(16384): @@ -545,7 +561,7 @@ def _query_http(self, ext, dtype, sample_rate, num_channels, num_frames): url = self.get_url(audio_file) format_ = ext if ext in ["mp3"] else None with requests.get(url, stream=True) as resp: - return sox_io_backend.info(resp.raw, format=format_) + return sox_io_backend.info(Unseekable(resp.raw), format=format_) @parameterized.expand( [ diff --git a/test/torchaudio_unittest/backend/sox_io/load_test.py b/test/torchaudio_unittest/backend/sox_io/load_test.py index 760351b3e1..37a14619cf 100644 --- a/test/torchaudio_unittest/backend/sox_io/load_test.py +++ b/test/torchaudio_unittest/backend/sox_io/load_test.py @@ -2,6 +2,8 @@ import itertools import tarfile +import torch +import torchaudio from parameterized import parameterized from torchaudio._internal import module_utils as _mod_utils from torchaudio.backend import sox_io_backend @@ -10,6 +12,7 @@ get_wav_data, HttpServerMixin, load_wav, + nested_params, PytorchTestCase, save_wav, skipIfNoExec, @@ -169,35 +172,6 @@ def test_multiple_channels(self, dtype, num_channels): normalize = False self.assert_wav(dtype, sample_rate, num_channels, normalize, duration=1) - @parameterized.expand( - list( - itertools.product( - [8000, 16000, 44100], - [1, 2], - [96, 128, 160, 192, 224, 256, 320], - ) - ), - name_func=name_func, - ) - def test_mp3(self, sample_rate, num_channels, bit_rate): - """`sox_io_backend.load` can load mp3 format correctly.""" - self.assert_format("mp3", sample_rate, num_channels, compression=bit_rate, duration=1, atol=5e-05) - - @parameterized.expand( - list( - itertools.product( - [16000], - [2], - [128], - ) - ), - name_func=name_func, - ) - def test_mp3_large(self, sample_rate, num_channels, bit_rate): - """`sox_io_backend.load` can load large mp3 file correctly.""" - two_hours = 2 * 60 * 60 - self.assert_format("mp3", sample_rate, num_channels, compression=bit_rate, duration=two_hours, atol=5e-05) - @parameterized.expand( list( itertools.product( @@ -319,72 +293,92 @@ def test_amr_nb(self): self.assert_format("amr-nb", sample_rate=8000, num_channels=1, bit_depth=32, duration=1) -@skipIfNoExec("sox") @skipIfNoSox class TestLoadParams(TempDirMixin, PytorchTestCase): """Test the correctness of frame parameters of `sox_io_backend.load`""" - original = None - path = None + def _test(self, func, frame_offset, num_frames, channels_first, normalize): + original = get_wav_data("int16", num_channels=2, normalize=False) + path = self.get_temp_path("test.wav") + save_wav(path, original, sample_rate=8000) - def setUp(self): - super().setUp() - sample_rate = 8000 - self.original = get_wav_data("float32", num_channels=2) - self.path = self.get_temp_path("test.wav") - save_wav(self.path, self.original, sample_rate) + output, _ = func(path, frame_offset, num_frames, normalize, channels_first, None) + frame_end = None if num_frames == -1 else frame_offset + num_frames + expected = original[:, slice(frame_offset, frame_end)] + if not channels_first: + expected = expected.T + if normalize: + expected = expected.to(torch.float32) / (2**15) + self.assertEqual(output, expected) + + @nested_params( + [0, 1, 10, 100, 1000], + [-1, 1, 10, 100, 1000], + [True, False], + [True, False], + ) + def test_sox(self, frame_offset, num_frames, channels_first, normalize): + """The combination of properly changes the output tensor""" - @parameterized.expand( - list( - itertools.product( - [0, 1, 10, 100, 1000], - [-1, 1, 10, 100, 1000], - ) - ), - name_func=name_func, + self._test(torch.ops.torchaudio.sox_io_load_audio_file, frame_offset, num_frames, channels_first, normalize) + + # test file-like obj + def func(path, *args): + with open(path, "rb") as fileobj: + return torchaudio._torchaudio.load_audio_fileobj(fileobj, *args) + + self._test(func, frame_offset, num_frames, channels_first, normalize) + + @nested_params( + [0, 1, 10, 100, 1000], + [-1, 1, 10, 100, 1000], + [True, False], + [True, False], ) - def test_frame(self, frame_offset, num_frames): - """num_frames and frame_offset correctly specify the region of data""" - found, _ = sox_io_backend.load(self.path, frame_offset, num_frames) - frame_end = None if num_frames == -1 else frame_offset + num_frames - self.assertEqual(found, self.original[:, frame_offset:frame_end]) + def test_ffmpeg(self, frame_offset, num_frames, channels_first, normalize): + """The combination of properly changes the output tensor""" + from torchaudio.io._compat import load_audio, load_audio_fileobj + + self._test(load_audio, frame_offset, num_frames, channels_first, normalize) + + # test file-like obj + def func(path, *args): + with open(path, "rb") as fileobj: + return load_audio_fileobj(fileobj, *args) - @parameterized.expand([(True,), (False,)], name_func=name_func) - def test_channels_first(self, channels_first): - """channels_first swaps axes""" - found, _ = sox_io_backend.load(self.path, channels_first=channels_first) - expected = self.original if channels_first else self.original.transpose(1, 0) - self.assertEqual(found, expected) + self._test(func, frame_offset, num_frames, channels_first, normalize) @skipIfNoSox class TestLoadWithoutExtension(PytorchTestCase): def test_mp3(self): - """Providing format allows to read mp3 without extension - - libsox does not check header for mp3 + """MP3 file without extension can be loaded + Originally, we added `format` argument for this case, but now we use FFmpeg + for MP3 decoding, which works even without `format` argument. https://github.com/pytorch/audio/issues/1040 The file was generated with the following command ffmpeg -f lavfi -i "sine=frequency=1000:duration=5" -ar 16000 -f mp3 test_noext """ path = get_asset_path("mp3_without_ext") - _, sr = sox_io_backend.load(path, format="mp3") + _, sr = sox_io_backend.load(path) + assert sr == 16000 + + with open(path, "rb") as fileobj: + _, sr = sox_io_backend.load(fileobj) assert sr == 16000 class CloggedFileObj: def __init__(self, fileobj): self.fileobj = fileobj - self.buffer = b"" - def read(self, n): - if not self.buffer: - self.buffer += self.fileobj.read(n) - ret = self.buffer[:2] - self.buffer = self.buffer[2:] - return ret + def read(self, _): + return self.fileobj.read(2) + + def seek(self, offset, whence): + return self.fileobj.seek(offset, whence) @skipIfNoSox @@ -557,6 +551,14 @@ def test_tarfile(self, ext, kwargs): self.assertEqual(expected, found) +class Unseekable: + def __init__(self, fileobj): + self.fileobj = fileobj + + def read(self, n): + return self.fileobj.read(n) + + @skipIfNoSox @skipIfNoExec("sox") @skipIfNoModule("requests") @@ -587,10 +589,11 @@ def test_requests(self, ext, kwargs): url = self.get_url(audio_file) with requests.get(url, stream=True) as resp: - found, sr = sox_io_backend.load(resp.raw, format=format_) + found, sr = sox_io_backend.load(Unseekable(resp.raw), format=format_) assert sr == sample_rate - self.assertEqual(expected, found) + if ext != "mp3": + self.assertEqual(expected, found) @parameterized.expand( list( diff --git a/test/torchaudio_unittest/backend/sox_io/save_test.py b/test/torchaudio_unittest/backend/sox_io/save_test.py index 848bf41310..5db7a5a9f8 100644 --- a/test/torchaudio_unittest/backend/sox_io/save_test.py +++ b/test/torchaudio_unittest/backend/sox_io/save_test.py @@ -1,6 +1,5 @@ import io import os -import unittest import torch from parameterized import parameterized @@ -179,31 +178,6 @@ def test_save_wav_dtype(self, test_mode, params): (dtype,) = params self.assert_save_consistency("wav", src_dtype=dtype, test_mode=test_mode) - @nested_params( - ["path", "fileobj", "bytesio"], - [ - None, - -4.2, - -0.2, - 0, - 0.2, - 96, - 128, - 160, - 192, - 224, - 256, - 320, - ], - ) - def test_save_mp3(self, test_mode, bit_rate): - if test_mode in ["fileobj", "bytesio"]: - if bit_rate is not None and bit_rate < 1: - raise unittest.SkipTest( - "mp3 format with variable bit rate is known to " "not yield the exact same result as sox command." - ) - self.assert_save_consistency("mp3", compression=bit_rate, test_mode=test_mode) - @nested_params( ["path", "fileobj", "bytesio"], [8, 16, 24], @@ -349,7 +323,6 @@ def test_save_gsm(self, test_mode): @parameterized.expand( [ ("wav", "PCM_S", 16), - ("mp3",), ("flac",), ("vorbis",), ("sph", "PCM_S", 16), @@ -437,5 +410,5 @@ def test_save_fail(self): When attempted to save into a non-existing dir, error message must contain the file path. """ path = os.path.join("non_existing_directory", "foo.wav") - with self.assertRaisesRegex(RuntimeError, "^Error saving audio file: failed to open file {0}$".format(path)): + with self.assertRaisesRegex(RuntimeError, path): sox_io_backend.save(path, torch.zeros(1, 1), 8000) diff --git a/test/torchaudio_unittest/backend/sox_io/smoke_test.py b/test/torchaudio_unittest/backend/sox_io/smoke_test.py index b3e39b61cb..4329209bc8 100644 --- a/test/torchaudio_unittest/backend/sox_io/smoke_test.py +++ b/test/torchaudio_unittest/backend/sox_io/smoke_test.py @@ -1,11 +1,8 @@ import io import itertools -import unittest from parameterized import parameterized -from torchaudio._internal.module_utils import is_sox_available from torchaudio.backend import sox_io_backend -from torchaudio.utils import sox_utils from torchaudio_unittest.common_utils import ( get_wav_data, skipIfNoSox, @@ -16,12 +13,6 @@ from .common import name_func -skipIfNoMP3 = unittest.skipIf( - not is_sox_available() or "mp3" not in sox_utils.list_read_formats() or "mp3" not in sox_utils.list_write_formats(), - '"sox_io" backend does not support MP3', -) - - @skipIfNoSox class SmokeTest(TempDirMixin, TorchaudioTestCase): """Run smoke test on various audio format @@ -73,7 +64,6 @@ def test_wav(self, dtype, sample_rate, num_channels): ) ) ) - @skipIfNoMP3 def test_mp3(self, sample_rate, num_channels, bit_rate): """Run smoke test on mp3 format""" self.run_smoke_test("mp3", sample_rate, num_channels, compression=bit_rate) @@ -159,7 +149,6 @@ def test_wav(self, dtype, sample_rate, num_channels): ) ) ) - @skipIfNoMP3 def test_mp3(self, sample_rate, num_channels, bit_rate): """Run smoke test on mp3 format""" self.run_smoke_test("mp3", sample_rate, num_channels, compression=bit_rate) diff --git a/torchaudio/backend/sox_io_backend.py b/torchaudio/backend/sox_io_backend.py index b6574afaff..f7cd846c01 100644 --- a/torchaudio/backend/sox_io_backend.py +++ b/torchaudio/backend/sox_io_backend.py @@ -33,10 +33,18 @@ def _fail_load_fileobj(fileobj, *args, **kwargs): raise RuntimeError(f"Failed to load audio from {fileobj}") -_fallback_info = _fail_info -_fallback_info_fileobj = _fail_info_fileobj -_fallback_load = _fail_load -_fallback_load_fileobj = _fail_load_fileobj +if torchaudio._extension._FFMPEG_INITIALIZED: + import torchaudio.io._compat as _compat + + _fallback_info = _compat.info_audio + _fallback_info_fileobj = _compat.info_audio_fileobj + _fallback_load = _compat.load_audio + _fallback_load_fileobj = _compat.load_audio_fileobj +else: + _fallback_info = _fail_info + _fallback_info_fileobj = _fail_info_fileobj + _fallback_load = _fail_load + _fallback_load_filebj = _fail_load_fileobj @_mod_utils.requires_sox() diff --git a/torchaudio/csrc/ffmpeg/stream_reader_binding.cpp b/torchaudio/csrc/ffmpeg/stream_reader_binding.cpp index 32635f72cc..5bf12a37ee 100644 --- a/torchaudio/csrc/ffmpeg/stream_reader_binding.cpp +++ b/torchaudio/csrc/ffmpeg/stream_reader_binding.cpp @@ -26,18 +26,6 @@ c10::intrusive_ptr init( get_input_format_context(src, device, map(option))); } -std::tuple, int64_t> load(const std::string& src) { - StreamReaderBinding s{get_input_format_context(src, {}, {})}; - int i = static_cast(s.find_best_audio_stream()); - auto sinfo = s.StreamReader::get_src_stream_info(i); - int64_t sample_rate = static_cast(sinfo.sample_rate); - s.add_audio_stream(i, -1, -1, {}, {}, {}); - s.process_all_packets(); - auto tensors = s.pop_chunks(); - assert(tensors.size() > 0); - return std::make_tuple<>(tensors[0], sample_rate); -} - using S = const c10::intrusive_ptr&; TORCH_LIBRARY_FRAGMENT(torchaudio, m) { @@ -47,7 +35,6 @@ TORCH_LIBRARY_FRAGMENT(torchaudio, m) { av_log_set_level(AV_LOG_ERROR); } }); - m.def("torchaudio::ffmpeg_load", load); m.class_("ffmpeg_StreamReader") .def(torch::init<>(init)) .def("num_src_streams", [](S self) { return self->num_src_streams(); }) diff --git a/torchaudio/io/_compat.py b/torchaudio/io/_compat.py new file mode 100644 index 0000000000..c97d51ef3f --- /dev/null +++ b/torchaudio/io/_compat.py @@ -0,0 +1,110 @@ +from typing import Dict, Optional, Tuple + +import torch +import torchaudio +from torchaudio.backend.common import AudioMetaData + + +# Note: need to comply TorchScript syntax -- need annotation and no f-string nor global +def _info_audio( + s: torch.classes.torchaudio.ffmpeg_StreamReader, +): + i = s.find_best_audio_stream() + sinfo = s.get_src_stream_info(i) + return AudioMetaData( + int(sinfo[7]), + sinfo[5], + sinfo[8], + sinfo[6], + sinfo[1].upper(), + ) + + +def info_audio( + src: str, + format: Optional[str], +) -> AudioMetaData: + s = torch.classes.torchaudio.ffmpeg_StreamReader(src, format, None) + return _info_audio(s) + + +def info_audio_fileobj( + src, + format: Optional[str], +) -> AudioMetaData: + s = torchaudio._torchaudio_ffmpeg.StreamReaderFileObj(src, format, None, 4096) + return _info_audio(s) + + +def _get_load_filter( + frame_offset: int = 0, + num_frames: int = -1, + convert: bool = True, +) -> Optional[str]: + if frame_offset < 0: + raise RuntimeError("Invalid argument: frame_offset must be non-negative. Found: {}".format(frame_offset)) + if num_frames == 0 or num_frames < -1: + raise RuntimeError("Invalid argument: num_frames must be -1 or greater than 0. Found: {}".format(num_frames)) + + # All default values -> no filter + if frame_offset == 0 and num_frames == -1 and not convert: + return None + # Only convert + aformat = "aformat=sample_fmts=fltp" + if frame_offset == 0 and num_frames == -1 and convert: + return aformat + # At least one of frame_offset or num_frames has non-default value + if num_frames > 0: + atrim = "atrim=start_sample={}:end_sample={}".format(frame_offset, frame_offset + num_frames) + else: + atrim = "atrim=start_sample={}".format(frame_offset) + if not convert: + return atrim + return "{},{}".format(atrim, aformat) + + +# Note: need to comply TorchScript syntax -- need annotation and no f-string nor global +def _load_audio( + s: torch.classes.torchaudio.ffmpeg_StreamReader, + frame_offset: int = 0, + num_frames: int = -1, + convert: bool = True, + channels_first: bool = True, +) -> Tuple[torch.Tensor, int]: + i = s.find_best_audio_stream() + sinfo = s.get_src_stream_info(i) + sample_rate = int(sinfo[7]) + option: Dict[str, str] = {} + s.add_audio_stream(i, -1, -1, _get_load_filter(frame_offset, num_frames, convert), None, option) + s.process_all_packets() + waveform = s.pop_chunks()[0] + if waveform is None: + raise RuntimeError("Failed to decode audio.") + assert waveform is not None + if channels_first: + waveform = waveform.T + return waveform, sample_rate + + +def load_audio( + src: str, + frame_offset: int = 0, + num_frames: int = -1, + convert: bool = True, + channels_first: bool = True, + format: Optional[str] = None, +) -> Tuple[torch.Tensor, int]: + s = torch.classes.torchaudio.ffmpeg_StreamReader(src, format, None) + return _load_audio(s, frame_offset, num_frames, convert, channels_first) + + +def load_audio_fileobj( + src: str, + frame_offset: int = 0, + num_frames: int = -1, + convert: bool = True, + channels_first: bool = True, + format: Optional[str] = None, +) -> Tuple[torch.Tensor, int]: + s = torchaudio._torchaudio_ffmpeg.StreamReaderFileObj(src, format, None, 4096) + return _load_audio(s, frame_offset, num_frames, convert, channels_first) diff --git a/torchaudio/utils/__init__.py b/torchaudio/utils/__init__.py index 87761d6b98..90874dd225 100644 --- a/torchaudio/utils/__init__.py +++ b/torchaudio/utils/__init__.py @@ -4,7 +4,7 @@ from .download import download_asset if _mod_utils.is_sox_available(): - sox_utils.set_verbosity(1) + sox_utils.set_verbosity(0) __all__ = [