Skip to content

Commit

Permalink
Use FFmpeg-based I/O as fallback in sox_io backend (pytorch#2419)
Browse files Browse the repository at this point in the history
Summary:
This commit add fallback mechanism to `info` and `load` functions of sox_io backend.
If torchaudio is compiled to use FFmpeg, and runtime dependencies are properly loaded,
in case `info` and `load` fail, it fallback to FFmpeg-based implementation.

Depends on pytorch#2416, pytorch#2417, pytorch#2418

Pull Request resolved: pytorch#2419

Differential Revision: D36740306

Pulled By: mthrok

fbshipit-source-id: 91dfdd199959d83ce643ccc38cf163ce29fba55e
  • Loading branch information
mthrok committed May 31, 2022
1 parent b56f60b commit 53a3aa7
Show file tree
Hide file tree
Showing 14 changed files with 333 additions and 82 deletions.
30 changes: 23 additions & 7 deletions test/torchaudio_unittest/backend/sox_io/info_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,23 +312,31 @@ def test_opus(self, bitrate, num_channels, compression_level):
@skipIfNoSox
class TestLoadWithoutExtension(PytorchTestCase):
def test_mp3(self):
"""Providing `format` allows to read mp3 without extension
libsox does not check header for mp3
"""MP3 file without extension can be loaded
Originally, we added `format` argument for this case, but now we use FFmpeg
for MP3 decoding, which works even without `format` argument.
https://github.com/pytorch/audio/issues/1040
The file was generated with the following command
ffmpeg -f lavfi -i "sine=frequency=1000:duration=5" -ar 16000 -f mp3 test_noext
"""
path = get_asset_path("mp3_without_ext")
sinfo = sox_io_backend.info(path, format="mp3")
sinfo = sox_io_backend.info(path)
assert sinfo.sample_rate == 16000
assert sinfo.num_frames == 81216
assert sinfo.num_frames == 0
assert sinfo.num_channels == 1
assert sinfo.bits_per_sample == 0 # bit_per_sample is irrelevant for compressed formats
assert sinfo.encoding == "MP3"

with open(path, "rb") as fileobj:
sinfo = sox_io_backend.info(fileobj)
assert sinfo.sample_rate == 16000
assert sinfo.num_frames == 0
assert sinfo.num_channels == 1
assert sinfo.bits_per_sample == 0
assert sinfo.encoding == "MP3"


class FileObjTestBase(TempDirMixin):
def _gen_file(self, ext, dtype, sample_rate, num_channels, num_frames, *, comments=None):
Expand All @@ -355,6 +363,14 @@ def _gen_comment_file(self, comments):
return comment_path


class Unseekable:
def __init__(self, fileobj):
self.fileobj = fileobj

def read(self, n):
return self.fileobj.read(n)


@skipIfNoSox
@skipIfNoExec("sox")
class TestFileObject(FileObjTestBase, PytorchTestCase):
Expand Down Expand Up @@ -435,7 +451,7 @@ def test_fileobj_large_header(self, ext, dtype):
num_channels = 2
comments = "metadata=" + " ".join(["value" for _ in range(1000)])

with self.assertRaisesRegex(RuntimeError, "Failed to fetch metadata from"):
with self.assertRaises(RuntimeError):
sinfo = self._query_fileobj(ext, dtype, sample_rate, num_channels, num_frames, comments=comments)

with self._set_buffer_size(16384):
Expand Down Expand Up @@ -545,7 +561,7 @@ def _query_http(self, ext, dtype, sample_rate, num_channels, num_frames):
url = self.get_url(audio_file)
format_ = ext if ext in ["mp3"] else None
with requests.get(url, stream=True) as resp:
return sox_io_backend.info(resp.raw, format=format_)
return sox_io_backend.info(Unseekable(resp.raw), format=format_)

@parameterized.expand(
[
Expand Down
118 changes: 76 additions & 42 deletions test/torchaudio_unittest/backend/sox_io/load_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,18 @@
import itertools
import tarfile

import torch
import torchaudio
from parameterized import parameterized
from torchaudio._internal import module_utils as _mod_utils
from torchaudio.backend import sox_io_backend
from torchaudio_unittest.common_utils import (
ffmpeg_utils,
get_asset_path,
get_wav_data,
HttpServerMixin,
load_wav,
nested_params,
PytorchTestCase,
save_wav,
skipIfNoExec,
Expand Down Expand Up @@ -81,7 +85,10 @@ def assert_format(
)
# 2. Convert to wav with sox
wav_bit_depth = 32 if bit_depth == 24 else None # for 24-bit wav
sox_utils.convert_audio_file(path, ref_path, bit_depth=wav_bit_depth)
if format == "mp3":
ffmpeg_utils.convert_to_wav(path, ref_path)
else:
sox_utils.convert_audio_file(path, ref_path, bit_depth=wav_bit_depth)
# 3. Load the given format with torchaudio
data, sr = sox_io_backend.load(path, normalize=normalize)
# 4. Load wav with scipy
Expand Down Expand Up @@ -319,72 +326,90 @@ def test_amr_nb(self):
self.assert_format("amr-nb", sample_rate=8000, num_channels=1, bit_depth=32, duration=1)


@skipIfNoExec("sox")
@skipIfNoSox
class TestLoadParams(TempDirMixin, PytorchTestCase):
"""Test the correctness of frame parameters of `sox_io_backend.load`"""

original = None
path = None
def _test(self, func, frame_offset, num_frames, channels_first, normalize):
original = get_wav_data("int16", num_channels=2, normalize=False)
path = self.get_temp_path("test.wav")
save_wav(path, original, sample_rate=8000)

def setUp(self):
super().setUp()
sample_rate = 8000
self.original = get_wav_data("float32", num_channels=2)
self.path = self.get_temp_path("test.wav")
save_wav(self.path, self.original, sample_rate)
output, _ = func(path, frame_offset, num_frames, normalize, channels_first, None)
frame_end = None if num_frames == -1 else frame_offset + num_frames
expected = original[:, slice(frame_offset, frame_end)]
if not channels_first:
expected = expected.T
if normalize:
expected = expected.to(torch.float32) / (2**15)
self.assertEqual(output, expected)

@nested_params(
[0, 1, 10, 100, 1000],
[-1, 1, 10, 100, 1000],
[True, False],
[True, False],
)
def test_sox(self, frame_offset, num_frames, channels_first, normalize):
"""The combination of properly changes the output tensor"""

@parameterized.expand(
list(
itertools.product(
[0, 1, 10, 100, 1000],
[-1, 1, 10, 100, 1000],
)
),
name_func=name_func,
self._test(torch.ops.torchaudio.sox_io_load_audio_file, frame_offset, num_frames, channels_first, normalize)

# test file-like obj
def func(path, *args):
with open(path, "rb") as fileobj:
return torchaudio._torchaudio.load_audio_fileobj(fileobj, *args)

self._test(func, frame_offset, num_frames, channels_first, normalize)

@nested_params(
[0, 1, 10, 100, 1000],
[-1, 1, 10, 100, 1000],
[True, False],
[True, False],
)
def test_frame(self, frame_offset, num_frames):
"""num_frames and frame_offset correctly specify the region of data"""
found, _ = sox_io_backend.load(self.path, frame_offset, num_frames)
frame_end = None if num_frames == -1 else frame_offset + num_frames
self.assertEqual(found, self.original[:, frame_offset:frame_end])
def test_ffmpeg(self, frame_offset, num_frames, channels_first, normalize):
"""The combination of properly changes the output tensor"""
self._test(torch.ops.torchaudio.ffmpeg_load_audio, frame_offset, num_frames, channels_first, normalize)

@parameterized.expand([(True,), (False,)], name_func=name_func)
def test_channels_first(self, channels_first):
"""channels_first swaps axes"""
found, _ = sox_io_backend.load(self.path, channels_first=channels_first)
expected = self.original if channels_first else self.original.transpose(1, 0)
self.assertEqual(found, expected)
# test file-like obj
def func(path, *args):
with open(path, "rb") as fileobj:
return torchaudio._torchaudio_ffmpeg.load_audio_fileobj(fileobj, *args)

self._test(func, frame_offset, num_frames, channels_first, normalize)


@skipIfNoSox
class TestLoadWithoutExtension(PytorchTestCase):
def test_mp3(self):
"""Providing format allows to read mp3 without extension
libsox does not check header for mp3
"""MP3 file without extension can be loaded
Originally, we added `format` argument for this case, but now we use FFmpeg
for MP3 decoding, which works even without `format` argument.
https://github.com/pytorch/audio/issues/1040
The file was generated with the following command
ffmpeg -f lavfi -i "sine=frequency=1000:duration=5" -ar 16000 -f mp3 test_noext
"""
path = get_asset_path("mp3_without_ext")
_, sr = sox_io_backend.load(path, format="mp3")
_, sr = sox_io_backend.load(path)
assert sr == 16000

with open(path, "rb") as fileobj:
_, sr = sox_io_backend.load(fileobj)
assert sr == 16000


class CloggedFileObj:
def __init__(self, fileobj):
self.fileobj = fileobj
self.buffer = b""

def read(self, n):
if not self.buffer:
self.buffer += self.fileobj.read(n)
ret = self.buffer[:2]
self.buffer = self.buffer[2:]
return ret
def read(self, _):
return self.fileobj.read(2)

def seek(self, offset, whence):
return self.fileobj.seek(offset, whence)


@skipIfNoSox
Expand Down Expand Up @@ -557,6 +582,14 @@ def test_tarfile(self, ext, kwargs):
self.assertEqual(expected, found)


class Unseekable:
def __init__(self, fileobj):
self.fileobj = fileobj

def read(self, n):
return self.fileobj.read(n)


@skipIfNoSox
@skipIfNoExec("sox")
@skipIfNoModule("requests")
Expand Down Expand Up @@ -587,10 +620,11 @@ def test_requests(self, ext, kwargs):

url = self.get_url(audio_file)
with requests.get(url, stream=True) as resp:
found, sr = sox_io_backend.load(resp.raw, format=format_)
found, sr = sox_io_backend.load(Unseekable(resp.raw), format=format_)

assert sr == sample_rate
self.assertEqual(expected, found)
if ext != "mp3":
self.assertEqual(expected, found)

@parameterized.expand(
list(
Expand Down
13 changes: 10 additions & 3 deletions test/torchaudio_unittest/backend/sox_io/save_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from parameterized import parameterized
from torchaudio.backend import sox_io_backend
from torchaudio_unittest.common_utils import (
ffmpeg_utils,
get_wav_data,
load_wav,
nested_params,
Expand Down Expand Up @@ -130,7 +131,10 @@ def assert_save_consistency(
else:
raise ValueError(f"Unexpected test mode: {test_mode}")
# 2.2. Convert the target format to wav with sox
sox_utils.convert_audio_file(tgt_path, tst_path, encoding=cmp_encoding, bit_depth=cmp_bit_depth)
if format == "mp3":
ffmpeg_utils.convert_to_wav(tgt_path, tst_path)
else:
sox_utils.convert_audio_file(tgt_path, tst_path, encoding=cmp_encoding, bit_depth=cmp_bit_depth)
# 2.3. Load with SciPy
found = load_wav(tst_path, normalize=False)[0]

Expand All @@ -140,7 +144,10 @@ def assert_save_consistency(
src_path, sox_path, compression=compression, encoding=sox_encoding, bit_depth=bits_per_sample
)
# 3.2. Convert the target format to wav with sox
sox_utils.convert_audio_file(sox_path, ref_path, encoding=cmp_encoding, bit_depth=cmp_bit_depth)
if format == "mp3":
ffmpeg_utils.convert_to_wav(sox_path, ref_path)
else:
sox_utils.convert_audio_file(sox_path, ref_path, encoding=cmp_encoding, bit_depth=cmp_bit_depth)
# 3.3. Load with SciPy
expected = load_wav(ref_path, normalize=False)[0]

Expand Down Expand Up @@ -437,5 +444,5 @@ def test_save_fail(self):
When attempted to save into a non-existing dir, error message must contain the file path.
"""
path = os.path.join("non_existing_directory", "foo.wav")
with self.assertRaisesRegex(RuntimeError, "^Error saving audio file: failed to open file {0}$".format(path)):
with self.assertRaisesRegex(RuntimeError, path):
sox_io_backend.save(path, torch.zeros(1, 1), 8000)
11 changes: 0 additions & 11 deletions test/torchaudio_unittest/backend/sox_io/smoke_test.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
import io
import itertools
import unittest

from parameterized import parameterized
from torchaudio._internal.module_utils import is_sox_available
from torchaudio.backend import sox_io_backend
from torchaudio.utils import sox_utils
from torchaudio_unittest.common_utils import (
get_wav_data,
skipIfNoSox,
Expand All @@ -16,12 +13,6 @@
from .common import name_func


skipIfNoMP3 = unittest.skipIf(
not is_sox_available() or "mp3" not in sox_utils.list_read_formats() or "mp3" not in sox_utils.list_write_formats(),
'"sox_io" backend does not support MP3',
)


@skipIfNoSox
class SmokeTest(TempDirMixin, TorchaudioTestCase):
"""Run smoke test on various audio format
Expand Down Expand Up @@ -73,7 +64,6 @@ def test_wav(self, dtype, sample_rate, num_channels):
)
)
)
@skipIfNoMP3
def test_mp3(self, sample_rate, num_channels, bit_rate):
"""Run smoke test on mp3 format"""
self.run_smoke_test("mp3", sample_rate, num_channels, compression=bit_rate)
Expand Down Expand Up @@ -159,7 +149,6 @@ def test_wav(self, dtype, sample_rate, num_channels):
)
)
)
@skipIfNoMP3
def test_mp3(self, sample_rate, num_channels, bit_rate):
"""Run smoke test on mp3 format"""
self.run_smoke_test("mp3", sample_rate, num_channels, compression=bit_rate)
Expand Down
10 changes: 10 additions & 0 deletions test/torchaudio_unittest/common_utils/ffmpeg_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import subprocess
import sys


