Skip to content

Commit

Permalink
Tweak TempDirMixin
Browse files Browse the repository at this point in the history
  • Loading branch information
mthrok committed Jun 26, 2020
1 parent e57ea56 commit b9f8732
Show file tree
Hide file tree
Showing 6 changed files with 66 additions and 60 deletions.
36 changes: 21 additions & 15 deletions test/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,28 +92,34 @@ def set_audio_backend(backend):
class TempDirMixin:
"""Mixin to provide easy access to temp dir"""
temp_dir_ = None
base_temp_dir = None
temp_dir = None

def setUp(self):
super().setUp()
self._init_temp_dir()
@classmethod
def setUpClass(cls):
super().setUpClass()
# If TORCHAUDIO_TEST_TEMP_DIR is set, use it instead of temporary directory.
# this is handy for debugging.
key = 'TORCHAUDIO_TEST_TEMP_DIR'
if key in os.environ:
cls.base_temp_dir = os.environ[key]
else:
cls.temp_dir_ = tempfile.TemporaryDirectory()
cls.base_temp_dir = cls.temp_dir_.name

def tearDown(self):
@classmethod
def tearDownClass(cls):
super().tearDownClass()
self._clean_up_temp_dir()
if isinstance(cls.temp_dir_, tempfile.TemporaryDirectory):
cls.temp_dir_.cleanup()

def _init_temp_dir(self):
self.temp_dir_ = tempfile.TemporaryDirectory()
self.temp_dir = self.temp_dir_.name

def _clean_up_temp_dir(self):
if self.temp_dir_ is not None:
self.temp_dir_.cleanup()
self.temp_dir_ = None
self.temp_dir = None
def setUp(self):
self.temp_dir = os.path.join(self.base_temp_dir, self.id())

def get_temp_path(self, *paths):
return os.path.join(self.temp_dir, *paths)
path = os.path.join(self.temp_dir, *paths)
os.makedirs(os.path.dirname(path), exist_ok=True)
return path


