Skip to content

Commit

Permalink
Use FFmpeg-based I/O as fallback in sox_io backend (#2419)
Browse files Browse the repository at this point in the history
Summary:
This commit add fallback mechanism to `info` and `load` functions of sox_io backend.
If torchaudio is compiled to use FFmpeg, and runtime dependencies are properly loaded,
in case `info` and `load` fail, it fallback to FFmpeg-based implementation.

BC-breaking changes:
 - FFmpeg does not report the number of frames for MP3, this is because MP3 does not store the information of the number of frames. It can be estimated from the audio duration and sample rate, but it might be inaccurate, so we keep it 0.

Depends on
- #2416
- #2417
- #2418
- #2423
- #2427

Pull Request resolved: #2419

Reviewed By: carolineechen

Differential Revision: D36740306

Pulled By: mthrok

fbshipit-source-id: 9e2ad095b8b39e41404970de0d8d9b5aaa856c97
  • Loading branch information
mthrok authored and facebook-github-bot committed Jun 2, 2022
1 parent a61b90c commit 19c60a0
Show file tree
Hide file tree
Showing 8 changed files with 220 additions and 134 deletions.
30 changes: 23 additions & 7 deletions test/torchaudio_unittest/backend/sox_io/info_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,23 +312,31 @@ def test_opus(self, bitrate, num_channels, compression_level):
@skipIfNoSox
class TestLoadWithoutExtension(PytorchTestCase):
def test_mp3(self):
"""Providing `format` allows to read mp3 without extension
libsox does not check header for mp3
"""MP3 file without extension can be loaded
Originally, we added `format` argument for this case, but now we use FFmpeg
for MP3 decoding, which works even without `format` argument.
https://github.com/pytorch/audio/issues/1040
The file was generated with the following command
ffmpeg -f lavfi -i "sine=frequency=1000:duration=5" -ar 16000 -f mp3 test_noext
"""
path = get_asset_path("mp3_without_ext")
sinfo = sox_io_backend.info(path, format="mp3")
sinfo = sox_io_backend.info(path)
assert sinfo.sample_rate == 16000
assert sinfo.num_frames == 81216
assert sinfo.num_frames == 0
assert sinfo.num_channels == 1
assert sinfo.bits_per_sample == 0 # bit_per_sample is irrelevant for compressed formats
assert sinfo.encoding == "MP3"

with open(path, "rb") as fileobj:
sinfo = sox_io_backend.info(fileobj)
assert sinfo.sample_rate == 16000
assert sinfo.num_frames == 0
assert sinfo.num_channels == 1
assert sinfo.bits_per_sample == 0
assert sinfo.encoding == "MP3"


class FileObjTestBase(TempDirMixin):
def _gen_file(self, ext, dtype, sample_rate, num_channels, num_frames, *, comments=None):
Expand All @@ -355,6 +363,14 @@ def _gen_comment_file(self, comments):
return comment_path


class Unseekable:
def __init__(self, fileobj):
self.fileobj = fileobj

def read(self, n):
return self.fileobj.read(n)


@skipIfNoSox
@skipIfNoExec("sox")
class TestFileObject(FileObjTestBase, PytorchTestCase):
Expand Down Expand Up @@ -435,7 +451,7 @@ def test_fileobj_large_header(self, ext, dtype):
num_channels = 2
comments = "metadata=" + " ".join(["value" for _ in range(1000)])

with self.assertRaisesRegex(RuntimeError, "Failed to fetch metadata from"):
with self.assertRaises(RuntimeError):
sinfo = self._query_fileobj(ext, dtype, sample_rate, num_channels, num_frames, comments=comments)

with self._set_buffer_size(16384):
Expand Down Expand Up @@ -545,7 +561,7 @@ def _query_http(self, ext, dtype, sample_rate, num_channels, num_frames):
url = self.get_url(audio_file)
format_ = ext if ext in ["mp3"] else None
with requests.get(url, stream=True) as resp:
return sox_io_backend.info(resp.raw, format=format_)
return sox_io_backend.info(Unseekable(resp.raw), format=format_)

@parameterized.expand(
[
Expand Down
143 changes: 73 additions & 70 deletions test/torchaudio_unittest/backend/sox_io/load_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import itertools
import tarfile

import torch
import torchaudio
from parameterized import parameterized
from torchaudio._internal import module_utils as _mod_utils
from torchaudio.backend import sox_io_backend
Expand All @@ -10,6 +12,7 @@
get_wav_data,
HttpServerMixin,
load_wav,
nested_params,
PytorchTestCase,
save_wav,
skipIfNoExec,
Expand Down Expand Up @@ -169,35 +172,6 @@ def test_multiple_channels(self, dtype, num_channels):
normalize = False
self.assert_wav(dtype, sample_rate, num_channels, normalize, duration=1)

@parameterized.expand(
list(
itertools.product(
[8000, 16000, 44100],
[1, 2],
[96, 128, 160, 192, 224, 256, 320],
)
),
name_func=name_func,
)
def test_mp3(self, sample_rate, num_channels, bit_rate):
"""`sox_io_backend.load` can load mp3 format correctly."""
self.assert_format("mp3", sample_rate, num_channels, compression=bit_rate, duration=1, atol=5e-05)

@parameterized.expand(
list(
itertools.product(
[16000],
[2],
[128],
)
),
name_func=name_func,
)
def test_mp3_large(self, sample_rate, num_channels, bit_rate):
"""`sox_io_backend.load` can load large mp3 file correctly."""
two_hours = 2 * 60 * 60
self.assert_format("mp3", sample_rate, num_channels, compression=bit_rate, duration=two_hours, atol=5e-05)

@parameterized.expand(
list(
itertools.product(
Expand Down Expand Up @@ -319,72 +293,92 @@ def test_amr_nb(self):
self.assert_format("amr-nb", sample_rate=8000, num_channels=1, bit_depth=32, duration=1)


@skipIfNoExec("sox")
@skipIfNoSox
class TestLoadParams(TempDirMixin, PytorchTestCase):
"""Test the correctness of frame parameters of `sox_io_backend.load`"""

original = None
path = None
def _test(self, func, frame_offset, num_frames, channels_first, normalize):
original = get_wav_data("int16", num_channels=2, normalize=False)
path = self.get_temp_path("test.wav")
save_wav(path, original, sample_rate=8000)

def setUp(self):
super().setUp()
sample_rate = 8000
self.original = get_wav_data("float32", num_channels=2)
self.path = self.get_temp_path("test.wav")
save_wav(self.path, self.original, sample_rate)
output, _ = func(path, frame_offset, num_frames, normalize, channels_first, None)
frame_end = None if num_frames == -1 else frame_offset + num_frames
expected = original[:, slice(frame_offset, frame_end)]
if not channels_first:
expected = expected.T
if normalize:
expected = expected.to(torch.float32) / (2**15)
self.assertEqual(output, expected)

@nested_params(
[0, 1, 10, 100, 1000],
[-1, 1, 10, 100, 1000],
[True, False],
[True, False],
)
def test_sox(self, frame_offset, num_frames, channels_first, normalize):
"""The combination of properly changes the output tensor"""

@parameterized.expand(
list(
itertools.product(
[0, 1, 10, 100, 1000],
[-1, 1, 10, 100, 1000],
)
),
name_func=name_func,
self._test(torch.ops.torchaudio.sox_io_load_audio_file, frame_offset, num_frames, channels_first, normalize)

# test file-like obj
def func(path, *args):
with open(path, "rb") as fileobj:
return torchaudio._torchaudio.load_audio_fileobj(fileobj, *args)

self._test(func, frame_offset, num_frames, channels_first, normalize)

@nested_params(
[0, 1, 10, 100, 1000],
[-1, 1, 10, 100, 1000],
[True, False],
[True, False],
)
def test_frame(self, frame_offset, num_frames):
"""num_frames and frame_offset correctly specify the region of data"""
found, _ = sox_io_backend.load(self.path, frame_offset, num_frames)
frame_end = None if num_frames == -1 else frame_offset + num_frames
self.assertEqual(found, self.original[:, frame_offset:frame_end])
def test_ffmpeg(self, frame_offset, num_frames, channels_first, normalize):
"""The combination of properly changes the output tensor"""
from torchaudio.io._compat import load_audio, load_audio_fileobj

self._test(load_audio, frame_offset, num_frames, channels_first, normalize)

# test file-like obj
def func(path, *args):
with open(path, "rb") as fileobj:
return load_audio_fileobj(fileobj, *args)

@parameterized.expand([(True,), (False,)], name_func=name_func)
def test_channels_first(self, channels_first):
"""channels_first swaps axes"""
found, _ = sox_io_backend.load(self.path, channels_first=channels_first)
expected = self.original if channels_first else self.original.transpose(1, 0)
self.assertEqual(found, expected)
self._test(func, frame_offset, num_frames, channels_first, normalize)


@skipIfNoSox
class TestLoadWithoutExtension(PytorchTestCase):
def test_mp3(self):
"""Providing format allows to read mp3 without extension
libsox does not check header for mp3
"""MP3 file without extension can be loaded
Originally, we added `format` argument for this case, but now we use FFmpeg
for MP3 decoding, which works even without `format` argument.
https://github.com/pytorch/audio/issues/1040
The file was generated with the following command
ffmpeg -f lavfi -i "sine=frequency=1000:duration=5" -ar 16000 -f mp3 test_noext
"""
path = get_asset_path("mp3_without_ext")
_, sr = sox_io_backend.load(path, format="mp3")
_, sr = sox_io_backend.load(path)
assert sr == 16000

with open(path, "rb") as fileobj:
_, sr = sox_io_backend.load(fileobj)
assert sr == 16000


class CloggedFileObj:
def __init__(self, fileobj):
self.fileobj = fileobj
self.buffer = b""

def read(self, n):
if not self.buffer:
self.buffer += self.fileobj.read(n)
ret = self.buffer[:2]
self.buffer = self.buffer[2:]
return ret
def read(self, _):
return self.fileobj.read(2)

def seek(self, offset, whence):
return self.fileobj.seek(offset, whence)


@skipIfNoSox
Expand Down Expand Up @@ -557,6 +551,14 @@ def test_tarfile(self, ext, kwargs):
self.assertEqual(expected, found)


class Unseekable:
def __init__(self, fileobj):
self.fileobj = fileobj

def read(self, n):
return self.fileobj.read(n)


@skipIfNoSox
@skipIfNoExec("sox")
@skipIfNoModule("requests")
Expand Down Expand Up @@ -587,10 +589,11 @@ def test_requests(self, ext, kwargs):

url = self.get_url(audio_file)
with requests.get(url, stream=True) as resp:
found, sr = sox_io_backend.load(resp.raw, format=format_)
found, sr = sox_io_backend.load(Unseekable(resp.raw), format=format_)

assert sr == sample_rate
self.assertEqual(expected, found)
if ext != "mp3":
self.assertEqual(expected, found)

@parameterized.expand(
list(
Expand Down
29 changes: 1 addition & 28 deletions test/torchaudio_unittest/backend/sox_io/save_test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import io
import os
import unittest

import torch
from parameterized import parameterized
Expand Down Expand Up @@ -179,31 +178,6 @@ def test_save_wav_dtype(self, test_mode, params):
(dtype,) = params
self.assert_save_consistency("wav", src_dtype=dtype, test_mode=test_mode)

@nested_params(
["path", "fileobj", "bytesio"],
[
None,
-4.2,
-0.2,
0,
0.2,
96,
128,
160,
192,
224,
256,
320,
],
)
def test_save_mp3(self, test_mode, bit_rate):
if test_mode in ["fileobj", "bytesio"]:
if bit_rate is not None and bit_rate < 1:
raise unittest.SkipTest(
"mp3 format with variable bit rate is known to " "not yield the exact same result as sox command."
)
self.assert_save_consistency("mp3", compression=bit_rate, test_mode=test_mode)

@nested_params(
["path", "fileobj", "bytesio"],
[8, 16, 24],
Expand Down Expand Up @@ -349,7 +323,6 @@ def test_save_gsm(self, test_mode):
@parameterized.expand(
[
("wav", "PCM_S", 16),
("mp3",),
("flac",),
("vorbis",),
("sph", "PCM_S", 16),
Expand Down Expand Up @@ -437,5 +410,5 @@ def test_save_fail(self):
When attempted to save into a non-existing dir, error message must contain the file path.
"""
path = os.path.join("non_existing_directory", "foo.wav")
with self.assertRaisesRegex(RuntimeError, "^Error saving audio file: failed to open file {0}$".format(path)):
with self.assertRaisesRegex(RuntimeError, path):
sox_io_backend.save(path, torch.zeros(1, 1), 8000)
11 changes: 0 additions & 11 deletions test/torchaudio_unittest/backend/sox_io/smoke_test.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
import io
import itertools
import unittest

from parameterized import parameterized
from torchaudio._internal.module_utils import is_sox_available
from torchaudio.backend import sox_io_backend
from torchaudio.utils import sox_utils
from torchaudio_unittest.common_utils import (
get_wav_data,
skipIfNoSox,
Expand All @@ -16,12 +13,6 @@
from .common import name_func


skipIfNoMP3 = unittest.skipIf(
not is_sox_available() or "mp3" not in sox_utils.list_read_formats() or "mp3" not in sox_utils.list_write_formats(),
'"sox_io" backend does not support MP3',
)


@skipIfNoSox
class SmokeTest(TempDirMixin, TorchaudioTestCase):
"""Run smoke test on various audio format
Expand Down Expand Up @@ -73,7 +64,6 @@ def test_wav(self, dtype, sample_rate, num_channels):
)
)
)
@skipIfNoMP3
def test_mp3(self, sample_rate, num_channels, bit_rate):
"""Run smoke test on mp3 format"""
self.run_smoke_test("mp3", sample_rate, num_channels, compression=bit_rate)
Expand Down Expand Up @@ -159,7 +149,6 @@ def test_wav(self, dtype, sample_rate, num_channels):
)
)
)
@skipIfNoMP3
def test_mp3(self, sample_rate, num_channels, bit_rate):
"""Run smoke test on mp3 format"""
self.run_smoke_test("mp3", sample_rate, num_channels, compression=bit_rate)
Expand Down
Loading

0 comments on commit 19c60a0

Please sign in to comment.