Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use FFmpeg-based I/O as fallback in sox_io backend #2419

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 23 additions & 7 deletions test/torchaudio_unittest/backend/sox_io/info_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,23 +312,31 @@ def test_opus(self, bitrate, num_channels, compression_level):
@skipIfNoSox
class TestLoadWithoutExtension(PytorchTestCase):
def test_mp3(self):
"""Providing `format` allows to read mp3 without extension

libsox does not check header for mp3
"""MP3 file without extension can be loaded

Originally, we added `format` argument for this case, but now we use FFmpeg
for MP3 decoding, which works even without `format` argument.
https://github.com/pytorch/audio/issues/1040

The file was generated with the following command
ffmpeg -f lavfi -i "sine=frequency=1000:duration=5" -ar 16000 -f mp3 test_noext
"""
path = get_asset_path("mp3_without_ext")
sinfo = sox_io_backend.info(path, format="mp3")
sinfo = sox_io_backend.info(path)
assert sinfo.sample_rate == 16000
assert sinfo.num_frames == 81216
assert sinfo.num_frames == 0
assert sinfo.num_channels == 1
assert sinfo.bits_per_sample == 0 # bit_per_sample is irrelevant for compressed formats
assert sinfo.encoding == "MP3"

with open(path, "rb") as fileobj:
sinfo = sox_io_backend.info(fileobj)
assert sinfo.sample_rate == 16000
assert sinfo.num_frames == 0
assert sinfo.num_channels == 1
assert sinfo.bits_per_sample == 0
assert sinfo.encoding == "MP3"


class FileObjTestBase(TempDirMixin):
def _gen_file(self, ext, dtype, sample_rate, num_channels, num_frames, *, comments=None):
Expand All @@ -355,6 +363,14 @@ def _gen_comment_file(self, comments):
return comment_path


class Unseekable:
def __init__(self, fileobj):
self.fileobj = fileobj

def read(self, n):
return self.fileobj.read(n)


@skipIfNoSox
@skipIfNoExec("sox")
class TestFileObject(FileObjTestBase, PytorchTestCase):
Expand Down Expand Up @@ -435,7 +451,7 @@ def test_fileobj_large_header(self, ext, dtype):
num_channels = 2
comments = "metadata=" + " ".join(["value" for _ in range(1000)])

with self.assertRaisesRegex(RuntimeError, "Failed to fetch metadata from"):
with self.assertRaises(RuntimeError):
sinfo = self._query_fileobj(ext, dtype, sample_rate, num_channels, num_frames, comments=comments)

with self._set_buffer_size(16384):
Expand Down Expand Up @@ -545,7 +561,7 @@ def _query_http(self, ext, dtype, sample_rate, num_channels, num_frames):
url = self.get_url(audio_file)
format_ = ext if ext in ["mp3"] else None
with requests.get(url, stream=True) as resp:
return sox_io_backend.info(resp.raw, format=format_)
return sox_io_backend.info(Unseekable(resp.raw), format=format_)

@parameterized.expand(
[
Expand Down
143 changes: 73 additions & 70 deletions test/torchaudio_unittest/backend/sox_io/load_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import itertools
import tarfile

import torch
import torchaudio
from parameterized import parameterized
from torchaudio._internal import module_utils as _mod_utils
from torchaudio.backend import sox_io_backend
Expand All @@ -10,6 +12,7 @@
get_wav_data,
HttpServerMixin,
load_wav,
nested_params,
PytorchTestCase,
save_wav,
skipIfNoExec,
Expand Down Expand Up @@ -169,35 +172,6 @@ def test_multiple_channels(self, dtype, num_channels):
normalize = False
self.assert_wav(dtype, sample_rate, num_channels, normalize, duration=1)

@parameterized.expand(
list(
itertools.product(
[8000, 16000, 44100],
[1, 2],
[96, 128, 160, 192, 224, 256, 320],
)
),
name_func=name_func,
)
def test_mp3(self, sample_rate, num_channels, bit_rate):
"""`sox_io_backend.load` can load mp3 format correctly."""
self.assert_format("mp3", sample_rate, num_channels, compression=bit_rate, duration=1, atol=5e-05)

@parameterized.expand(
list(
itertools.product(
[16000],
[2],
[128],
)
),
name_func=name_func,
)
def test_mp3_large(self, sample_rate, num_channels, bit_rate):
"""`sox_io_backend.load` can load large mp3 file correctly."""
two_hours = 2 * 60 * 60
self.assert_format("mp3", sample_rate, num_channels, compression=bit_rate, duration=two_hours, atol=5e-05)

@parameterized.expand(
list(
itertools.product(
Expand Down Expand Up @@ -319,72 +293,92 @@ def test_amr_nb(self):
self.assert_format("amr-nb", sample_rate=8000, num_channels=1, bit_depth=32, duration=1)


@skipIfNoExec("sox")
@skipIfNoSox
class TestLoadParams(TempDirMixin, PytorchTestCase):
"""Test the correctness of frame parameters of `sox_io_backend.load`"""

original = None
path = None
def _test(self, func, frame_offset, num_frames, channels_first, normalize):
original = get_wav_data("int16", num_channels=2, normalize=False)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could we assign normalize=normalize here and avoid the extra normalize branch later in the function?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The normalize argument in get_wav_data is how the original data is generated, while the normalize argument in the test is about what value to pass to the load function, so that data is converted to float32 on-the-fly, no matter what the original data type is. (note that the WAV format stores data with int16. There is almost no application other than DL that handles float32 WAV files)

To test the normalize argument of load function, we need to test on waveform other than float32, (int16 is the common choice) so that we can check switching normalize of load function would change the resulting Tensor dtype. If the original dtype is float32, we do not know if normalize argument of load is properly working.

path = self.get_temp_path("test.wav")
save_wav(path, original, sample_rate=8000)

def setUp(self):
super().setUp()
sample_rate = 8000
self.original = get_wav_data("float32", num_channels=2)
self.path = self.get_temp_path("test.wav")
save_wav(self.path, self.original, sample_rate)
output, _ = func(path, frame_offset, num_frames, normalize, channels_first, None)
frame_end = None if num_frames == -1 else frame_offset + num_frames
expected = original[:, slice(frame_offset, frame_end)]
if not channels_first:
expected = expected.T
if normalize:
expected = expected.to(torch.float32) / (2**15)
self.assertEqual(output, expected)

@nested_params(
[0, 1, 10, 100, 1000],
[-1, 1, 10, 100, 1000],
[True, False],
[True, False],
)
def test_sox(self, frame_offset, num_frames, channels_first, normalize):
"""The combination of properly changes the output tensor"""

@parameterized.expand(
list(
itertools.product(
[0, 1, 10, 100, 1000],
[-1, 1, 10, 100, 1000],
)
),
name_func=name_func,
self._test(torch.ops.torchaudio.sox_io_load_audio_file, frame_offset, num_frames, channels_first, normalize)

# test file-like obj
def func(path, *args):
with open(path, "rb") as fileobj:
return torchaudio._torchaudio.load_audio_fileobj(fileobj, *args)

self._test(func, frame_offset, num_frames, channels_first, normalize)

@nested_params(
[0, 1, 10, 100, 1000],
[-1, 1, 10, 100, 1000],
[True, False],
[True, False],
)
def test_frame(self, frame_offset, num_frames):
"""num_frames and frame_offset correctly specify the region of data"""
found, _ = sox_io_backend.load(self.path, frame_offset, num_frames)
frame_end = None if num_frames == -1 else frame_offset + num_frames
self.assertEqual(found, self.original[:, frame_offset:frame_end])
def test_ffmpeg(self, frame_offset, num_frames, channels_first, normalize):
"""The combination of properly changes the output tensor"""
from torchaudio.io._compat import load_audio, load_audio_fileobj

self._test(load_audio, frame_offset, num_frames, channels_first, normalize)

# test file-like obj
def func(path, *args):
with open(path, "rb") as fileobj:
return load_audio_fileobj(fileobj, *args)

@parameterized.expand([(True,), (False,)], name_func=name_func)
def test_channels_first(self, channels_first):
"""channels_first swaps axes"""
found, _ = sox_io_backend.load(self.path, channels_first=channels_first)
expected = self.original if channels_first else self.original.transpose(1, 0)
self.assertEqual(found, expected)
self._test(func, frame_offset, num_frames, channels_first, normalize)


@skipIfNoSox
class TestLoadWithoutExtension(PytorchTestCase):
def test_mp3(self):
"""Providing format allows to read mp3 without extension

