From 6a8a05f806c48272491eb0dc927f6fca66683c98 Mon Sep 17 00:00:00 2001 From: Isaac Seessel Date: Mon, 15 Mar 2021 11:01:19 -0400 Subject: [PATCH] Add support for signed-24-bit-int wav files in sox_io_backend --- .../backend/sox_io/load_test.py | 56 ++++++++++++++++++- torchaudio/backend/sox_io_backend.py | 10 ++-- torchaudio/csrc/sox/utils.cpp | 5 +- 3 files changed, 63 insertions(+), 8 deletions(-) diff --git a/test/torchaudio_unittest/backend/sox_io/load_test.py b/test/torchaudio_unittest/backend/sox_io/load_test.py index 593098ecab..115efebbf9 100644 --- a/test/torchaudio_unittest/backend/sox_io/load_test.py +++ b/test/torchaudio_unittest/backend/sox_io/load_test.py @@ -42,6 +42,49 @@ def assert_wav(self, dtype, sample_rate, num_channels, normalize, duration): assert sr == sample_rate self.assertEqual(data, expected) + def assert_24bit_wav(self, sample_rate, num_channels, normalize, duration): + """ `sox_io_backend.load` can load 24-bit signed PCM wav format. Since torch does not support the ``int24`` dtype, + we implicitly cast the resulting tensor to the ``int32`` dtype. + + It is not possible to use #assert_wav method above, as #get_wav_data does not support + the 'int24' dtype. This is because torch does not support the ``int24`` dtype. + Hence, we must use the following workaround. + + x + | + | 1. Generate 24-bit wav with Sox. + | + v 2. Convert 24-bit wav to 32-bit wav with Sox. + wav(24-bit) ----------------------> wav(32-bit) + | | + | 3. Load 24-bit wav with torchaudio| 4. Load 32-bit wav with scipy + | | + v v + tensor ----------> x <----------- tensor + 5. Compare + + # Underlying assumptions are: + # i. Sox properly converts from 24-bit to 32-bit + # ii. Loading 32-bit wav file with scipy is correct. + """ + path = self.get_temp_path('1.original.wav') + ref_path = self.get_temp_path('2.reference.wav') + + # 1. Generate 24-bit signed wav with Sox + sox_utils.gen_audio_file( + path, sample_rate, num_channels, + bit_depth=24, duration=duration) + + # 2. Convert from 24-bit wav to 32-bit wav with sox + sox_utils.convert_audio_file(path, ref_path, bit_depth=32) + # 3. Load 24-bit wav with torchaudio + data, sr = sox_io_backend.load(path, normalize=normalize) + # 4. Load 32-bit wav with scipy + data_ref = load_wav(ref_path, normalize=normalize)[0] + # 5. Compare + assert sr == sample_rate + self.assertEqual(data, data_ref, atol=3e-03, rtol=1.3e-06) + def assert_mp3(self, sample_rate, num_channels, bit_rate, duration): """`sox_io_backend.load` can load mp3 format. @@ -50,7 +93,7 @@ def assert_mp3(self, sample_rate, num_channels, bit_rate, duration): x | - | 1. Generate mp3 with Sox + | 1. Generate mp3 with Sox | v 2. Convert to wav with Sox mp3 ------------------------------> wav @@ -61,7 +104,7 @@ def assert_mp3(self, sample_rate, num_channels, bit_rate, duration): tensor ----------> x <----------- tensor 5. Compare - Underlying assumptions are; + Underlying assumptions are: i. Conversion of mp3 to wav with Sox preserves data. ii. Loading wav file with scipy is correct. @@ -213,6 +256,15 @@ def test_wav(self, dtype, sample_rate, num_channels, normalize): """`sox_io_backend.load` can load wav format correctly.""" self.assert_wav(dtype, sample_rate, num_channels, normalize, duration=1) + @parameterized.expand(list(itertools.product( + [8000, 16000], + [1, 2], + [False, True], + )), name_func=name_func) + def test_24bit_wav(self, sample_rate, num_channels, normalize): + """`sox_io_backend.load` can load 24bit wav format correctly. Corectly casts it to ``int32`` tensor dtype.""" + self.assert_24bit_wav(sample_rate, num_channels, normalize, duration=1) + @parameterized.expand(list(itertools.product( ['int16'], [16000], diff --git a/torchaudio/backend/sox_io_backend.py b/torchaudio/backend/sox_io_backend.py index f821c968e6..65b3e8bd50 100644 --- a/torchaudio/backend/sox_io_backend.py +++ b/torchaudio/backend/sox_io_backend.py @@ -73,6 +73,7 @@ def load( * 32-bit floating-point * 32-bit signed integer + * 24-bit signed integer * 16-bit signed integer * 8-bit unsigned integer (WAV only) @@ -92,10 +93,11 @@ def load( The samples are normalized to fit in the range of ``[-1.0, 1.0]``. When the input format is WAV with integer type, such as 32-bit signed integer, 16-bit - signed integer and 8-bit unsigned integer (24-bit signed integer is not supported), - by providing ``normalize=False``, this function can return integer Tensor, where the samples - are expressed within the whole range of the corresponding dtype, that is, ``int32`` tensor - for 32-bit signed PCM, ``int16`` for 16-bit signed PCM and ``uint8`` for 8-bit unsigned PCM. + signed integer, 24-bit signed integer, and 8-bit unsigned integer, by providing ``normalize=False``, + this function can return integer Tensor, where the samples are expressed within the whole range + of the corresponding dtype, that is, ``int32`` tensor for 32-bit signed PCM, + ``int16`` for 16-bit signed PCM and ``uint8`` for 8-bit unsigned PCM. Since torch does not + support ``int24`` dtype, 24-bit signed PCM are converted to ``int32`` tensors. ``normalize`` parameter has no effect on 32-bit floating-point WAV and other formats, such as ``flac`` and ``mp3``. diff --git a/torchaudio/csrc/sox/utils.cpp b/torchaudio/csrc/sox/utils.cpp index 66a5a80af8..a297b0c96c 100644 --- a/torchaudio/csrc/sox/utils.cpp +++ b/torchaudio/csrc/sox/utils.cpp @@ -118,15 +118,16 @@ caffe2::TypeMeta get_dtype( switch (encoding) { case SOX_ENCODING_UNSIGNED: // 8-bit PCM WAV return torch::kUInt8; - case SOX_ENCODING_SIGN2: // 16-bit or 32-bit PCM WAV + case SOX_ENCODING_SIGN2: // 16-bit, 24-bit, or 32-bit PCM WAV switch (precision) { case 16: return torch::kInt16; + case 24: // Cast 24-bit to 32-bit. case 32: return torch::kInt32; default: throw std::runtime_error( - "Only 16 and 32 bits are supported for signed PCM."); + "Only 16, 24, and 32 bits are supported for signed PCM."); } default: // default to float32 for the other formats, including