Skip to content

Commit

Permalink
Add TorchScript-able "save" func to sox_io backend (#732)
Browse files Browse the repository at this point in the history
This is a part of PRs to add new "sox_io" backend. #726 and depends on #718, #728 and #731.

This PR adds `save` function to "sox_io" backend, which can save Tensor to a file with the following audio formats;
 - `wav`
 - `mp3`
 - `flac`
 - `ogg/vorbis`
  • Loading branch information
mthrok authored Jul 1, 2020
1 parent ea42513 commit 3324283
Show file tree
Hide file tree
Showing 13 changed files with 750 additions and 45 deletions.
36 changes: 21 additions & 15 deletions test/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,28 +92,34 @@ def set_audio_backend(backend):
class TempDirMixin:
"""Mixin to provide easy access to temp dir"""
temp_dir_ = None
base_temp_dir = None
temp_dir = None

def setUp(self):
super().setUp()
self._init_temp_dir()
@classmethod
def setUpClass(cls):
super().setUpClass()
# If TORCHAUDIO_TEST_TEMP_DIR is set, use it instead of temporary directory.
# this is handy for debugging.
key = 'TORCHAUDIO_TEST_TEMP_DIR'
if key in os.environ:
cls.base_temp_dir = os.environ[key]
else:
cls.temp_dir_ = tempfile.TemporaryDirectory()
cls.base_temp_dir = cls.temp_dir_.name

def tearDown(self):
@classmethod
def tearDownClass(cls):
super().tearDownClass()
self._clean_up_temp_dir()
if isinstance(cls.temp_dir_, tempfile.TemporaryDirectory):
cls.temp_dir_.cleanup()

def _init_temp_dir(self):
self.temp_dir_ = tempfile.TemporaryDirectory()
self.temp_dir = self.temp_dir_.name

def _clean_up_temp_dir(self):
if self.temp_dir_ is not None:
self.temp_dir_.cleanup()
self.temp_dir_ = None
self.temp_dir = None
def setUp(self):
self.temp_dir = os.path.join(self.base_temp_dir, self.id())

def get_temp_path(self, *paths):
return os.path.join(self.temp_dir, *paths)
path = os.path.join(self.temp_dir, *paths)
os.makedirs(os.path.dirname(path), exist_ok=True)
return path


class TestBaseMixin:
Expand Down
13 changes: 11 additions & 2 deletions test/sox_io_backend/sox_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,16 @@ def gen_audio_file(
'Use get_wav_data and save_wav to generate wav file for accurate result.')
command = [
'sox',
'-V', # verbose
'-V3', # verbose
'-R',
# -R is supposed to be repeatable, though the implementation looks suspicious
# and not setting the seed to a fixed value.
# https://fossies.org/dox/sox-14.4.2/sox_8c_source.html
# search "sox_globals.repeatable"
]
if bit_depth is not None:
command += ['--bits', str(bit_depth)]
command += [
'--rate', str(sample_rate),
'--null', # no input
'--channels', str(num_channels),
Expand Down Expand Up @@ -60,7 +69,7 @@ def convert_audio_file(
src_path, dst_path,
*, bit_depth=None, compression=None):
"""Convert audio file with `sox` command."""
command = ['sox', '-V', str(src_path)]
command = ['sox', '-V3', '-R', str(src_path)]
if bit_depth is not None:
command += ['--bits', str(bit_depth)]
if compression is not None:
Expand Down
10 changes: 5 additions & 5 deletions test/sox_io_backend/test_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class TestInfo(TempDirMixin, PytorchTestCase):
def test_wav(self, dtype, sample_rate, num_channels):
"""`sox_io_backend.info` can check wav file correctly"""
duration = 1
path = self.get_temp_path(f'{dtype}_{sample_rate}_{num_channels}.wav')
path = self.get_temp_path('data.wav')
data = get_wav_data(dtype, num_channels, normalize=False, num_frames=duration * sample_rate)
save_wav(path, data, sample_rate)
info = sox_io_backend.info(path)
Expand All @@ -44,7 +44,7 @@ def test_wav(self, dtype, sample_rate, num_channels):
def test_wav_multiple_channels(self, dtype, sample_rate, num_channels):
"""`sox_io_backend.info` can check wav file with channels more than 2 correctly"""
duration = 1
path = self.get_temp_path(f'{dtype}_{sample_rate}_{num_channels}.wav')
path = self.get_temp_path('data.wav')
data = get_wav_data(dtype, num_channels, normalize=False, num_frames=duration * sample_rate)
save_wav(path, data, sample_rate)
info = sox_io_backend.info(path)
Expand All @@ -60,7 +60,7 @@ def test_wav_multiple_channels(self, dtype, sample_rate, num_channels):
def test_mp3(self, sample_rate, num_channels, bit_rate):
"""`sox_io_backend.info` can check mp3 file correctly"""
duration = 1
path = self.get_temp_path(f'{sample_rate}_{num_channels}_{bit_rate}k.mp3')
path = self.get_temp_path('data.mp3')
sox_utils.gen_audio_file(
path, sample_rate, num_channels,
compression=bit_rate, duration=duration,
Expand All @@ -79,7 +79,7 @@ def test_mp3(self, sample_rate, num_channels, bit_rate):
def test_flac(self, sample_rate, num_channels, compression_level):
"""`sox_io_backend.info` can check flac file correctly"""
duration = 1
path = self.get_temp_path(f'{sample_rate}_{num_channels}_{compression_level}.flac')
path = self.get_temp_path('data.flac')
sox_utils.gen_audio_file(
path, sample_rate, num_channels,
compression=compression_level, duration=duration,
Expand All @@ -97,7 +97,7 @@ def test_flac(self, sample_rate, num_channels, compression_level):
def test_vorbis(self, sample_rate, num_channels, quality_level):
"""`sox_io_backend.info` can check vorbis file correctly"""
duration = 1
path = self.get_temp_path(f'{sample_rate}_{num_channels}_{quality_level}.vorbis')
path = self.get_temp_path('data.vorbis')
sox_utils.gen_audio_file(
path, sample_rate, num_channels,
compression=quality_level, duration=duration,
Expand Down
14 changes: 7 additions & 7 deletions test/sox_io_backend/test_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def assert_wav(self, dtype, sample_rate, num_channels, normalize, duration):
Wav data loaded with sox_io backend should match those with scipy
"""
path = self.get_temp_path(f'{dtype}_{sample_rate}_{num_channels}_{normalize}.wav')
path = self.get_temp_path('reference.wav')
data = get_wav_data(dtype, num_channels, normalize=normalize, num_frames=duration * sample_rate)
save_wav(path, data, sample_rate)
expected = load_wav(path, normalize=normalize)[0]
Expand Down Expand Up @@ -58,8 +58,8 @@ def assert_mp3(self, sample_rate, num_channels, bit_rate, duration):
By combining i & ii, step 2. and 4. allows to load reference mp3 data
without using torchaudio
"""
path = self.get_temp_path(f'{sample_rate}_{num_channels}_{bit_rate}_{duration}.mp3')
ref_path = f'{path}.wav'
path = self.get_temp_path('1.original.mp3')
ref_path = self.get_temp_path('2.reference.wav')

# 1. Generate mp3 with sox
sox_utils.gen_audio_file(
Expand All @@ -80,8 +80,8 @@ def assert_flac(self, sample_rate, num_channels, compression_level, duration):
This test takes the same strategy as mp3 to compare the result
"""
path = self.get_temp_path(f'{sample_rate}_{num_channels}_{compression_level}_{duration}.flac')
ref_path = f'{path}.wav'
path = self.get_temp_path('1.original.flac')
ref_path = self.get_temp_path('2.reference.wav')

# 1. Generate flac with sox
sox_utils.gen_audio_file(
Expand All @@ -102,8 +102,8 @@ def assert_vorbis(self, sample_rate, num_channels, quality_level, duration):
This test takes the same strategy as mp3 to compare the result
"""
path = self.get_temp_path(f'{sample_rate}_{num_channels}_{quality_level}_{duration}.vorbis')
ref_path = f'{path}.wav'
path = self.get_temp_path('1.original.vorbis')
ref_path = self.get_temp_path('2.reference.wav')

# 1. Generate vorbis with sox
sox_utils.gen_audio_file(
Expand Down
52 changes: 52 additions & 0 deletions test/sox_io_backend/test_roundtrip.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import itertools

from torchaudio.backend import sox_io_backend
from parameterized import parameterized

from ..common_utils import (
TempDirMixin,
PytorchTestCase,
skipIfNoExec,
skipIfNoExtension,
)
from .common import (
get_test_name,
get_wav_data,
)


@skipIfNoExec('sox')
@skipIfNoExtension
class TestRoundTripIO(TempDirMixin, PytorchTestCase):
"""save/load round trip should not degrade data for lossless formats"""
@parameterized.expand(list(itertools.product(
['float32', 'int32', 'int16', 'uint8'],
[8000, 16000],
[1, 2],
)), name_func=get_test_name)
def test_wav(self, dtype, sample_rate, num_channels):
"""save/load round trip should not degrade data for wav formats"""
original = get_wav_data(dtype, num_channels, normalize=False)
data = original
for i in range(10):
path = self.get_temp_path(f'{i}.wav')
sox_io_backend.save(path, data, sample_rate)
data, sr = sox_io_backend.load(path, normalize=False)
assert sr == sample_rate
self.assertEqual(original, data)

@parameterized.expand(list(itertools.product(
[8000, 16000],
[1, 2],
list(range(9)),
)), name_func=get_test_name)
def test_flac(self, sample_rate, num_channels, compression_level):
"""save/load round trip should not degrade data for flac formats"""
original = get_wav_data('float32', num_channels)
data = original
for i in range(10):
path = self.get_temp_path(f'{i}.flac')
sox_io_backend.save(path, data, sample_rate, compression=compression_level)
data, sr = sox_io_backend.load(path)
assert sr == sample_rate
self.assertEqual(original, data)
Loading

0 comments on commit 3324283

Please sign in to comment.