diff --git a/test/torchaudio_unittest/backend/sox_io/info_test.py b/test/torchaudio_unittest/backend/sox_io/info_test.py index ea59ee46f44..289de5bcbe2 100644 --- a/test/torchaudio_unittest/backend/sox_io/info_test.py +++ b/test/torchaudio_unittest/backend/sox_io/info_test.py @@ -312,23 +312,31 @@ def test_opus(self, bitrate, num_channels, compression_level): @skipIfNoSox class TestLoadWithoutExtension(PytorchTestCase): def test_mp3(self): - """Providing `format` allows to read mp3 without extension - - libsox does not check header for mp3 + """MP3 file without extension can be loaded + Originally, we added `format` argument for this case, but now we use FFmpeg + for MP3 decoding, which works even without `format` argument. https://github.com/pytorch/audio/issues/1040 The file was generated with the following command ffmpeg -f lavfi -i "sine=frequency=1000:duration=5" -ar 16000 -f mp3 test_noext """ path = get_asset_path("mp3_without_ext") - sinfo = sox_io_backend.info(path, format="mp3") + sinfo = sox_io_backend.info(path) assert sinfo.sample_rate == 16000 - assert sinfo.num_frames == 81216 + assert sinfo.num_frames == 0 assert sinfo.num_channels == 1 assert sinfo.bits_per_sample == 0 # bit_per_sample is irrelevant for compressed formats assert sinfo.encoding == "MP3" + with open(path, "rb") as fileobj: + sinfo = sox_io_backend.info(fileobj) + assert sinfo.sample_rate == 16000 + assert sinfo.num_frames == 0 + assert sinfo.num_channels == 1 + assert sinfo.bits_per_sample == 0 + assert sinfo.encoding == "MP3" + class FileObjTestBase(TempDirMixin): def _gen_file(self, ext, dtype, sample_rate, num_channels, num_frames, *, comments=None): @@ -355,6 +363,14 @@ def _gen_comment_file(self, comments): return comment_path +class Unseekable: + def __init__(self, fileobj): + self.fileobj = fileobj + + def read(self, n): + return self.fileobj.read(n) + + @skipIfNoSox @skipIfNoExec("sox") class TestFileObject(FileObjTestBase, PytorchTestCase): @@ -435,7 +451,7 @@ def test_fileobj_large_header(self, ext, dtype): num_channels = 2 comments = "metadata=" + " ".join(["value" for _ in range(1000)]) - with self.assertRaisesRegex(RuntimeError, "^Error loading audio file:"): + with self.assertRaises(RuntimeError): sinfo = self._query_fileobj(ext, dtype, sample_rate, num_channels, num_frames, comments=comments) with self._set_buffer_size(16384): @@ -545,7 +561,7 @@ def _query_http(self, ext, dtype, sample_rate, num_channels, num_frames): url = self.get_url(audio_file) format_ = ext if ext in ["mp3"] else None with requests.get(url, stream=True) as resp: - return sox_io_backend.info(resp.raw, format=format_) + return sox_io_backend.info(Unseekable(resp.raw), format=format_) @parameterized.expand( [ @@ -583,5 +599,5 @@ def test_info_fail(self): When attempted to get info on a non-existing file, error message must contain the file path. """ path = "non_existing_audio.wav" - with self.assertRaisesRegex(RuntimeError, "^Error loading audio file: failed to open file {0}$".format(path)): + with self.assertRaisesRegex(RuntimeError, path): sox_io_backend.info(path) diff --git a/test/torchaudio_unittest/backend/sox_io/load_test.py b/test/torchaudio_unittest/backend/sox_io/load_test.py index 7a6b251493d..d0fd4d6f90d 100644 --- a/test/torchaudio_unittest/backend/sox_io/load_test.py +++ b/test/torchaudio_unittest/backend/sox_io/load_test.py @@ -2,14 +2,18 @@ import itertools import tarfile +import torch +import torchaudio from parameterized import parameterized from torchaudio._internal import module_utils as _mod_utils from torchaudio.backend import sox_io_backend from torchaudio_unittest.common_utils import ( + ffmpeg_utils, get_asset_path, get_wav_data, HttpServerMixin, load_wav, + nested_params, PytorchTestCase, save_wav, skipIfNoExec, @@ -81,7 +85,10 @@ def assert_format( ) # 2. Convert to wav with sox wav_bit_depth = 32 if bit_depth == 24 else None # for 24-bit wav - sox_utils.convert_audio_file(path, ref_path, bit_depth=wav_bit_depth) + if format == "mp3": + ffmpeg_utils.convert_to_wav(path, ref_path) + else: + sox_utils.convert_audio_file(path, ref_path, bit_depth=wav_bit_depth) # 3. Load the given format with torchaudio data, sr = sox_io_backend.load(path, normalize=normalize) # 4. Load wav with scipy @@ -319,72 +326,90 @@ def test_amr_nb(self): self.assert_format("amr-nb", sample_rate=8000, num_channels=1, bit_depth=32, duration=1) -@skipIfNoExec("sox") @skipIfNoSox class TestLoadParams(TempDirMixin, PytorchTestCase): """Test the correctness of frame parameters of `sox_io_backend.load`""" - original = None - path = None + def _test(self, func, frame_offset, num_frames, channels_first, normalize): + original = get_wav_data("int16", num_channels=2, normalize=False) + path = self.get_temp_path("test.wav") + save_wav(path, original, sample_rate=8000) - def setUp(self): - super().setUp() - sample_rate = 8000 - self.original = get_wav_data("float32", num_channels=2) - self.path = self.get_temp_path("test.wav") - save_wav(self.path, self.original, sample_rate) + output, _ = func(path, frame_offset, num_frames, normalize, channels_first, None) + frame_end = None if num_frames == -1 else frame_offset + num_frames + expected = original[:, slice(frame_offset, frame_end)] + if not channels_first: + expected = expected.T + if normalize: + expected = expected.to(torch.float32) / (2**15) + self.assertEqual(output, expected) + + @nested_params( + [0, 1, 10, 100, 1000], + [-1, 1, 10, 100, 1000], + [True, False], + [True, False], + ) + def test_sox(self, frame_offset, num_frames, channels_first, normalize): + """The combination of properly changes the output tensor""" - @parameterized.expand( - list( - itertools.product( - [0, 1, 10, 100, 1000], - [-1, 1, 10, 100, 1000], - ) - ), - name_func=name_func, + self._test(torch.ops.torchaudio.sox_io_load_audio_file, frame_offset, num_frames, channels_first, normalize) + + # test file-like obj + def func(path, *args): + with open(path, "rb") as fileobj: + return torchaudio._torchaudio.load_audio_fileobj(fileobj, *args) + + self._test(func, frame_offset, num_frames, channels_first, normalize) + + @nested_params( + [0, 1, 10, 100, 1000], + [-1, 1, 10, 100, 1000], + [True, False], + [True, False], ) - def test_frame(self, frame_offset, num_frames): - """num_frames and frame_offset correctly specify the region of data""" - found, _ = sox_io_backend.load(self.path, frame_offset, num_frames) - frame_end = None if num_frames == -1 else frame_offset + num_frames - self.assertEqual(found, self.original[:, frame_offset:frame_end]) + def test_ffmpeg(self, frame_offset, num_frames, channels_first, normalize): + """The combination of properly changes the output tensor""" + self._test(torch.ops.torchaudio.ffmpeg_load_audio, frame_offset, num_frames, channels_first, normalize) - @parameterized.expand([(True,), (False,)], name_func=name_func) - def test_channels_first(self, channels_first): - """channels_first swaps axes""" - found, _ = sox_io_backend.load(self.path, channels_first=channels_first) - expected = self.original if channels_first else self.original.transpose(1, 0) - self.assertEqual(found, expected) + # test file-like obj + def func(path, *args): + with open(path, "rb") as fileobj: + return torchaudio._torchaudio_ffmpeg.load_audio_fileobj(fileobj, *args) + + self._test(func, frame_offset, num_frames, channels_first, normalize) @skipIfNoSox class TestLoadWithoutExtension(PytorchTestCase): def test_mp3(self): - """Providing format allows to read mp3 without extension - - libsox does not check header for mp3 + """MP3 file without extension can be loaded + Originally, we added `format` argument for this case, but now we use FFmpeg + for MP3 decoding, which works even without `format` argument. https://github.com/pytorch/audio/issues/1040 The file was generated with the following command ffmpeg -f lavfi -i "sine=frequency=1000:duration=5" -ar 16000 -f mp3 test_noext """ path = get_asset_path("mp3_without_ext") - _, sr = sox_io_backend.load(path, format="mp3") + _, sr = sox_io_backend.load(path) + assert sr == 16000 + + with open(path, "rb") as fileobj: + _, sr = sox_io_backend.load(fileobj) assert sr == 16000 class CloggedFileObj: def __init__(self, fileobj): self.fileobj = fileobj - self.buffer = b"" - def read(self, n): - if not self.buffer: - self.buffer += self.fileobj.read(n) - ret = self.buffer[:2] - self.buffer = self.buffer[2:] - return ret + def read(self, _): + return self.fileobj.read(2) + + def seek(self, offset, whence): + return self.fileobj.seek(offset, whence) @skipIfNoSox @@ -557,6 +582,14 @@ def test_tarfile(self, ext, kwargs): self.assertEqual(expected, found) +class Unseekable: + def __init__(self, fileobj): + self.fileobj = fileobj + + def read(self, n): + return self.fileobj.read(n) + + @skipIfNoSox @skipIfNoExec("sox") @skipIfNoModule("requests") @@ -587,10 +620,11 @@ def test_requests(self, ext, kwargs): url = self.get_url(audio_file) with requests.get(url, stream=True) as resp: - found, sr = sox_io_backend.load(resp.raw, format=format_) + found, sr = sox_io_backend.load(Unseekable(resp.raw), format=format_) assert sr == sample_rate - self.assertEqual(expected, found) + if ext != "mp3": + self.assertEqual(expected, found) @parameterized.expand( list( @@ -627,5 +661,5 @@ def test_load_fail(self): When attempted to load a non-existing file, error message must contain the file path. """ path = "non_existing_audio.wav" - with self.assertRaisesRegex(RuntimeError, "^Error loading audio file: failed to open file {0}$".format(path)): + with self.assertRaisesRegex(RuntimeError, path): sox_io_backend.load(path) diff --git a/test/torchaudio_unittest/backend/sox_io/save_test.py b/test/torchaudio_unittest/backend/sox_io/save_test.py index 848bf413108..59e2ff4678c 100644 --- a/test/torchaudio_unittest/backend/sox_io/save_test.py +++ b/test/torchaudio_unittest/backend/sox_io/save_test.py @@ -6,6 +6,7 @@ from parameterized import parameterized from torchaudio.backend import sox_io_backend from torchaudio_unittest.common_utils import ( + ffmpeg_utils, get_wav_data, load_wav, nested_params, @@ -130,7 +131,10 @@ def assert_save_consistency( else: raise ValueError(f"Unexpected test mode: {test_mode}") # 2.2. Convert the target format to wav with sox - sox_utils.convert_audio_file(tgt_path, tst_path, encoding=cmp_encoding, bit_depth=cmp_bit_depth) + if format == "mp3": + ffmpeg_utils.convert_to_wav(tgt_path, tst_path) + else: + sox_utils.convert_audio_file(tgt_path, tst_path, encoding=cmp_encoding, bit_depth=cmp_bit_depth) # 2.3. Load with SciPy found = load_wav(tst_path, normalize=False)[0] @@ -140,7 +144,10 @@ def assert_save_consistency( src_path, sox_path, compression=compression, encoding=sox_encoding, bit_depth=bits_per_sample ) # 3.2. Convert the target format to wav with sox - sox_utils.convert_audio_file(sox_path, ref_path, encoding=cmp_encoding, bit_depth=cmp_bit_depth) + if format == "mp3": + ffmpeg_utils.convert_to_wav(sox_path, ref_path) + else: + sox_utils.convert_audio_file(sox_path, ref_path, encoding=cmp_encoding, bit_depth=cmp_bit_depth) # 3.3. Load with SciPy expected = load_wav(ref_path, normalize=False)[0] @@ -437,5 +444,5 @@ def test_save_fail(self): When attempted to save into a non-existing dir, error message must contain the file path. """ path = os.path.join("non_existing_directory", "foo.wav") - with self.assertRaisesRegex(RuntimeError, "^Error saving audio file: failed to open file {0}$".format(path)): + with self.assertRaisesRegex(RuntimeError, path): sox_io_backend.save(path, torch.zeros(1, 1), 8000) diff --git a/test/torchaudio_unittest/backend/sox_io/smoke_test.py b/test/torchaudio_unittest/backend/sox_io/smoke_test.py index b3e39b61cb2..4329209bc85 100644 --- a/test/torchaudio_unittest/backend/sox_io/smoke_test.py +++ b/test/torchaudio_unittest/backend/sox_io/smoke_test.py @@ -1,11 +1,8 @@ import io import itertools -import unittest from parameterized import parameterized -from torchaudio._internal.module_utils import is_sox_available from torchaudio.backend import sox_io_backend -from torchaudio.utils import sox_utils from torchaudio_unittest.common_utils import ( get_wav_data, skipIfNoSox, @@ -16,12 +13,6 @@ from .common import name_func -skipIfNoMP3 = unittest.skipIf( - not is_sox_available() or "mp3" not in sox_utils.list_read_formats() or "mp3" not in sox_utils.list_write_formats(), - '"sox_io" backend does not support MP3', -) - - @skipIfNoSox class SmokeTest(TempDirMixin, TorchaudioTestCase): """Run smoke test on various audio format @@ -73,7 +64,6 @@ def test_wav(self, dtype, sample_rate, num_channels): ) ) ) - @skipIfNoMP3 def test_mp3(self, sample_rate, num_channels, bit_rate): """Run smoke test on mp3 format""" self.run_smoke_test("mp3", sample_rate, num_channels, compression=bit_rate) @@ -159,7 +149,6 @@ def test_wav(self, dtype, sample_rate, num_channels): ) ) ) - @skipIfNoMP3 def test_mp3(self, sample_rate, num_channels, bit_rate): """Run smoke test on mp3 format""" self.run_smoke_test("mp3", sample_rate, num_channels, compression=bit_rate) diff --git a/test/torchaudio_unittest/common_utils/ffmpeg_utils.py b/test/torchaudio_unittest/common_utils/ffmpeg_utils.py new file mode 100644 index 00000000000..0f91ac7f1dd --- /dev/null +++ b/test/torchaudio_unittest/common_utils/ffmpeg_utils.py @@ -0,0 +1,10 @@ +import subprocess +import sys + + +def convert_to_wav(src_path, dst_path): + """Convert audio file with `ffmpeg` command.""" + # TODO: parameterize codec + command = ["ffmpeg", "-y", "-i", src_path, "-c:a", "pcm_f32le", dst_path] + print(" ".join(command), file=sys.stderr) + subprocess.run(command, check=True) diff --git a/test/torchaudio_unittest/io/stream_reader_test.py b/test/torchaudio_unittest/io/stream_reader_test.py index 4e05ba056f6..8f0e61a676b 100644 --- a/test/torchaudio_unittest/io/stream_reader_test.py +++ b/test/torchaudio_unittest/io/stream_reader_test.py @@ -360,6 +360,20 @@ def test_seek_negative(self): s.seek(-1.0) +def _to_fltp(original): + denom = { + torch.uint8: 2**7, + torch.int16: 2**15, + torch.int32: 2**31, + }[original.dtype] + + fltp = original.to(torch.float32) + if original.dtype == torch.uint8: + fltp -= 128 + fltp /= denom + return fltp + + @skipIfNoFFmpeg @_media_source class StreamReaderAudioTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestCase): @@ -399,9 +413,15 @@ def test_basic_audio_stream(self, dtype, num_channels): # provide the matching dtype self._test_wav(src, original, fmt=fmt) - if not self.test_fileobj: - # use the internal dtype ffmpeg picks - self._test_wav(src, original, fmt=None) + # use the internal dtype ffmpeg picks + if self.test_fileobj: + src.seek(0) + self._test_wav(src, original, fmt=None) + # convert to float32 + expected = _to_fltp(original) + if self.test_fileobj: + src.seek(0) + self._test_wav(src, expected, fmt="fltp") @nested_params( ["int16", "uint8", "int32"], # "float", "double", "int64"] diff --git a/torchaudio/backend/sox_io_backend.py b/torchaudio/backend/sox_io_backend.py index 6b10822fc26..ab99a8139ff 100644 --- a/torchaudio/backend/sox_io_backend.py +++ b/torchaudio/backend/sox_io_backend.py @@ -8,6 +8,52 @@ from .common import AudioMetaData +# Note: need to comply TorchScript syntax -- need annotation and no f-string +def _alt_info(filepath: str, format: Optional[str]) -> AudioMetaData: + return AudioMetaData(*torch.ops.torchaudio.ffmpeg_get_audio_info(filepath, format)) + + +def _alt_info_fileobj(fileobj, format: Optional[str]) -> AudioMetaData: + return AudioMetaData(*torchaudio._torchaudio_ffmpeg.get_audio_info_fileobj(fileobj, format)) + + +# Note: need to comply TorchScript syntax -- need annotation and no f-string +def _fail_info(filepath: str, format: Optional[str]) -> AudioMetaData: + raise RuntimeError("Failed to fetch metadata from {}".format(filepath)) + + +def _fail_info_fileobj(fileobj, format: Optional[str]) -> AudioMetaData: + raise RuntimeError("Failed to fetch metadata from {}".format(fileobj)) + + +# Note: need to comply TorchScript syntax -- need annotation and no f-string +def _fail_load( + filepath: str, + frame_offset: int = 0, + num_frames: int = -1, + normalize: bool = True, + channels_first: bool = True, + format: Optional[str] = None, +) -> Tuple[torch.Tensor, int]: + raise RuntimeError("Failed to load audio from {}".format(filepath)) + + +def _fail_load_fileobj(fileobj, *args, **kwargs): + raise RuntimeError(f"Failed to load audio from {fileobj}") + + +if torchaudio._extension._FFMPEG_INITIALIZED: + _fallback_info = _alt_info + _fallback_info_fileobj = _alt_info_fileobj + _fallback_load = torch.ops.torchaudio.ffmpeg_load_audio + _fallback_load_fileobj = torchaudio._torchaudio_ffmpeg.load_audio_fileobj +else: + _fallback_info = _fail_info + _fallback_info_fileobj = _fail_info_fileobj + _fallback_load = _fail_load + _fallback_load_filebj = _fail_load_fileobj + + @_mod_utils.requires_sox() def info( filepath: str, @@ -46,11 +92,14 @@ def info( if not torch.jit.is_scripting(): if hasattr(filepath, "read"): sinfo = torchaudio._torchaudio.get_info_fileobj(filepath, format) - return AudioMetaData(*sinfo) + if sinfo is not None: + return AudioMetaData(*sinfo) + return _fallback_info_fileobj(filepath, format) filepath = os.fspath(filepath) sinfo = torch.ops.torchaudio.sox_io_get_info(filepath, format) - assert sinfo is not None # for TorchScript compatibility - return AudioMetaData(*sinfo) + if sinfo is not None: + return AudioMetaData(*sinfo) + return _fallback_info(filepath, format) @_mod_utils.requires_sox() @@ -145,15 +194,19 @@ def load( """ if not torch.jit.is_scripting(): if hasattr(filepath, "read"): - return torchaudio._torchaudio.load_audio_fileobj( + ret = torchaudio._torchaudio.load_audio_fileobj( filepath, frame_offset, num_frames, normalize, channels_first, format ) + if ret is not None: + return ret + return _fallback_load_fileobj(filepath, frame_offset, num_frames, normalize, channels_first, format) filepath = os.fspath(filepath) ret = torch.ops.torchaudio.sox_io_load_audio_file( filepath, frame_offset, num_frames, normalize, channels_first, format ) - assert ret is not None # for TorchScript compatibility - return ret + if ret is not None: + return ret + return _fallback_load(filepath, frame_offset, num_frames, normalize, channels_first, format) @_mod_utils.requires_sox() diff --git a/torchaudio/csrc/ffmpeg/pybind/pybind.cpp b/torchaudio/csrc/ffmpeg/pybind/pybind.cpp index 46e633262c1..1a7e18dffb9 100644 --- a/torchaudio/csrc/ffmpeg/pybind/pybind.cpp +++ b/torchaudio/csrc/ffmpeg/pybind/pybind.cpp @@ -7,6 +7,14 @@ namespace ffmpeg { namespace { PYBIND11_MODULE(_torchaudio_ffmpeg, m) { + m.def( + "load_audio_fileobj", + &torchaudio::ffmpeg::load_audio_fileobj, + "Load audio from file object."); + m.def( + "get_audio_info_fileobj", + &torchaudio::ffmpeg::get_audio_info_fileobj, + "Get metadata of audio in file object."); py::class_>( m, "StreamReaderFileObj") .def(py::init< diff --git a/torchaudio/csrc/ffmpeg/pybind/stream_reader.cpp b/torchaudio/csrc/ffmpeg/pybind/stream_reader.cpp index ede150c774e..e0143b16074 100644 --- a/torchaudio/csrc/ffmpeg/pybind/stream_reader.cpp +++ b/torchaudio/csrc/ffmpeg/pybind/stream_reader.cpp @@ -85,5 +85,36 @@ StreamReaderFileObj::StreamReaderFileObj( option.value_or(OptionDict{}), pAVIO)) {} +std::tuple, int64_t> load_audio_fileobj( + py::object fileobj, + const c10::optional& frame_offset, + const c10::optional& num_frames, + bool convert, + bool channels_first, + const c10::optional& format) { + FileObj f{fileobj, 4086}; + return load_audio( + get_input_format_context( + static_cast(py::str(fileobj.attr("__str__")())), + format, + {}, + f.pAVIO), + frame_offset, + num_frames, + convert, + channels_first); +} + +MetaDataTuple get_audio_info_fileobj( + py::object fileobj, + const c10::optional& format) { + FileObj f{fileobj, 4086}; + return get_audio_info(get_input_format_context( + static_cast(py::str(fileobj.attr("__str__")())), + format, + {}, + f.pAVIO)); +} + } // namespace ffmpeg } // namespace torchaudio diff --git a/torchaudio/csrc/ffmpeg/pybind/stream_reader.h b/torchaudio/csrc/ffmpeg/pybind/stream_reader.h index 7b12ae5c020..bbda5374e18 100644 --- a/torchaudio/csrc/ffmpeg/pybind/stream_reader.h +++ b/torchaudio/csrc/ffmpeg/pybind/stream_reader.h @@ -24,5 +24,20 @@ class StreamReaderFileObj : protected FileObj, public StreamReaderBinding { int64_t buffer_size); }; +std::tuple, int64_t> load_audio_fileobj( + py::object fileobj, + const c10::optional& frame_offset, + const c10::optional& num_frames, + bool convert, + bool channels_first, + const c10::optional& format); + +using MetaDataTuple = + std::tuple; + +MetaDataTuple get_audio_info_fileobj( + py::object fileobj, + const c10::optional& format); + } // namespace ffmpeg } // namespace torchaudio diff --git a/torchaudio/csrc/ffmpeg/stream_reader_binding.cpp b/torchaudio/csrc/ffmpeg/stream_reader_binding.cpp index 32635f72cc0..b1818d35cd7 100644 --- a/torchaudio/csrc/ffmpeg/stream_reader_binding.cpp +++ b/torchaudio/csrc/ffmpeg/stream_reader_binding.cpp @@ -26,16 +26,25 @@ c10::intrusive_ptr init( get_input_format_context(src, device, map(option))); } -std::tuple, int64_t> load(const std::string& src) { - StreamReaderBinding s{get_input_format_context(src, {}, {})}; - int i = static_cast(s.find_best_audio_stream()); - auto sinfo = s.StreamReader::get_src_stream_info(i); - int64_t sample_rate = static_cast(sinfo.sample_rate); - s.add_audio_stream(i, -1, -1, {}, {}, {}); - s.process_all_packets(); - auto tensors = s.pop_chunks(); - assert(tensors.size() > 0); - return std::make_tuple<>(tensors[0], sample_rate); +std::tuple load_audio( + const std::string& src, + const c10::optional& frame_offset, + const c10::optional& num_frames, + bool convert, + bool channels_first, + const c10::optional& format) { + return load_audio( + get_input_format_context(src, format, {}), + frame_offset, + num_frames, + convert, + channels_first); +} + +MetaDataTuple get_audio_info( + const std::string& src, + const c10::optional& format) { + return get_audio_info(get_input_format_context(src, format, {})); } using S = const c10::intrusive_ptr&; @@ -47,7 +56,8 @@ TORCH_LIBRARY_FRAGMENT(torchaudio, m) { av_log_set_level(AV_LOG_ERROR); } }); - m.def("torchaudio::ffmpeg_load", load); + m.def("torchaudio::ffmpeg_load_audio", load_audio); + m.def("torchaudio::ffmpeg_get_audio_info", get_audio_info); m.class_("ffmpeg_StreamReader") .def(torch::init<>(init)) .def("num_src_streams", [](S self) { return self->num_src_streams(); }) diff --git a/torchaudio/csrc/ffmpeg/stream_reader_wrapper.cpp b/torchaudio/csrc/ffmpeg/stream_reader_wrapper.cpp index 4675098f3bb..cd8bf854dc4 100644 --- a/torchaudio/csrc/ffmpeg/stream_reader_wrapper.cpp +++ b/torchaudio/csrc/ffmpeg/stream_reader_wrapper.cpp @@ -60,5 +60,82 @@ void StreamReaderBinding::process_all_packets() { } while (!ret); } +namespace { + +c10::optional get_load_filter( + const c10::optional& frame_offset, + const c10::optional& num_frames, + bool convert) { + if (!frame_offset && !num_frames && !convert) { + return {}; + } + std::string aformat = "aformat=sample_fmts=fltp'"; + if (!frame_offset && !num_frames && convert) { + return {aformat}; + } + + // At least one of frame_offset or num_frames is present + auto atrim = [&]() -> std::string { + std::vector parts; + if (frame_offset && frame_offset.value() > 0) { + parts.emplace_back( + "start_sample=" + std::to_string(frame_offset.value())); + } + if (num_frames && num_frames.value() > 0) { + auto offset = frame_offset.value_or(0); + parts.emplace_back( + "end_sample=" + std::to_string(offset + num_frames.value())); + } + return {"atrim=" + c10::Join(":", parts)}; + }(); + + if (!convert) { + return {atrim}; + } + return {c10::Join(",", std::vector{atrim, aformat})}; +} + +} // namespace + +std::tuple load_audio( + AVFormatContextPtr&& p, + const c10::optional& frame_offset, + const c10::optional& num_frames, + bool convert, + bool channels_first) { + StreamReaderBinding s{std::move(p)}; + int i = static_cast(s.find_best_audio_stream()); + auto sinfo = s.StreamReader::get_src_stream_info(i); + int64_t sample_rate = static_cast(sinfo.sample_rate); + s.add_audio_stream( + i, -1, -1, get_load_filter(frame_offset, num_frames, convert), {}, {}); + s.process_all_packets(); + auto chunk = s.pop_chunks()[0]; + if (!chunk) { + throw std::runtime_error("Failed to decode an audio."); + } + auto tensor = chunk.value(); + if (channels_first) { + tensor = tensor.transpose(1, 0); + } + return std::make_tuple<>(tensor, sample_rate); +} + +MetaDataTuple get_audio_info(AVFormatContextPtr&& p) { + StreamReaderBinding s{std::move(p)}; + int i = static_cast(s.find_best_audio_stream()); + auto sinfo = s.StreamReader::get_src_stream_info(i); + std::string cdc{sinfo.codec_name}; + std::transform(cdc.begin(), cdc.end(), cdc.begin(), [](unsigned char c) { + return std::toupper(c); + }); + return std::make_tuple( + static_cast(sinfo.sample_rate), + sinfo.num_frames, + sinfo.num_channels, + sinfo.bits_per_sample, + cdc); +} + } // namespace ffmpeg } // namespace torchaudio diff --git a/torchaudio/csrc/ffmpeg/stream_reader_wrapper.h b/torchaudio/csrc/ffmpeg/stream_reader_wrapper.h index fc4e3acce4c..0b9a734dd3d 100644 --- a/torchaudio/csrc/ffmpeg/stream_reader_wrapper.h +++ b/torchaudio/csrc/ffmpeg/stream_reader_wrapper.h @@ -42,5 +42,24 @@ struct StreamReaderBinding : public StreamReader, void process_all_packets(); }; +// These are temporary implementations, to be used as fallback for sox_io +// backend. +// +// When we implement FFmpeg-based equivalents, we should not be constrained on +// these parameters and revise interface. (for example, resampling should +// be part of parameter otherwise frame_offset and num_frames are not +// fully useful) +std::tuple load_audio( + AVFormatContextPtr&& p, + const c10::optional& frame_offset, + const c10::optional& num_frames, + bool convert, + bool channels_first); + +using MetaDataTuple = + std::tuple; + +MetaDataTuple get_audio_info(AVFormatContextPtr&& p); + } // namespace ffmpeg } // namespace torchaudio diff --git a/torchaudio/csrc/pybind/sox/effects.cpp b/torchaudio/csrc/pybind/sox/effects.cpp index 8ef82d5fd75..f5a6bd6ba08 100644 --- a/torchaudio/csrc/pybind/sox/effects.cpp +++ b/torchaudio/csrc/pybind/sox/effects.cpp @@ -83,7 +83,10 @@ auto apply_effects_fileobj( /*filetype=*/format.has_value() ? format.value().c_str() : nullptr)); // In case of streamed data, length can be 0 - validate_input_memfile(sf); + if (static_cast(sf) == nullptr || + sf->encoding.encoding == SOX_ENCODING_UNKNOWN) { + return {}; + } // Prepare output buffer std::vector out_buffer; diff --git a/torchaudio/csrc/pybind/sox/io.cpp b/torchaudio/csrc/pybind/sox/io.cpp index 2c87663ecb5..cb595202351 100644 --- a/torchaudio/csrc/pybind/sox/io.cpp +++ b/torchaudio/csrc/pybind/sox/io.cpp @@ -60,8 +60,10 @@ auto get_info_fileobj(py::object fileobj, c10::optional format) /*encoding=*/nullptr, /*filetype=*/format.has_value() ? format.value().c_str() : nullptr)); - // In case of streamed data, length can be 0 - validate_input_memfile(sf); + if (static_cast(sf) == nullptr || + sf->encoding.encoding == SOX_ENCODING_UNKNOWN) { + return c10::optional{}; + } return std::forward_as_tuple( static_cast(sf->signal.rate), diff --git a/torchaudio/csrc/sox/effects.cpp b/torchaudio/csrc/sox/effects.cpp index fe806b05981..dcf5f9decdc 100644 --- a/torchaudio/csrc/sox/effects.cpp +++ b/torchaudio/csrc/sox/effects.cpp @@ -103,7 +103,10 @@ auto apply_effects_file( /*encoding=*/nullptr, /*filetype=*/format.has_value() ? format.value().c_str() : nullptr)); - validate_input_file(sf, path); + if (static_cast(sf) == nullptr || + sf->encoding.encoding == SOX_ENCODING_UNKNOWN) { + return {}; + } const auto dtype = get_dtype(sf->encoding.encoding, sf->signal.precision); diff --git a/torchaudio/csrc/sox/io.cpp b/torchaudio/csrc/sox/io.cpp index 6876c7d310a..1c4f4771c0a 100644 --- a/torchaudio/csrc/sox/io.cpp +++ b/torchaudio/csrc/sox/io.cpp @@ -19,7 +19,11 @@ c10::optional get_info_file( /*encoding=*/nullptr, /*filetype=*/format.has_value() ? format.value().c_str() : nullptr)); - validate_input_file(sf, path); + if (static_cast(sf) == nullptr || + sf->encoding.encoding == SOX_ENCODING_UNKNOWN) { + return {}; + } + return std::forward_as_tuple( static_cast(sf->signal.rate), static_cast(sf->signal.length / sf->signal.channels), diff --git a/torchaudio/csrc/sox/utils.cpp b/torchaudio/csrc/sox/utils.cpp index 2208ed2074c..d8d14dd6d4f 100644 --- a/torchaudio/csrc/sox/utils.cpp +++ b/torchaudio/csrc/sox/utils.cpp @@ -95,10 +95,6 @@ void validate_input_file(const SoxFormat& sf, const std::string& path) { } } -void validate_input_memfile(const SoxFormat& sf) { - return validate_input_file(sf, ""); -} - void validate_input_tensor(const torch::Tensor tensor) { if (!tensor.device().is_cpu()) { throw std::runtime_error("Input tensor has to be on CPU."); diff --git a/torchaudio/csrc/sox/utils.h b/torchaudio/csrc/sox/utils.h index 73b76e71b92..ca84b600432 100644 --- a/torchaudio/csrc/sox/utils.h +++ b/torchaudio/csrc/sox/utils.h @@ -52,13 +52,6 @@ struct SoxFormat { sox_format_t* fd_; }; -/// -/// Verify that input file is found, has known encoding, and not empty -void validate_input_file(const SoxFormat& sf, const std::string& path); - -/// Verify that input memory buffer has known encoding, and not empty -void validate_input_memfile(const SoxFormat& sf); - /// /// Verify that input Tensor is 2D, CPU and either uin8, int16, int32 or float32 void validate_input_tensor(const torch::Tensor); diff --git a/torchaudio/utils/__init__.py b/torchaudio/utils/__init__.py index 87761d6b985..90874dd2253 100644 --- a/torchaudio/utils/__init__.py +++ b/torchaudio/utils/__init__.py @@ -4,7 +4,7 @@ from .download import download_asset if _mod_utils.is_sox_available(): - sox_utils.set_verbosity(1) + sox_utils.set_verbosity(0) __all__ = [