From 793eeab8745e17faac6ba0df85ceca39c9c20782 Mon Sep 17 00:00:00 2001 From: moto <855818+mthrok@users.noreply.github.com> Date: Thu, 25 Jun 2020 19:16:56 -0400 Subject: [PATCH] Add load function (#731) This is a part of PRs to add new "sox_io" backend. #726 and depends on #718 and #728 . This PR adds `load` function to "sox_io" backend, which is tested on the following audio formats; - `wav` - `mp3` - `flac` - `ogg/vorbis` * By default, "sox_io" backend returns Tensor with `float32` dtype and the shape of `[channel, time]`. The samples are normalized to fit in the range of `[-1.0, 1.0]`. Unlike existing "sox" backend, the new `load` function can handle WAV file natively, when the input format is WAV with integer type, (such as 32-bit signed integer, 16-bit signed integer and 8-bit unsigned integer) by providing `normalize=False`, this function can return integer Tensor, where the samples are expressed within the whole range of the corresponding dtype, that is, `int32` tensor for `32-bit PCM`, `int16` for `16-bit PCM` and `uint8` for `8-bit PCM`. This behavior follows [scipy.io.wavfile.read](https://docs.scipy.org/doc/scipy/reference/generated/scipy.io.wavfile.read.html). `normalize` parameter has no effect for other formats and the load function always return normalized value with `float32` Tensor. __* Note__ The current binary distribution of torchaudio does not contain `ogg/vorbis` and `opus` codecs. To handle these files, one needs to build torchaudio from the source with proper codecs in the system. __Note 2__ Since this PR, `scipy` becomes required module for running test. --- test/sox_io_backend/common.py | 88 +++++++++ test/sox_io_backend/sox_utils.py | 18 +- test/sox_io_backend/test_info.py | 20 +- test/sox_io_backend/test_load.py | 245 ++++++++++++++++++++++++ test/sox_io_backend/test_torchscript.py | 42 +++- torchaudio/backend/sox_io_backend.py | 70 +++++++ torchaudio/csrc/register.cpp | 25 +++ torchaudio/csrc/sox_io.cpp | 115 +++++++---- torchaudio/csrc/sox_io.h | 16 +- torchaudio/csrc/sox_utils.cpp | 113 +++++++++++ torchaudio/csrc/sox_utils.h | 75 ++++++++ 11 files changed, 772 insertions(+), 55 deletions(-) create mode 100644 test/sox_io_backend/test_load.py create mode 100644 torchaudio/csrc/sox_utils.cpp create mode 100644 torchaudio/csrc/sox_utils.h diff --git a/test/sox_io_backend/common.py b/test/sox_io_backend/common.py index 688066fe64..d477852e12 100644 --- a/test/sox_io_backend/common.py +++ b/test/sox_io_backend/common.py @@ -1,2 +1,90 @@ +from typing import Optional + +import torch +import scipy.io.wavfile + + def get_test_name(func, _, params): return f'{func.__name__}_{"_".join(str(p) for p in params.args)}' + + +def normalize_wav(tensor: torch.Tensor) -> torch.Tensor: + if tensor.dtype == torch.float32: + pass + elif tensor.dtype == torch.int32: + tensor = tensor.to(torch.float32) + tensor[tensor > 0] /= 2147483647. + tensor[tensor < 0] /= 2147483648. + elif tensor.dtype == torch.int16: + tensor = tensor.to(torch.float32) + tensor[tensor > 0] /= 32767. + tensor[tensor < 0] /= 32768. + elif tensor.dtype == torch.uint8: + tensor = tensor.to(torch.float32) - 128 + tensor[tensor > 0] /= 127. + tensor[tensor < 0] /= 128. + return tensor + + +def get_wav_data( + dtype: str, + num_channels: int, + *, + num_frames: Optional[int] = None, + normalize: bool = True, + channels_first: bool = True, +): + """Generate linear signal of the given dtype and num_channels + + Data range is + [-1.0, 1.0] for float32, + [-2147483648, 2147483647] for int32 + [-32768, 32767] for int16 + [0, 255] for uint8 + + num_frames allow to change the linear interpolation parameter. + Default values are 256 for uint8, else 1 << 16. + 1 << 16 as default is so that int16 value range is completely covered. + """ + dtype_ = getattr(torch, dtype) + + if num_frames is None: + if dtype == 'uint8': + num_frames = 256 + else: + num_frames = 1 << 16 + + if dtype == 'uint8': + base = torch.linspace(0, 255, num_frames, dtype=dtype_) + if dtype == 'float32': + base = torch.linspace(-1., 1., num_frames, dtype=dtype_) + if dtype == 'int32': + base = torch.linspace(-2147483648, 2147483647, num_frames, dtype=dtype_) + if dtype == 'int16': + base = torch.linspace(-32768, 32767, num_frames, dtype=dtype_) + data = base.repeat([num_channels, 1]) + if not channels_first: + data = data.transpose(1, 0) + if normalize: + data = normalize_wav(data) + return data + + +def load_wav(path: str, normalize=True, channels_first=True) -> torch.Tensor: + """Load wav file without torchaudio""" + sample_rate, data = scipy.io.wavfile.read(path) + data = torch.from_numpy(data.copy()) + if data.ndim == 1: + data = data.unsqueeze(1) + if normalize: + data = normalize_wav(data) + if channels_first: + data = data.transpose(1, 0) + return data, sample_rate + + +def save_wav(path, data, sample_rate, channels_first=True): + """Save wav file without torchaudio""" + if channels_first: + data = data.transpose(1, 0) + scipy.io.wavfile.write(path, sample_rate, data.numpy()) diff --git a/test/sox_io_backend/sox_utils.py b/test/sox_io_backend/sox_utils.py index 5887928a34..c30224158a 100644 --- a/test/sox_io_backend/sox_utils.py +++ b/test/sox_io_backend/sox_utils.py @@ -26,6 +26,9 @@ def gen_audio_file( *, encoding=None, bit_depth=None, compression=None, attenuation=None, duration=1, ): """Generate synthetic audio file with `sox` command.""" + if path.endswith('.wav'): + raise RuntimeError( + 'Use get_wav_data and save_wav to generate wav file for accurate result.') command = [ 'sox', '-V', # verbose @@ -51,4 +54,17 @@ def gen_audio_file( command += ['vol', f'-{attenuation}dB'] print(' '.join(command)) subprocess.run(command, check=True) - subprocess.run(['soxi', path], check=True) + + +def convert_audio_file( + src_path, dst_path, + *, bit_depth=None, compression=None): + """Convert audio file with `sox` command.""" + command = ['sox', '-V', str(src_path)] + if bit_depth is not None: + command += ['--bits', str(bit_depth)] + if compression is not None: + command += ['--compression', str(compression)] + command += [dst_path] + print(' '.join(command)) + subprocess.run(command, check=True) diff --git a/test/sox_io_backend/test_info.py b/test/sox_io_backend/test_info.py index 7954af782f..91c13278f6 100644 --- a/test/sox_io_backend/test_info.py +++ b/test/sox_io_backend/test_info.py @@ -10,7 +10,9 @@ skipIfNoExtension, ) from .common import ( - get_test_name + get_test_name, + get_wav_data, + save_wav, ) from . import sox_utils @@ -27,12 +29,8 @@ def test_wav(self, dtype, sample_rate, num_channels): """`sox_io_backend.info` can check wav file correctly""" duration = 1 path = self.get_temp_path(f'{dtype}_{sample_rate}_{num_channels}.wav') - sox_utils.gen_audio_file( - path, sample_rate, num_channels, - bit_depth=sox_utils.get_bit_depth(dtype), - encoding=sox_utils.get_encoding(dtype), - duration=duration, - ) + data = get_wav_data(dtype, num_channels, normalize=False, num_frames=duration * sample_rate) + save_wav(path, data, sample_rate) info = sox_io_backend.info(path) assert info.get_sample_rate() == sample_rate assert info.get_num_frames() == sample_rate * duration @@ -47,12 +45,8 @@ def test_wav_multiple_channels(self, dtype, sample_rate, num_channels): """`sox_io_backend.info` can check wav file with channels more than 2 correctly""" duration = 1 path = self.get_temp_path(f'{dtype}_{sample_rate}_{num_channels}.wav') - sox_utils.gen_audio_file( - path, sample_rate, num_channels, - bit_depth=sox_utils.get_bit_depth(dtype), - encoding=sox_utils.get_encoding(dtype), - duration=duration, - ) + data = get_wav_data(dtype, num_channels, normalize=False, num_frames=duration * sample_rate) + save_wav(path, data, sample_rate) info = sox_io_backend.info(path) assert info.get_sample_rate() == sample_rate assert info.get_num_frames() == sample_rate * duration diff --git a/test/sox_io_backend/test_load.py b/test/sox_io_backend/test_load.py new file mode 100644 index 0000000000..a04550a666 --- /dev/null +++ b/test/sox_io_backend/test_load.py @@ -0,0 +1,245 @@ +import itertools + +from torchaudio.backend import sox_io_backend +from parameterized import parameterized + +from ..common_utils import ( + TempDirMixin, + PytorchTestCase, + skipIfNoExec, + skipIfNoExtension, +) +from .common import ( + get_test_name, + get_wav_data, + load_wav, + save_wav, +) +from . import sox_utils + + +class LoadTestBase(TempDirMixin, PytorchTestCase): + def assert_wav(self, dtype, sample_rate, num_channels, normalize, duration): + """`sox_io_backend.load` can load wav format correctly. + + Wav data loaded with sox_io backend should match those with scipy + """ + path = self.get_temp_path(f'{dtype}_{sample_rate}_{num_channels}_{normalize}.wav') + data = get_wav_data(dtype, num_channels, normalize=normalize, num_frames=duration * sample_rate) + save_wav(path, data, sample_rate) + expected = load_wav(path, normalize=normalize)[0] + data, sr = sox_io_backend.load(path, normalize=normalize) + assert sr == sample_rate + self.assertEqual(data, expected) + + def assert_mp3(self, sample_rate, num_channels, bit_rate, duration): + """`sox_io_backend.load` can load mp3 format. + + mp3 encoding introduces delay and boundary effects so + we create reference wav file from mp3 + + x + | + | 1. Generate mp3 with Sox + | + v 2. Convert to wav with Sox + mp3 ------------------------------> wav + | | + | 3. Load with torchaudio | 4. Load with scipy + | | + v v + tensor ----------> x <----------- tensor + 5. Compare + + Underlying assumptions are; + i. Conversion of mp3 to wav with Sox preserves data. + ii. Loading wav file with scipy is correct. + + By combining i & ii, step 2. and 4. allows to load reference mp3 data + without using torchaudio + """ + path = self.get_temp_path(f'{sample_rate}_{num_channels}_{bit_rate}_{duration}.mp3') + ref_path = f'{path}.wav' + + # 1. Generate mp3 with sox + sox_utils.gen_audio_file( + path, sample_rate, num_channels, + compression=bit_rate, duration=duration) + # 2. Convert to wav with sox + sox_utils.convert_audio_file(path, ref_path) + # 3. Load mp3 with torchaudio + data, sr = sox_io_backend.load(path) + # 4. Load wav with scipy + data_ref = load_wav(ref_path)[0] + # 5. Compare + assert sr == sample_rate + self.assertEqual(data, data_ref, atol=3e-03, rtol=1.3e-06) + + def assert_flac(self, sample_rate, num_channels, compression_level, duration): + """`sox_io_backend.load` can load flac format. + + This test takes the same strategy as mp3 to compare the result + """ + path = self.get_temp_path(f'{sample_rate}_{num_channels}_{compression_level}_{duration}.flac') + ref_path = f'{path}.wav' + + # 1. Generate flac with sox + sox_utils.gen_audio_file( + path, sample_rate, num_channels, + compression=compression_level, bit_depth=16, duration=duration) + # 2. Convert to wav with sox + sox_utils.convert_audio_file(path, ref_path) + # 3. Load flac with torchaudio + data, sr = sox_io_backend.load(path) + # 4. Load wav with scipy + data_ref = load_wav(ref_path)[0] + # 5. Compare + assert sr == sample_rate + self.assertEqual(data, data_ref, atol=4e-05, rtol=1.3e-06) + + def assert_vorbis(self, sample_rate, num_channels, quality_level, duration): + """`sox_io_backend.load` can load vorbis format. + + This test takes the same strategy as mp3 to compare the result + """ + path = self.get_temp_path(f'{sample_rate}_{num_channels}_{quality_level}_{duration}.vorbis') + ref_path = f'{path}.wav' + + # 1. Generate vorbis with sox + sox_utils.gen_audio_file( + path, sample_rate, num_channels, + compression=quality_level, bit_depth=16, duration=duration) + # 2. Convert to wav with sox + sox_utils.convert_audio_file(path, ref_path) + # 3. Load vorbis with torchaudio + data, sr = sox_io_backend.load(path) + # 4. Load wav with scipy + data_ref = load_wav(ref_path)[0] + # 5. Compare + assert sr == sample_rate + self.assertEqual(data, data_ref, atol=4e-05, rtol=1.3e-06) + + +@skipIfNoExec('sox') +@skipIfNoExtension +class TestLoad(LoadTestBase): + """Test the correctness of `sox_io_backend.load` for various formats""" + @parameterized.expand(list(itertools.product( + ['float32', 'int32', 'int16', 'uint8'], + [8000, 16000], + [1, 2], + [False, True], + )), name_func=get_test_name) + def test_wav(self, dtype, sample_rate, num_channels, normalize): + """`sox_io_backend.load` can load wav format correctly.""" + self.assert_wav(dtype, sample_rate, num_channels, normalize, duration=1) + + @parameterized.expand(list(itertools.product( + ['int16'], + [16000], + [2], + [False], + )), name_func=get_test_name) + def test_wav_large(self, dtype, sample_rate, num_channels, normalize): + """`sox_io_backend.load` can load large wav file correctly.""" + two_hours = 2 * 60 * 60 + self.assert_wav(dtype, sample_rate, num_channels, normalize, two_hours) + + @parameterized.expand(list(itertools.product( + ['float32', 'int32', 'int16', 'uint8'], + [4, 8, 16, 32], + )), name_func=get_test_name) + def test_multiple_channels(self, dtype, num_channels): + """`sox_io_backend.load` can load wav file with more than 2 channels.""" + sample_rate = 8000 + normalize = False + self.assert_wav(dtype, sample_rate, num_channels, normalize, duration=1) + + @parameterized.expand(list(itertools.product( + [8000, 16000, 44100], + [1, 2], + [96, 128, 160, 192, 224, 256, 320], + )), name_func=get_test_name) + def test_mp3(self, sample_rate, num_channels, bit_rate): + """`sox_io_backend.load` can load mp3 format correctly.""" + self.assert_mp3(sample_rate, num_channels, bit_rate, duration=1) + + @parameterized.expand(list(itertools.product( + [16000], + [2], + [128], + )), name_func=get_test_name) + def test_mp3_large(self, sample_rate, num_channels, bit_rate): + """`sox_io_backend.load` can load large mp3 file correctly.""" + two_hours = 2 * 60 * 60 + self.assert_mp3(sample_rate, num_channels, bit_rate, two_hours) + + @parameterized.expand(list(itertools.product( + [8000, 16000], + [1, 2], + list(range(9)), + )), name_func=get_test_name) + def test_flac(self, sample_rate, num_channels, compression_level): + """`sox_io_backend.load` can load flac format correctly.""" + self.assert_flac(sample_rate, num_channels, compression_level, duration=1) + + @parameterized.expand(list(itertools.product( + [16000], + [2], + [0], + )), name_func=get_test_name) + def test_flac_large(self, sample_rate, num_channels, compression_level): + """`sox_io_backend.load` can load large flac file correctly.""" + two_hours = 2 * 60 * 60 + self.assert_flac(sample_rate, num_channels, compression_level, two_hours) + + @parameterized.expand(list(itertools.product( + [8000, 16000], + [1, 2], + [-1, 0, 1, 2, 3, 3.6, 5, 10], + )), name_func=get_test_name) + def test_vorbis(self, sample_rate, num_channels, quality_level): + """`sox_io_backend.load` can load vorbis format correctly.""" + self.assert_vorbis(sample_rate, num_channels, quality_level, duration=1) + + @parameterized.expand(list(itertools.product( + [16000], + [2], + [10], + )), name_func=get_test_name) + def test_vorbis_large(self, sample_rate, num_channels, quality_level): + """`sox_io_backend.load` can load large vorbis file correctly.""" + two_hours = 2 * 60 * 60 + self.assert_vorbis(sample_rate, num_channels, quality_level, two_hours) + + +@skipIfNoExec('sox') +@skipIfNoExtension +class TestLoadParams(TempDirMixin, PytorchTestCase): + """Test the correctness of frame parameters of `sox_io_backend.load`""" + original = None + path = None + + def setUp(self): + super().setUp() + sample_rate = 8000 + self.original = get_wav_data('float32', num_channels=2) + self.path = self.get_temp_path('test.wave') + save_wav(self.path, self.original, sample_rate) + + @parameterized.expand(list(itertools.product( + [0, 1, 10, 100, 1000], + [-1, 1, 10, 100, 1000], + )), name_func=get_test_name) + def test_frame(self, frame_offset, num_frames): + """num_frames and frame_offset correctly specify the region of data""" + found, _ = sox_io_backend.load(self.path, frame_offset, num_frames) + frame_end = None if num_frames == -1 else frame_offset + num_frames + self.assertEqual(found, self.original[:, frame_offset:frame_end]) + + @parameterized.expand([(True, ), (False, )], name_func=get_test_name) + def test_channels_first(self, channels_first): + """channels_first swaps axes""" + found, _ = sox_io_backend.load(self.path, channels_first=channels_first) + expected = self.original if channels_first else self.original.transpose(1, 0) + self.assertEqual(found, expected) diff --git a/test/sox_io_backend/test_torchscript.py b/test/sox_io_backend/test_torchscript.py index aff488126c..c6e9df41e1 100644 --- a/test/sox_io_backend/test_torchscript.py +++ b/test/sox_io_backend/test_torchscript.py @@ -12,29 +12,34 @@ ) from .common import ( get_test_name, + get_wav_data, + save_wav ) -from . import sox_utils def py_info_func(filepath: str) -> torch.classes.torchaudio.SignalInfo: return sox_io_backend.info(filepath) +def py_load_func(filepath: str, normalize: bool, channels_first: bool): + return sox_io_backend.load( + filepath, normalize=normalize, channels_first=channels_first) + + @skipIfNoExec('sox') @skipIfNoExtension class SoxIO(TempDirMixin, TorchaudioTestCase): + """TorchScript-ability Test suite for `sox_io_backend`""" @parameterized.expand(list(itertools.product( ['float32', 'int32', 'int16', 'uint8'], [8000, 16000], [1, 2], )), name_func=get_test_name) def test_info_wav(self, dtype, sample_rate, num_channels): + """`sox_io_backend.info` is torchscript-able and returns the same result""" audio_path = self.get_temp_path(f'{dtype}_{sample_rate}_{num_channels}.wav') - sox_utils.gen_audio_file( - audio_path, sample_rate, num_channels, - bit_depth=sox_utils.get_bit_depth(dtype), - encoding=sox_utils.get_encoding(dtype), - ) + data = get_wav_data(dtype, num_channels, normalize=False, num_frames=1 * sample_rate) + save_wav(audio_path, data, sample_rate) script_path = self.get_temp_path('info_func') torch.jit.script(py_info_func).save(script_path) @@ -46,3 +51,28 @@ def test_info_wav(self, dtype, sample_rate, num_channels): assert py_info.get_sample_rate() == ts_info.get_sample_rate() assert py_info.get_num_frames() == ts_info.get_num_frames() assert py_info.get_num_channels() == ts_info.get_num_channels() + + @parameterized.expand(list(itertools.product( + ['float32', 'int32', 'int16', 'uint8'], + [8000, 16000], + [1, 2], + [False, True], + [False, True], + )), name_func=get_test_name) + def test_load_wav(self, dtype, sample_rate, num_channels, normalize, channels_first): + """`sox_io_backend.load` is torchscript-able and returns the same result""" + audio_path = self.get_temp_path(f'test_load_{dtype}_{sample_rate}_{num_channels}_{normalize}.wav') + data = get_wav_data(dtype, num_channels, normalize=False, num_frames=1 * sample_rate) + save_wav(audio_path, data, sample_rate) + + script_path = self.get_temp_path('load_func') + torch.jit.script(py_load_func).save(script_path) + ts_load_func = torch.jit.load(script_path) + + py_data, py_sr = py_load_func( + audio_path, normalize=normalize, channels_first=channels_first) + ts_data, ts_sr = ts_load_func( + audio_path, normalize=normalize, channels_first=channels_first) + + self.assertEqual(py_sr, ts_sr) + self.assertEqual(py_data, ts_data) diff --git a/torchaudio/backend/sox_io_backend.py b/torchaudio/backend/sox_io_backend.py index 40254658e2..a9bcffdd3a 100644 --- a/torchaudio/backend/sox_io_backend.py +++ b/torchaudio/backend/sox_io_backend.py @@ -1,3 +1,5 @@ +from typing import Tuple + import torch from torchaudio._internal import ( module_utils as _mod_utils, @@ -8,3 +10,71 @@ def info(filepath: str) -> torch.classes.torchaudio.SignalInfo: """Get signal information of an audio file.""" return torch.ops.torchaudio.sox_io_get_info(filepath) + + +@_mod_utils.requires_module('torchaudio._torchaudio') +def load( + filepath: str, + frame_offset: int = 0, + num_frames: int = -1, + normalize: bool = True, + channels_first: bool = True, +) -> Tuple[torch.Tensor, int]: + """Load audio data from file. + + This function can handle all the codecs that underlying libsox can handle, however note the + followings. + + Note 1: + Current torchaudio's binary release only contains codecs for MP3, FLAC and OGG/VORBIS. + If you need other formats, you need to build torchaudio from source with libsox and + the corresponding codecs. Refer to README for this. + + Note 2: + This function is tested on the following formats; + - WAV + - 32-bit floating-point + - 32-bit signed integer + - 16-bit signed integer + - 8-bit unsigned integer + - MP3 + - FLAC + - OGG/VORBIS + + By default, this function returns Tensor with ``float32`` dtype and the shape of ``[channel, time]``. + The samples are normalized to fit in the range of ``[-1.0, 1.0]``. + + When the input format is WAV with integer type, such as 32-bit signed integer, 16-bit + signed integer and 8-bit unsigned integer (24-bit signed integer is not supported), + by providing ``normalize=False``, this function can return integer Tensor, where the samples + are expressed within the whole range of the corresponding dtype, that is, ``int32`` tensor + for 32-bit signed PCM, ``int16`` for 16-bit signed PCM and ``uint8`` for 8-bit unsigned PCM. + + ``normalize`` parameter has no effect on 32-bit floating-point WAV and other formats, such as + flac and mp3. For these formats, this function always returns ``float32`` Tensor with values + normalized to ``[-1.0, 1.0]``. + + Args: + filepath: Path to audio file + frame_offset: Number of frames to skip before start reading data. + num_frames: Maximum number of frames to read. -1 reads all the remaining samples, starting + from ``frame_offset``. This function may return the less number of frames if there is + not enough frames in the given file. + normalize: When ``True``, this function always return ``float32``, and sample values are + normalized to ``[-1.0, 1.0]``. If input file is integer WAV, giving ``False`` will change + the resulting Tensor type to integer type. This argument has no effect for formats other + than integer WAV type. + channels_first: When True, the returned Tensor has dimension ``[channel, time]``. + Otherwise, the returned Tensor's dimension is ``[time, channel]``. + + Returns: + torch.Tensor: If the input file has integer wav format and normalization is off, then it has + integer type, else ``float32`` type. If ``channels_first=True``, it has + ``[channel, time]`` else ``[time, channel]``. + """ + signal = torch.ops.torchaudio.sox_io_load_audio_file( + filepath, frame_offset, num_frames, normalize, channels_first) + return signal.get_tensor(), signal.get_sample_rate() + + +load_wav = load diff --git a/torchaudio/csrc/register.cpp b/torchaudio/csrc/register.cpp index baa13c031f..44f7826e5a 100644 --- a/torchaudio/csrc/register.cpp +++ b/torchaudio/csrc/register.cpp @@ -3,11 +3,15 @@ #include #include +#include #include namespace torchaudio { namespace { +//////////////////////////////////////////////////////////////////////////////// +// typedefs.h +//////////////////////////////////////////////////////////////////////////////// static auto registerSignalInfo = torch::class_("torchaudio", "SignalInfo") .def(torch::init()) @@ -15,12 +19,33 @@ static auto registerSignalInfo = .def("get_num_channels", &SignalInfo::getNumChannels) .def("get_num_frames", &SignalInfo::getNumFrames); +//////////////////////////////////////////////////////////////////////////////// +// sox_utils.h +//////////////////////////////////////////////////////////////////////////////// +static auto registerTensorSignal = + torch::class_("torchaudio", "TensorSignal") + .def(torch::init()) + .def("get_tensor", &sox_utils::TensorSignal::getTensor) + .def("get_sample_rate", &sox_utils::TensorSignal::getSampleRate) + .def("get_channels_first", &sox_utils::TensorSignal::getChannelsFirst); + +//////////////////////////////////////////////////////////////////////////////// +// sox_io.h +//////////////////////////////////////////////////////////////////////////////// static auto registerGetInfo = torch::RegisterOperators().op( torch::RegisterOperators::options() .schema( "torchaudio::sox_io_get_info(str path) -> __torch__.torch.classes.torchaudio.SignalInfo info") .catchAllKernel()); +static auto registerLoadAudioFile = torch::RegisterOperators().op( + torch::RegisterOperators::options() + .schema( + "torchaudio::sox_io_load_audio_file(str path, int frame_offset, int num_frames, bool normalize, bool channels_first) -> __torch__.torch.classes.torchaudio.TensorSignal signal") + .catchAllKernel< + decltype(sox_io::load_audio_file), + &sox_io::load_audio_file>()); + //////////////////////////////////////////////////////////////////////////////// // sox_effects.h //////////////////////////////////////////////////////////////////////////////// diff --git a/torchaudio/csrc/sox_io.cpp b/torchaudio/csrc/sox_io.cpp index 735f16dbd0..349e65c97d 100644 --- a/torchaudio/csrc/sox_io.cpp +++ b/torchaudio/csrc/sox_io.cpp @@ -1,54 +1,103 @@ #include #include +#include using namespace torch::indexing; +using namespace torchaudio::sox_utils; namespace torchaudio { namespace sox_io { -namespace { - -/// Helper struct to safely close the sox_format_t descriptor. -struct SoxDescriptor { - explicit SoxDescriptor(sox_format_t* fd) noexcept : fd_(fd) {} - SoxDescriptor(const SoxDescriptor& other) = delete; - SoxDescriptor(SoxDescriptor&& other) = delete; - SoxDescriptor& operator=(const SoxDescriptor& other) = delete; - SoxDescriptor& operator=(SoxDescriptor&& other) = delete; - ~SoxDescriptor() { - if (fd_ != nullptr) { - sox_close(fd_); - } - } - sox_format_t* operator->() noexcept { - return fd_; - } - sox_format_t* get() noexcept { - return fd_; +c10::intrusive_ptr get_info(const std::string& path) { + SoxFormat sf(sox_open_read( + path.c_str(), + /*signal=*/nullptr, + /*encoding=*/nullptr, + /*filetype=*/nullptr)); + + if (sf.get() == nullptr) { + throw std::runtime_error("Error opening audio file"); } - private: - sox_format_t* fd_; -}; + return c10::make_intrusive( + static_cast(sf->signal.rate), + static_cast(sf->signal.channels), + static_cast(sf->signal.length / sf->signal.channels)); +} -} // namespace +c10::intrusive_ptr load_audio_file( + const std::string& path, + const int64_t frame_offset, + const int64_t num_frames, + const bool normalize, + const bool channels_first) { + if (frame_offset < 0) { + throw std::runtime_error( + "Invalid argument: frame_offset must be non-negative."); + } + if (num_frames == 0 || num_frames < -1) { + throw std::runtime_error( + "Invalid argument: num_frames must be -1 or greater than 0."); + } -c10::intrusive_ptr<::torchaudio::SignalInfo> get_info( - const std::string& file_name) { - SoxDescriptor sd(sox_open_read( - file_name.c_str(), + SoxFormat sf(sox_open_read( + path.c_str(), /*signal=*/nullptr, /*encoding=*/nullptr, /*filetype=*/nullptr)); - if (sd.get() == nullptr) { - throw std::runtime_error("Error opening audio file"); + validate_input_file(sf); + + const int64_t num_channels = sf->signal.channels; + const int64_t num_total_samples = sf->signal.length; + const int64_t sample_start = sf->signal.channels * frame_offset; + + if (sox_seek(sf.get(), sample_start, 0) == SOX_EOF) { + throw std::runtime_error("Error reading audio file: offset past EOF."); + } + + const int64_t sample_end = [&]() { + if (num_frames == -1) + return num_total_samples; + const int64_t sample_end_ = num_channels * num_frames + sample_start; + if (num_total_samples < sample_end_) { + // For lossy encoding, it is difficult to predict exact size of buffer for + // reading the number of samples required. + // So we allocate buffer size of given `num_frames` and ask sox to read as + // much as possible. For lossless format, sox reads exact number of + // samples, but for lossy encoding, sox can end up reading less. (i.e. + // mp3) For the consistent behavior specification between lossy/lossless + // format, we allow users to provide `num_frames` value that exceeds #of + // available samples, and we adjust it here. + return num_total_samples; + } + return sample_end_; + }(); + + const int64_t max_samples = sample_end - sample_start; + + // Read samples into buffer + std::vector buffer; + buffer.reserve(max_samples); + const int64_t num_samples = sox_read(sf.get(), buffer.data(), max_samples); + if (num_samples == 0) { + throw std::runtime_error( + "Error reading audio file: empty file or read operation failed."); } + // NOTE: num_samples may be smaller than max_samples if the input + // format is compressed (i.e. mp3). + + // Convert to Tensor + auto tensor = convert_to_tensor( + buffer.data(), + num_samples, + num_channels, + get_dtype(sf->encoding.encoding, sf->signal.precision), + normalize, + channels_first); - return c10::make_intrusive<::torchaudio::SignalInfo>( - static_cast(sd->signal.rate), - static_cast(sd->signal.channels), - static_cast(sd->signal.length / sd->signal.channels)); + return c10::make_intrusive( + tensor, static_cast(sf->signal.rate), channels_first); } } // namespace sox_io diff --git a/torchaudio/csrc/sox_io.h b/torchaudio/csrc/sox_io.h index 45eb3c1e93..3751f22cf5 100644 --- a/torchaudio/csrc/sox_io.h +++ b/torchaudio/csrc/sox_io.h @@ -1,11 +1,23 @@ +#ifndef TORCHAUDIO_SOX_IO_H +#define TORCHAUDIO_SOX_IO_H + #include +#include #include namespace torchaudio { namespace sox_io { -c10::intrusive_ptr<::torchaudio::SignalInfo> get_info( - const std::string& file_name); +c10::intrusive_ptr get_info(const std::string& path); + +c10::intrusive_ptr load_audio_file( + const std::string& path, + const int64_t frame_offset = 0, + const int64_t num_frames = -1, + const bool normalize = true, + const bool channels_first = true); } // namespace sox_io } // namespace torchaudio + +#endif diff --git a/torchaudio/csrc/sox_utils.cpp b/torchaudio/csrc/sox_utils.cpp new file mode 100644 index 0000000000..4a7b3d8014 --- /dev/null +++ b/torchaudio/csrc/sox_utils.cpp @@ -0,0 +1,113 @@ +#include +#include +#include + +namespace torchaudio { +namespace sox_utils { + +TensorSignal::TensorSignal( + torch::Tensor tensor_, + int64_t sample_rate_, + bool channels_first_) + : tensor(tensor_), + sample_rate(sample_rate_), + channels_first(channels_first_){}; + +torch::Tensor TensorSignal::getTensor() const { + return tensor; +} +int64_t TensorSignal::getSampleRate() const { + return sample_rate; +} +bool TensorSignal::getChannelsFirst() const { + return channels_first; +} + +SoxFormat::SoxFormat(sox_format_t* fd) noexcept : fd_(fd) {} +SoxFormat::~SoxFormat() { + if (fd_ != nullptr) { + sox_close(fd_); + } +} +sox_format_t* SoxFormat::operator->() const noexcept { + return fd_; +} +sox_format_t* SoxFormat::get() const noexcept { + return fd_; +} + +void validate_input_file(const SoxFormat& sf) { + if (sf.get() == nullptr) { + throw std::runtime_error("Error loading audio file: failed to open file."); + } + if (sf->encoding.encoding == SOX_ENCODING_UNKNOWN) { + throw std::runtime_error("Error loading audio file: unknown encoding."); + } + if (sf->signal.length == 0) { + throw std::runtime_error("Error reading audio file: unkown length."); + } +} + +caffe2::TypeMeta get_dtype( + const sox_encoding_t encoding, + const unsigned precision) { + const auto dtype = [&]() { + switch (encoding) { + case SOX_ENCODING_UNSIGNED: // 8-bit PCM WAV + return torch::kUInt8; + case SOX_ENCODING_SIGN2: // 16-bit or 32-bit PCM WAV + switch (precision) { + case 16: + return torch::kInt16; + case 32: + return torch::kInt32; + default: + throw std::runtime_error( + "Only 16 and 32 bits are supported for signed PCM."); + } + default: + // default to float32 for the other formats, including + // 32-bit flaoting-point WAV, + // MP3, + // FLAC, + // VORBIS etc... + return torch::kFloat32; + } + }(); + return c10::scalarTypeToTypeMeta(dtype); +} + +torch::Tensor convert_to_tensor( + sox_sample_t* buffer, + const int32_t num_samples, + const int32_t num_channels, + const caffe2::TypeMeta dtype, + const bool normalize, + const bool channels_first) { + auto t = torch::from_blob( + buffer, {num_samples / num_channels, num_channels}, torch::kInt32); + // Note: Tensor created from_blob does not own data but borrwos + // So make sure to create a new copy after processing samples. + if (normalize || dtype == torch::kFloat32) { + t = t.to(torch::kFloat32); + t *= (t > 0) / 2147483647. + (t < 0) / 2147483648.; + } else if (dtype == torch::kInt32) { + t = t.clone(); + } else if (dtype == torch::kInt16) { + t.floor_divide_(1 << 16); + t = t.to(torch::kInt16); + } else if (dtype == torch::kUInt8) { + t.floor_divide_(1 << 24); + t += 128; + t = t.to(torch::kUInt8); + } else { + throw std::runtime_error("Unsupported dtype."); + } + if (channels_first) { + t = t.transpose(1, 0); + } + return t.contiguous(); +} + +} // namespace sox_utils +} // namespace torchaudio diff --git a/torchaudio/csrc/sox_utils.h b/torchaudio/csrc/sox_utils.h new file mode 100644 index 0000000000..cc61d67c77 --- /dev/null +++ b/torchaudio/csrc/sox_utils.h @@ -0,0 +1,75 @@ +#ifndef TORCHAUDIO_SOX_UTILS_H +#define TORCHAUDIO_SOX_UTILS_H + +#include +#include + +namespace torchaudio { +namespace sox_utils { + +struct TensorSignal : torch::CustomClassHolder { + torch::Tensor tensor; + int64_t sample_rate; + bool channels_first; + + TensorSignal( + torch::Tensor tensor_, + int64_t sample_rate_, + bool channels_first_); + + torch::Tensor getTensor() const; + int64_t getSampleRate() const; + bool getChannelsFirst() const; +}; + +/// helper class to automatically close sox_format_t* +struct SoxFormat { + explicit SoxFormat(sox_format_t* fd) noexcept; + SoxFormat(const SoxFormat& other) = delete; + SoxFormat(SoxFormat&& other) = delete; + SoxFormat& operator=(const SoxFormat& other) = delete; + SoxFormat& operator=(SoxFormat&& other) = delete; + ~SoxFormat(); + sox_format_t* operator->() const noexcept; + sox_format_t* get() const noexcept; + + private: + sox_format_t* fd_; +}; + +/// +/// Verify that input file is found, has known encoding, and not empty +void validate_input_file(const SoxFormat& sf); + +/// +/// Get target dtype for the given encoding and precision. +caffe2::TypeMeta get_dtype( + const sox_encoding_t encoding, + const unsigned precision); + +/// +/// Convert sox_sample_t buffer to uint8/int16/int32/float32 Tensor +/// NOTE: This function might modify the values in the input buffer to +/// reduce the number of memory copy. +/// @param buffer Pointer to buffer that contains audio data. +/// @param num_samples The number of samples to read. +/// @param num_channels The number of channels. Used to reshape the resulting +/// Tensor. +/// @param dtype Target dtype. Determines the output dtype and value range in +/// conjunction with normalization. +/// @param noramlize Perform normalization. Only effective when dtype is not +/// kFloat32. When effective, the output tensor is kFloat32 type and value range +/// is [-1.0, 1.0] +/// @param channels_first When True, output Tensor has shape of [num_channels, +/// num_frames]. +torch::Tensor convert_to_tensor( + sox_sample_t* buffer, + const int32_t num_samples, + const int32_t num_channels, + const caffe2::TypeMeta dtype, + const bool normalize, + const bool channels_first); + +} // namespace sox_utils +} // namespace torchaudio +#endif