libsox does not check header for mp3
"""MP3 file without extension can be loaded

Originally, we added `format` argument for this case, but now we use FFmpeg
for MP3 decoding, which works even without `format` argument.
https://github.com/pytorch/audio/issues/1040

The file was generated with the following command
ffmpeg -f lavfi -i "sine=frequency=1000:duration=5" -ar 16000 -f mp3 test_noext
"""
path = get_asset_path("mp3_without_ext")
_, sr = sox_io_backend.load(path, format="mp3")
_, sr = sox_io_backend.load(path)
assert sr == 16000

with open(path, "rb") as fileobj:
_, sr = sox_io_backend.load(fileobj)
assert sr == 16000


class CloggedFileObj:
def __init__(self, fileobj):
self.fileobj = fileobj
self.buffer = b""

def read(self, n):
if not self.buffer:
self.buffer += self.fileobj.read(n)
ret = self.buffer[:2]
self.buffer = self.buffer[2:]
return ret
def read(self, _):
return self.fileobj.read(2)

def seek(self, offset, whence):
return self.fileobj.seek(offset, whence)


@skipIfNoSox
Expand Down Expand Up @@ -557,6 +551,14 @@ def test_tarfile(self, ext, kwargs):
self.assertEqual(expected, found)


class Unseekable:
def __init__(self, fileobj):
self.fileobj = fileobj