def convert_to_wav(src_path, dst_path):
"""Convert audio file with `ffmpeg` command."""
# TODO: parameterize codec
command = ["ffmpeg", "-y", "-i", src_path, "-c:a", "pcm_f32le", dst_path]
print(" ".join(command), file=sys.stderr)
subprocess.run(command, check=True)
26 changes: 23 additions & 3 deletions test/torchaudio_unittest/io/stream_reader_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,20 @@ def test_seek_negative(self):
s.seek(-1.0)


def _to_fltp(original):
denom = {
torch.uint8: 2**7,
torch.int16: 2**15,
torch.int32: 2**31,
}[original.dtype]

fltp = original.to(torch.float32)
if original.dtype == torch.uint8:
fltp -= 128
fltp /= denom
return fltp


@skipIfNoFFmpeg
@_media_source
class StreamReaderAudioTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestCase):
Expand Down Expand Up @@ -399,9 +413,15 @@ def test_basic_audio_stream(self, dtype, num_channels):

# provide the matching dtype
self._test_wav(src, original, fmt=fmt)
if not self.test_fileobj:
# use the internal dtype ffmpeg picks
self._test_wav(src, original, fmt=None)
# use the internal dtype ffmpeg picks
if self.test_fileobj:
src.seek(0)
self._test_wav(src, original, fmt=None)
# convert to float32
expected = _to_fltp(original)
if self.test_fileobj:
src.seek(0)
self._test_wav(src, expected, fmt="fltp")

@nested_params(
["int16", "uint8", "int32"], # "float", "double", "int64"]
Expand Down
Loading

0 comments on commit 53a3aa7

Please sign in to comment.