Skip to content

Commit

Permalink
Support file-like object in save func (#1141)
Browse files Browse the repository at this point in the history
* Support file-like object in save func

* Disable CircleCI cache for TP artifacts for cleaner build
  • Loading branch information
mthrok authored Jan 15, 2021
1 parent 72b7680 commit f1d8d1e
Show file tree
Hide file tree
Showing 14 changed files with 390 additions and 74 deletions.
4 changes: 0 additions & 4 deletions .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -411,7 +411,6 @@ jobs:
paths:
- conda
- env
- third_party/install
- run:
name: Install torchaudio
command: .circleci/unittest/linux/scripts/install.sh
Expand Down Expand Up @@ -456,7 +455,6 @@ jobs:
paths:
- conda
- env
- third_party/install
- run:
name: Install torchaudio
command: docker run -t --gpus all -e UPLOAD_CHANNEL -v $PWD:$PWD -w $PWD "${image_name}" .circleci/unittest/linux/scripts/install.sh
Expand Down Expand Up @@ -569,7 +567,6 @@ jobs:
paths:
- conda
- env
- third_party/install
- run:
name: Install torchaudio
command: .circleci/unittest/linux/scripts/install.sh
Expand Down Expand Up @@ -606,7 +603,6 @@ jobs:
paths:
- conda
- env
- third_party/install
- run:
name: Run style check
command: .circleci/unittest/linux/scripts/run_style_checks.sh
Expand Down
4 changes: 0 additions & 4 deletions .circleci/config.yml.in
Original file line number Diff line number Diff line change
Expand Up @@ -411,7 +411,6 @@ jobs:
paths:
- conda
- env
- third_party/install
- run:
name: Install torchaudio
command: .circleci/unittest/linux/scripts/install.sh
Expand Down Expand Up @@ -456,7 +455,6 @@ jobs:
paths:
- conda
- env
- third_party/install
- run:
name: Install torchaudio
command: docker run -t --gpus all -e UPLOAD_CHANNEL -v $PWD:$PWD -w $PWD "${image_name}" .circleci/unittest/linux/scripts/install.sh
Expand Down Expand Up @@ -569,7 +567,6 @@ jobs:
paths:
- conda
- env
- third_party/install
- run:
name: Install torchaudio
command: .circleci/unittest/linux/scripts/install.sh
Expand Down Expand Up @@ -606,7 +603,6 @@ jobs:
paths:
- conda
- env
- third_party/install
- run:
name: Run style check
command: .circleci/unittest/linux/scripts/run_style_checks.sh
Expand Down
42 changes: 41 additions & 1 deletion test/torchaudio_unittest/soundfile_backend/save_test.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import io
import itertools
from unittest.mock import patch

from torchaudio._internal import module_utils as _mod_utils
from torchaudio.backend import _soundfile_backend as soundfile_backend
from parameterized import parameterized

from torchaudio_unittest.common_utils import (
TempDirMixin,
Expand Down Expand Up @@ -209,3 +209,43 @@ def test_channels_first(self, channels_first):
found = load_wav(path)[0]
expected = data if channels_first else data.transpose(1, 0)
self.assertEqual(found, expected, atol=1e-4, rtol=1e-8)


@skipIfNoModule("soundfile")
class TestFileObject(TempDirMixin, PytorchTestCase):
def _test_fileobj(self, ext):
"""Saving audio to file-like object works"""
sample_rate = 16000
path = self.get_temp_path(f'test.{ext}')

subtype = 'FLOAT' if ext == 'wav' else None
data = get_wav_data('float32', num_channels=2)
soundfile.write(path, data.numpy().T, sample_rate, subtype=subtype)
expected = soundfile.read(path, dtype='float32')[0]

fileobj = io.BytesIO()
soundfile_backend.save(fileobj, data, sample_rate, format=ext)
fileobj.seek(0)
found, sr = soundfile.read(fileobj, dtype='float32')

assert sr == sample_rate
self.assertEqual(expected, found)

def test_fileobj_wav(self):
"""Saving audio via file-like object works"""
self._test_fileobj('wav')

@skipIfFormatNotSupported("FLAC")
def test_fileobj_flac(self):
"""Saving audio via file-like object works"""
self._test_fileobj('flac')

@skipIfFormatNotSupported("NIST")
def test_fileobj_nist(self):
"""Saving audio via file-like object works"""
self._test_fileobj('NIST')

@skipIfFormatNotSupported("OGG")
def test_fileobj_ogg(self):
"""Saving audio via file-like object works"""
self._test_fileobj('OGG')
86 changes: 86 additions & 0 deletions test/torchaudio_unittest/sox_io_backend/save_test.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import io
import itertools

from torchaudio.backend import sox_io_backend
Expand Down Expand Up @@ -417,3 +418,88 @@ def test_tensor_preserve(self, dtype):
sox_io_backend.save(path, data, 8000)

self.assertEqual(data, expected)


@skipIfNoExtension
@skipIfNoExec('sox')
class TestFileObject(SaveTestBase):
"""
We campare the result of file-like object input against file path input because
`save` function is rigrously tested for file path inputs to match libsox's result,
"""
@parameterized.expand([
('wav', None),
('mp3', 128),
('mp3', 320),
('flac', 0),
('flac', 5),
('flac', 8),
('vorbis', -1),
('vorbis', 10),
('amb', None),
])
def test_fileobj(self, ext, compression):
"""Saving audio to file object returns the same result as via file path."""
sample_rate = 16000
dtype = 'float32'
num_channels = 2
num_frames = 16000
channels_first = True

data = get_wav_data(dtype, num_channels, num_frames=num_frames)

ref_path = self.get_temp_path(f'reference.{ext}')
res_path = self.get_temp_path(f'test.{ext}')
sox_io_backend.save(
ref_path, data, channels_first=channels_first,
sample_rate=sample_rate, compression=compression)
with open(res_path, 'wb') as fileobj:
sox_io_backend.save(
fileobj, data, channels_first=channels_first,
sample_rate=sample_rate, compression=compression, format=ext)

expected_data, _ = sox_io_backend.load(ref_path)
data, sr = sox_io_backend.load(res_path)

assert sample_rate == sr
self.assertEqual(expected_data, data)

@parameterized.expand([
('wav', None),
('mp3', 128),
('mp3', 320),
('flac', 0),
('flac', 5),
('flac', 8),
('vorbis', -1),
('vorbis', 10),
('amb', None),
])
def test_bytesio(self, ext, compression):
"""Saving audio to BytesIO object returns the same result as via file path."""
sample_rate = 16000
dtype = 'float32'
num_channels = 2
num_frames = 16000
channels_first = True

data = get_wav_data(dtype, num_channels, num_frames=num_frames)

ref_path = self.get_temp_path(f'reference.{ext}')
res_path = self.get_temp_path(f'test.{ext}')
sox_io_backend.save(
ref_path, data, channels_first=channels_first,
sample_rate=sample_rate, compression=compression)
fileobj = io.BytesIO()
sox_io_backend.save(
fileobj, data, channels_first=channels_first,
sample_rate=sample_rate, compression=compression, format=ext)
fileobj.seek(0)
with open(res_path, 'wb') as file_:
file_.write(fileobj.read())

expected_data, _ = sox_io_backend.load(ref_path)
data, sr = sox_io_backend.load(res_path)

assert sample_rate == sr
self.assertEqual(expected_data, data)
18 changes: 13 additions & 5 deletions torchaudio/backend/_soundfile_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ def save(
sample_rate: int,
channels_first: bool = True,
compression: Optional[float] = None,
format: Optional[str] = None,
):
"""Save audio data to file.
Expand Down Expand Up @@ -168,6 +169,9 @@ def save(
otherwise ``[time, channel]``.
compression (Optional[float]):
Not used. It is here only for interface compatibility reson with "sox_io" backend.
format (str, optional):
Output audio format. This is required when the output audio format cannot be infered from
``filepath``, (such as file extension or ``name`` attribute of the given file object).
"""
if src.ndim != 2:
raise ValueError(f"Expected 2D Tensor, got {src.ndim}D.")
Expand All @@ -176,8 +180,13 @@ def save(
'`save` function of "soundfile" backend does not support "compression" parameter. '
"The argument is silently ignored."
)
if hasattr(filepath, 'write'):
if format is None:
raise RuntimeError('`format` is required when saving to file object.')
ext = format
else:
ext = str(filepath).split(".")[-1].lower()

ext = str(filepath).split(".")[-1].lower()
if ext != "wav":
subtype = None
elif src.dtype == torch.uint8:
Expand All @@ -193,17 +202,16 @@ def save(
else:
raise ValueError(f"Unsupported dtype for WAV: {src.dtype}")

format_ = None
# sph is a extension used in TED-LIUM but soundfile does not recognize it as NIST format,
# so we extend the extensions manually here
if ext in ["nis", "nist", "sph"]:
format_ = "NIST"
if ext in ["nis", "nist", "sph"] and format is None:
format = "NIST"

if channels_first:
src = src.t()

soundfile.write(
file=filepath, data=src, samplerate=sample_rate, subtype=subtype, format=format_
file=filepath, data=src, samplerate=sample_rate, subtype=subtype, format=format
)


Expand Down
44 changes: 28 additions & 16 deletions torchaudio/backend/sox_io_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,13 +134,33 @@ def load(
return signal.get_tensor(), signal.get_sample_rate()


@torch.jit.unused
def _save(
filepath: str,
src: torch.Tensor,
sample_rate: int,
channels_first: bool = True,
compression: Optional[float] = None,
format: Optional[str] = None,
):
if hasattr(filepath, 'write'):
if format is None:
raise RuntimeError('`format` is required when saving to file object.')
torchaudio._torchaudio.save_audio_fileobj(
filepath, src, sample_rate, channels_first, compression, format)
else:
torch.ops.torchaudio.sox_io_save_audio_file(
os.fspath(filepath), src, sample_rate, channels_first, compression, format)


@_mod_utils.requires_module('torchaudio._torchaudio')
def save(
filepath: str,
src: torch.Tensor,
sample_rate: int,
channels_first: bool = True,
compression: Optional[float] = None,
format: Optional[str] = None,
):
"""Save audio data to file.
Expand Down Expand Up @@ -184,23 +204,15 @@ def save(
| and lowest quality. Default: ``3``.
See the detail at http://sox.sourceforge.net/soxformat.html.
format (str, optional):
Output audio format. This is required when the output audio format cannot be infered from
``filepath``, (such as file extension or ``name`` attribute of the given file object).
"""
# Cast to str in case type is `pathlib.Path`
filepath = str(filepath)
if compression is None:
ext = str(filepath).split('.')[-1].lower()
if ext in ['wav', 'sph', 'amb', 'amr-nb']:
compression = 0.
elif ext == 'mp3':
compression = -4.5
elif ext == 'flac':
compression = 8.
elif ext in ['ogg', 'vorbis']:
compression = 3.
else:
raise RuntimeError(f'Unsupported file type: "{ext}"')
signal = torch.classes.torchaudio.TensorSignal(src, sample_rate, channels_first)
torch.ops.torchaudio.sox_io_save_audio_file(filepath, signal, compression)
if not torch.jit.is_scripting():
_save(filepath, src, sample_rate, channels_first, compression, format)
return
torch.ops.torchaudio.sox_io_save_audio_file(
filepath, src, sample_rate, channels_first, compression, format)


@_mod_utils.requires_module('torchaudio._torchaudio')
Expand Down
4 changes: 4 additions & 0 deletions torchaudio/csrc/pybind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,4 +100,8 @@ PYBIND11_MODULE(_torchaudio, m) {
"load_audio_fileobj",
&torchaudio::sox_io::load_audio_fileobj,
"Load audio from file object.");
m.def(
"save_audio_fileobj",
&torchaudio::sox_io::save_audio_fileobj,
"Save audio to file obj.");
}
8 changes: 4 additions & 4 deletions torchaudio/csrc/sox/effects.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,8 @@ c10::intrusive_ptr<TensorSignal> apply_effects_tensor(
// Create SoxEffectsChain
const auto dtype = in_tensor.dtype();
torchaudio::sox_effects_chain::SoxEffectsChain chain(
/*input_encoding=*/get_encodinginfo("wav", dtype, 0.),
/*output_encoding=*/get_encodinginfo("wav", dtype, 0.));
/*input_encoding=*/get_encodinginfo("wav", dtype),
/*output_encoding=*/get_encodinginfo("wav", dtype));

// Prepare output buffer
std::vector<sox_sample_t> out_buffer;
Expand Down Expand Up @@ -112,7 +112,7 @@ c10::intrusive_ptr<TensorSignal> apply_effects_file(
// Create and run SoxEffectsChain
torchaudio::sox_effects_chain::SoxEffectsChain chain(
/*input_encoding=*/sf->encoding,
/*output_encoding=*/get_encodinginfo("wav", dtype, 0.));
/*output_encoding=*/get_encodinginfo("wav", dtype));

chain.addInputFile(sf);
for (const auto& effect : effects) {
Expand Down Expand Up @@ -193,7 +193,7 @@ std::tuple<torch::Tensor, int64_t> apply_effects_fileobj(
const auto dtype = get_dtype(sf->encoding.encoding, sf->signal.precision);
torchaudio::sox_effects_chain::SoxEffectsChain chain(
/*input_encoding=*/sf->encoding,
/*output_encoding=*/get_encodinginfo("wav", dtype, 0.));
/*output_encoding=*/get_encodinginfo("wav", dtype));
chain.addInputFileObj(sf, in_buf, in_buffer_size, &fileobj);
for (const auto& effect : effects) {
chain.addEffect(effect);
Expand Down
Loading

0 comments on commit f1d8d1e

Please sign in to comment.