def read(self, n):
return self.fileobj.read(n)


@skipIfNoSox
@skipIfNoExec("sox")
@skipIfNoModule("requests")
Expand Down Expand Up @@ -587,10 +589,11 @@ def test_requests(self, ext, kwargs):

url = self.get_url(audio_file)
with requests.get(url, stream=True) as resp:
found, sr = sox_io_backend.load(resp.raw, format=format_)
found, sr = sox_io_backend.load(Unseekable(resp.raw), format=format_)

assert sr == sample_rate
self.assertEqual(expected, found)
if ext != "mp3":
self.assertEqual(expected, found)

@parameterized.expand(
list(
Expand Down
29 changes: 1 addition & 28 deletions test/torchaudio_unittest/backend/sox_io/save_test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import io
import os
import unittest

import torch
from parameterized import parameterized
Expand Down Expand Up @@ -179,31 +178,6 @@ def test_save_wav_dtype(self, test_mode, params):
(dtype,) = params
self.assert_save_consistency("wav", src_dtype=dtype, test_mode=test_mode)

@nested_params(
["path", "fileobj", "bytesio"],
[
None,
-4.2,
-0.2,
0,
0.2,
96,
128,
160,
192,
224,
256,
320,
],
)
def test_save_mp3(self, test_mode, bit_rate):
if test_mode in ["fileobj", "bytesio"]:
if bit_rate is not None and bit_rate < 1:
raise unittest.SkipTest(
"mp3 format with variable bit rate is known to " "not yield the exact same result as sox command."
)
self.assert_save_consistency("mp3", compression=bit_rate, test_mode=test_mode)

@nested_params(
["path", "fileobj", "bytesio"],
[8, 16, 24],
Expand Down Expand Up @@ -349,7 +323,6 @@ def test_save_gsm(self, test_mode):
@parameterized.expand(
[
("wav", "PCM_S", 16),
("mp3",),
("flac",),
("vorbis",),
("sph", "PCM_S", 16),
Expand Down Expand Up @@ -437,5 +410,5 @@ def test_save_fail(self):
When attempted to save into a non-existing dir, error message must contain the file path.
"""
path = os.path.join("non_existing_directory", "foo.wav")
with self.assertRaisesRegex(RuntimeError, "^Error saving audio file: failed to open file {0}$".format(path)):
with self.assertRaisesRegex(RuntimeError, path):
sox_io_backend.save(path, torch.zeros(1, 1), 8000)
11 changes: 0 additions & 11 deletions test/torchaudio_unittest/backend/sox_io/smoke_test.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
import io
import itertools
import unittest

from parameterized import parameterized
from torchaudio._internal.module_utils import is_sox_available
from torchaudio.backend import sox_io_backend
from torchaudio.utils import sox_utils
from torchaudio_unittest.common_utils import (
get_wav_data,
skipIfNoSox,
Expand All @@ -16,12 +13,6 @@
from .common import name_func


skipIfNoMP3 = unittest.skipIf(
not is_sox_available() or "mp3" not in sox_utils.list_read_formats() or "mp3" not in sox_utils.list_write_formats(),
'"sox_io" backend does not support MP3',
)


@skipIfNoSox
class SmokeTest(TempDirMixin, TorchaudioTestCase):
"""Run smoke test on various audio format
Expand Down Expand Up @@ -73,7 +64,6 @@ def test_wav(self, dtype, sample_rate, num_channels):
)
)
)
@skipIfNoMP3
def test_mp3(self, sample_rate, num_channels, bit_rate):
"""Run smoke test on mp3 format"""
self.run_smoke_test("mp3", sample_rate, num_channels, compression=bit_rate)
Expand Down Expand Up @@ -159,7 +149,6 @@ def test_wav(self, dtype, sample_rate, num_channels):
)
)
)
@skipIfNoMP3
def test_mp3(self, sample_rate, num_channels, bit_rate):
"""Run smoke test on mp3 format"""
self.run_smoke_test("mp3", sample_rate, num_channels, compression=bit_rate)
Expand Down
Loading