class TestBaseMixin:
Expand Down
10 changes: 5 additions & 5 deletions test/sox_io_backend/test_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class TestInfo(TempDirMixin, PytorchTestCase):
def test_wav(self, dtype, sample_rate, num_channels):
"""`sox_io_backend.info` can check wav file correctly"""
duration = 1
path = self.get_temp_path(f'{dtype}_{sample_rate}_{num_channels}.wav')
path = self.get_temp_path('data.wav')
data = get_wav_data(dtype, num_channels, normalize=False, num_frames=duration * sample_rate)
save_wav(path, data, sample_rate)
info = sox_io_backend.info(path)
Expand All @@ -44,7 +44,7 @@ def test_wav(self, dtype, sample_rate, num_channels):
def test_wav_multiple_channels(self, dtype, sample_rate, num_channels):
"""`sox_io_backend.info` can check wav file with channels more than 2 correctly"""
duration = 1
path = self.get_temp_path(f'{dtype}_{sample_rate}_{num_channels}.wav')
path = self.get_temp_path(f'data.wav')
data = get_wav_data(dtype, num_channels, normalize=False, num_frames=duration * sample_rate)
save_wav(path, data, sample_rate)
info = sox_io_backend.info(path)
Expand All @@ -60,7 +60,7 @@ def test_wav_multiple_channels(self, dtype, sample_rate, num_channels):
def test_mp3(self, sample_rate, num_channels, bit_rate):
"""`sox_io_backend.info` can check mp3 file correctly"""
duration = 1
path = self.get_temp_path(f'{sample_rate}_{num_channels}_{bit_rate}k.mp3')
path = self.get_temp_path(f'data.mp3')
sox_utils.gen_audio_file(
path, sample_rate, num_channels,
compression=bit_rate, duration=duration,
Expand All @@ -79,7 +79,7 @@ def test_mp3(self, sample_rate, num_channels, bit_rate):
def test_flac(self, sample_rate, num_channels, compression_level):
"""`sox_io_backend.info` can check flac file correctly"""
duration = 1
path = self.get_temp_path(f'{sample_rate}_{num_channels}_{compression_level}.flac')
path = self.get_temp_path(f'data.flac')
sox_utils.gen_audio_file(
path, sample_rate, num_channels,
compression=compression_level, duration=duration,
Expand All @@ -97,7 +97,7 @@ def test_flac(self, sample_rate, num_channels, compression_level):
def test_vorbis(self, sample_rate, num_channels, quality_level):
"""`sox_io_backend.info` can check vorbis file correctly"""
duration = 1
path = self.get_temp_path(f'{sample_rate}_{num_channels}_{quality_level}.vorbis')
path = self.get_temp_path(f'data.vorbis')
sox_utils.gen_audio_file(
path, sample_rate, num_channels,
compression=quality_level, duration=duration,
Expand Down
14 changes: 7 additions & 7 deletions test/sox_io_backend/test_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def assert_wav(self, dtype, sample_rate, num_channels, normalize, duration):
Wav data loaded with sox_io backend should match those with scipy
"""
path = self.get_temp_path(f'{dtype}_{sample_rate}_{num_channels}_{normalize}.wav')
path = self.get_temp_path(f'reference.wav')
data = get_wav_data(dtype, num_channels, normalize=normalize, num_frames=duration * sample_rate)
save_wav(path, data, sample_rate)
expected = load_wav(path, normalize=normalize)[0]
Expand Down Expand Up @@ -58,8 +58,8 @@ def assert_mp3(self, sample_rate, num_channels, bit_rate, duration):
By combining i & ii, step 2. and 4. allows to load reference mp3 data
without using torchaudio
"""
path = self.get_temp_path(f'{sample_rate}_{num_channels}_{bit_rate}_{duration}.mp3')
ref_path = f'{path}.wav'
path = self.get_temp_path('1.original.mp3')
ref_path = self.get_temp_path('2.reference.wav')

# 1. Generate mp3 with sox
sox_utils.gen_audio_file(
Expand All @@ -80,8 +80,8 @@ def assert_flac(self, sample_rate, num_channels, compression_level, duration):
This test takes the same strategy as mp3 to compare the result
"""
path = self.get_temp_path(f'{sample_rate}_{num_channels}_{compression_level}_{duration}.flac')
ref_path = f'{path}.wav'
path = self.get_temp_path('1.original.flac')
ref_path = self.get_temp_path('2.reference.wav')

# 1. Generate flac with sox
sox_utils.gen_audio_file(
Expand All @@ -102,8 +102,8 @@ def assert_vorbis(self, sample_rate, num_channels, quality_level, duration):
This test takes the same strategy as mp3 to compare the result
"""
path = self.get_temp_path(f'{sample_rate}_{num_channels}_{quality_level}_{duration}.vorbis')
ref_path = f'{path}.wav'
path = self.get_temp_path('1.original.vorbis')
ref_path = self.get_temp_path('2.reference.wav')

# 1. Generate vorbis with sox
sox_utils.gen_audio_file(
Expand Down
4 changes: 2 additions & 2 deletions test/sox_io_backend/test_roundtrip.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def test_roundtrip_wav(self, dtype, sample_rate, num_channels, normalize):
original = get_wav_data(dtype, num_channels, normalize=normalize)
data = original
for i in range(10):
path = self.get_temp_path(f'{dtype}_{sample_rate}_{num_channels}_{normalize}_{i}.wav')
path = self.get_temp_path(f'{i}.wav')
sox_io_backend.save(path, data, sample_rate)
data, sr = sox_io_backend.load(path, normalize=normalize)
assert sr == sample_rate
Expand All @@ -46,7 +46,7 @@ def test_roundtrip_flac(self, sample_rate, num_channels, compression_level):
original = get_wav_data('float32', num_channels)
data = original
for i in range(10):
path = self.get_temp_path(f'{sample_rate}_{num_channels}_{compression_level}_{i}.flac')
path = self.get_temp_path(f'{i}.flac')
sox_io_backend.save(path, data, sample_rate, compression=compression_level)
data, sr = sox_io_backend.load(path)
assert sr == sample_rate
Expand Down
54 changes: 27 additions & 27 deletions test/sox_io_backend/test_save.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
class SaveTestBase(TempDirMixin, PytorchTestCase):
def assert_wav(self, dtype, sample_rate, num_channels, num_frames):
"""`sox_io_backend.save` can save wav format."""
path = self.get_temp_path(f'test_wav_{dtype}_{sample_rate}_{num_channels}.wav')
path = self.get_temp_path('data.wav')
expected = get_wav_data(dtype, num_channels, num_frames=num_frames)
sox_io_backend.save(path, expected, sample_rate)
found = load_wav(path)[0]
Expand Down Expand Up @@ -55,11 +55,11 @@ def assert_mp3(self, sample_rate, num_channels, bit_rate, duration):
tensor -------> compare <--------- tensor
"""
src_path = self.get_temp_path(f'test_mp3_{sample_rate}_{num_channels}_{bit_rate}_{duration}.wav')
mp3_path = f'{src_path}.mp3'
wav_path = f'{mp3_path}.wav'
mp3_path_sox = f'{src_path}.sox.mp3'
wav_path_sox = f'{mp3_path_sox}.wav'
src_path = self.get_temp_path('1.reference.wav')
mp3_path = self.get_temp_path('2.1.torchaudio.mp3')
wav_path = self.get_temp_path('2.2.torchaudio.wav')
mp3_path_sox = self.get_temp_path('3.1.sox.mp3')
wav_path_sox = self.get_temp_path('3.2.sox.wav')

# 1. Generate original wav
data = get_wav_data('float32', num_channels, normalize=True, num_frames=duration * sample_rate)
Expand All @@ -86,29 +86,29 @@ def assert_flac(self, sample_rate, num_channels, compression_level, duration):
This test takes the same strategy as mp3 to compare the result
"""
src_path = self.get_temp_path(f'test_flac_{sample_rate}_{num_channels}_{compression_level}_{duration}.wav')
flac_path = f'{src_path}.flac'
wav_path = f'{flac_path}.wav'
flac_path_sox = f'{src_path}.sox.flac'
wav_path_sox = f'{flac_path_sox}.wav'
src_path = self.get_temp_path('1.reference.wav')
flc_path = self.get_temp_path('2.1.torchaudio.flac')
wav_path = self.get_temp_path('2.2.torchaudio.wav')
flc_path_sox = self.get_temp_path('3.1.sox.flac')
wav_path_sox = self.get_temp_path('3.2.sox.wav')

# 1. Generate original wav
data = get_wav_data('float32', num_channels, normalize=True, num_frames=duration * sample_rate)
save_wav(src_path, data, sample_rate)
# 2.1. Convert the original wav to flac with torchaudio
sox_io_backend.save(
flac_path, load_wav(src_path)[0], sample_rate, compression=compression_level)
flc_path, load_wav(src_path)[0], sample_rate, compression=compression_level)
# 2.2. Convert the flac to wav with Sox
# converting to 32 bit because flac file has 24 bit depth which scipy cannot handle.
sox_utils.convert_audio_file(flac_path, wav_path, bit_depth=32)
sox_utils.convert_audio_file(flc_path, wav_path, bit_depth=32)
# 2.3. Load
found = load_wav(wav_path)[0]

# 3.1. Convert the original wav to flac with SoX
sox_utils.convert_audio_file(src_path, flac_path_sox, compression=compression_level)
sox_utils.convert_audio_file(src_path, flc_path_sox, compression=compression_level)
# 3.2. Convert the flac to wav with Sox
# converting to 32 bit because flac file has 24 bit depth which scipy cannot handle.
sox_utils.convert_audio_file(flac_path_sox, wav_path_sox, bit_depth=32)
sox_utils.convert_audio_file(flc_path_sox, wav_path_sox, bit_depth=32)
# 3.3. Load
expected = load_wav(wav_path_sox)[0]

Expand All @@ -119,27 +119,27 @@ def _assert_vorbis(self, sample_rate, num_channels, quality_level, duration):
This test takes the same strategy as mp3 to compare the result
"""
src_path = self.get_temp_path(f'test_vorbis_{sample_rate}_{num_channels}_{quality_level}_{duration}.wav')
vorbis_path = f'{src_path}.vorbis'
wav_path = f'{vorbis_path}.wav'
vorbis_path_sox = f'{src_path}.sox.vorbis'
wav_path_sox = f'{vorbis_path_sox}.wav'
src_path = self.get_temp_path('1.reference.wav')
vbs_path = self.get_temp_path('2.1.torchaudio.vorbis')
wav_path = self.get_temp_path('2.2.torchaudio.wav')
vbs_path_sox = self.get_temp_path('3.1.sox.vorbis')
wav_path_sox = self.get_temp_path('3.2.sox.wav')

# 1. Generate original wav
data = get_wav_data('int16', num_channels, normalize=False, num_frames=duration * sample_rate)
save_wav(src_path, data, sample_rate)
# 2.1. Convert the original wav to vorbis with torchaudio
sox_io_backend.save(
vorbis_path, load_wav(src_path)[0], sample_rate, compression=quality_level)
vbs_path, load_wav(src_path)[0], sample_rate, compression=quality_level)
# 2.2. Convert the vorbis to wav with Sox
sox_utils.convert_audio_file(vorbis_path, wav_path)
sox_utils.convert_audio_file(vbs_path, wav_path)
# 2.3. Load
found = load_wav(wav_path)[0]

# 3.1. Convert the original wav to vorbis with SoX
sox_utils.convert_audio_file(src_path, vorbis_path_sox, compression=quality_level)
sox_utils.convert_audio_file(src_path, vbs_path_sox, compression=quality_level)
# 3.2. Convert the vorbis to wav with Sox
sox_utils.convert_audio_file(vorbis_path_sox, wav_path_sox)
sox_utils.convert_audio_file(vbs_path_sox, wav_path_sox)
# 3.3. Load
expected = load_wav(wav_path_sox)[0]

Expand Down Expand Up @@ -269,7 +269,7 @@ class TestSaveParams(TempDirMixin, PytorchTestCase):
@parameterized.expand([(True, ), (False, )], name_func=get_test_name)
def test_channels_first(self, channels_first):
"""channels_first swaps axes"""
path = self.get_temp_path('test_channel_first_{channels_first}.wav')
path = self.get_temp_path('data.wav')
data = get_wav_data('int32', 2, channels_first=channels_first)
sox_io_backend.save(
path, data, 8000, channels_first=channels_first)
Expand All @@ -282,7 +282,7 @@ def test_channels_first(self, channels_first):
], name_func=get_test_name)
def test_noncontiguous(self, dtype):
"""Noncontiguous tensors are saved correctly"""
path = self.get_temp_path('test_uncontiguous_{dtype}.wav')
path = self.get_temp_path('data.wav')
expected = get_wav_data(dtype, 4)[::2, ::2]
assert not expected.is_contiguous()
sox_io_backend.save(path, expected, 8000)
Expand All @@ -294,7 +294,7 @@ def test_noncontiguous(self, dtype):
])
def test_tensor_preserve(self, dtype):
"""save function should not alter Tensor"""
path = self.get_temp_path(f'test_preserve_{dtype}.wav')
path = self.get_temp_path('data.wav')
expected = get_wav_data(dtype, 4)[::2, ::2]

data = expected.clone()
Expand Down
8 changes: 4 additions & 4 deletions test/sox_io_backend/test_torchscript.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def test_info_wav(self, dtype, sample_rate, num_channels):
data = get_wav_data(dtype, num_channels, normalize=False, num_frames=1 * sample_rate)
save_wav(audio_path, data, sample_rate)

script_path = self.get_temp_path('info_func')
script_path = self.get_temp_path('info_func.zip')
torch.jit.script(py_info_func).save(script_path)
ts_info_func = torch.jit.load(script_path)

Expand All @@ -78,7 +78,7 @@ def test_load_wav(self, dtype, sample_rate, num_channels, normalize, channels_fi
data = get_wav_data(dtype, num_channels, normalize=False, num_frames=1 * sample_rate)
save_wav(audio_path, data, sample_rate)

script_path = self.get_temp_path('load_func')
script_path = self.get_temp_path('load_func.zip')
torch.jit.script(py_load_func).save(script_path)
ts_load_func = torch.jit.load(script_path)

Expand All @@ -96,7 +96,7 @@ def test_load_wav(self, dtype, sample_rate, num_channels, normalize, channels_fi
[1, 2],
)), name_func=get_test_name)
def test_save_wav(self, dtype, sample_rate, num_channels):
script_path = self.get_temp_path('save_func')
script_path = self.get_temp_path('save_func.zip')
torch.jit.script(py_save_func).save(script_path)
ts_save_func = torch.jit.load(script_path)

Expand All @@ -121,7 +121,7 @@ def test_save_wav(self, dtype, sample_rate, num_channels):
list(range(9)),
)), name_func=get_test_name)
def test_save_flac(self, sample_rate, num_channels, compression_level):
script_path = self.get_temp_path('save_func')
script_path = self.get_temp_path('save_func.zip')
torch.jit.script(py_save_func).save(script_path)
ts_save_func = torch.jit.load(script_path)

Expand Down

0 comments on commit b9f8732

Please sign in to comment.