Skip to content

Commit

Permalink
Add support for signed-24-bit-int wav files in sox_io_backend
Browse files Browse the repository at this point in the history
  • Loading branch information
iseessel committed Mar 13, 2021
1 parent f2b7542 commit 803cc7a
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 14 deletions.
57 changes: 55 additions & 2 deletions test/torchaudio_unittest/backend/sox_io/load_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,50 @@ def assert_wav(self, dtype, sample_rate, num_channels, normalize, duration):
assert sr == sample_rate
self.assertEqual(data, expected)

def assert_24bit_wav(self, sample_rate, num_channels, normalize, duration):
""" `sox_io_backend.load` can load 24-bit signed PCM wav format. Since torch does not support the ``int24`` dtype,
we implicitly cast the resulting tensor to the ``int32`` dtype.
It is not possible to use #assert_wav method above, as #get_wav_data does not support
the 'int24' dtype. This is because torch does not support the ``int24`` dtype.
Hence, we must use the following workaround.
x
|
| 1. Generate 24-bit wav with Sox.
|
v 2. Convert 24-bit wav to 32-bit wav with Sox.
wav(24-bit) ----------------------> wav(32-bit)
| |
| 3. Load 24-bit wav with torchaudio| 4. Load 32-bit wav with scipy
| |
v v
tensor ----------> x <----------- tensor
5. Compare
# Underlying assumptions are:
# i. Sox properly converts from 24-bit to 32-bit
# ii. Loading 32-bit wav file with scipy is correct.
"""
path = self.get_temp_path('1.original.wav')
ref_path = self.get_temp_path('2.reference.wav')

# 1. Generate 24-bit signed wav with Sox
sox_utils.gen_audio_file(
path, sample_rate, num_channels,
bit_depth=24, duration=duration)

# 2. Convert from 24-bit wav to 32-bit wav with sox
# This implicitly performs 8 left-shifts on each data chunk
sox_utils.convert_audio_file(path, ref_path, bit_depth=32)
# 3. Load 24-bit wav with torchaudio
data, sr = sox_io_backend.load(path, normalize=normalize)
# 4. Load 32-bit wav with scipy
data_ref = load_wav(ref_path, normalize=normalize)[0]
# 5. Compare
assert sr == sample_rate
self.assertEqual(data, data_ref, atol=3e-03, rtol=1.3e-06)

def assert_mp3(self, sample_rate, num_channels, bit_rate, duration):
"""`sox_io_backend.load` can load mp3 format.
Expand All @@ -50,7 +94,7 @@ def assert_mp3(self, sample_rate, num_channels, bit_rate, duration):
x
|
| 1. Generate mp3 with Sox
| 1. Generate mp3 with Sox
|
v 2. Convert to wav with Sox
mp3 ------------------------------> wav
Expand All @@ -61,7 +105,7 @@ def assert_mp3(self, sample_rate, num_channels, bit_rate, duration):
tensor ----------> x <----------- tensor
5. Compare
Underlying assumptions are;
Underlying assumptions are:
i. Conversion of mp3 to wav with Sox preserves data.
ii. Loading wav file with scipy is correct.
Expand Down Expand Up @@ -213,6 +257,15 @@ def test_wav(self, dtype, sample_rate, num_channels, normalize):
"""`sox_io_backend.load` can load wav format correctly."""
self.assert_wav(dtype, sample_rate, num_channels, normalize, duration=1)

@parameterized.expand(list(itertools.product(
[8000, 16000],
[1, 2],
[False, True],
)), name_func=name_func)
def test_24bit_wav(self, sample_rate, num_channels, normalize):
"""`sox_io_backend.load` can load 24bit wav format correctly. Corectly casts it to ``int32`` tensor dtype."""
self.assert_24bit_wav(sample_rate, num_channels, normalize, duration=1)

@parameterized.expand(list(itertools.product(
['int16'],
[16000],
Expand Down
13 changes: 9 additions & 4 deletions torchaudio/backend/sox_io_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,12 @@ def load(
* 32-bit floating-point
* 32-bit signed integer
* 24-bit signed integer
* 16-bit signed integer
* 8-bit unsigned integer (WAV only)
Note: 24-bit signed integers are converted to ``int32`` tensors.
* MP3
* FLAC
* OGG/VORBIS
Expand All @@ -92,10 +95,12 @@ def load(
The samples are normalized to fit in the range of ``[-1.0, 1.0]``.
When the input format is WAV with integer type, such as 32-bit signed integer, 16-bit
signed integer and 8-bit unsigned integer (24-bit signed integer is not supported),
by providing ``normalize=False``, this function can return integer Tensor, where the samples
are expressed within the whole range of the corresponding dtype, that is, ``int32`` tensor
for 32-bit signed PCM, ``int16`` for 16-bit signed PCM and ``uint8`` for 8-bit unsigned PCM.
signed integer, 24-bit signed integer, and 8-bit unsigned integer, by providing ``normalize=False``,
this function can return integer Tensor, where the samples are expressed within the whole range
of the corresponding dtype, that is, ``int32`` tensor for 32-bit signed PCM,
``int16`` for 16-bit signed PCM and ``uint8`` for 8-bit unsigned PCM. Since torch does not
support ``int24`` dtype, 24-bit signed PCM are converted to ``int32`` tensors,
by implicitly performing 8 left-shifts per data chunk.
``normalize`` parameter has no effect on 32-bit floating-point WAV and other formats, such as
``flac`` and ``mp3``.
Expand Down
25 changes: 17 additions & 8 deletions torchaudio/csrc/sox/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,15 +118,19 @@ caffe2::TypeMeta get_dtype(
switch (encoding) {
case SOX_ENCODING_UNSIGNED: // 8-bit PCM WAV
return torch::kUInt8;
case SOX_ENCODING_SIGN2: // 16-bit or 32-bit PCM WAV
case SOX_ENCODING_SIGN2: // 16-bit, 24-bit, or 32-bit PCM WAV
switch (precision) {
case 16:
return torch::kInt16;
case 24: // Cast 24-bit to 32-bit. That is each data chunk will
// implicitly perform 8 left-shifts.
std::cout << "Warning: Casting 24-bit signed PCM to torch::kInt32";
return torch::kInt32;
case 32:
return torch::kInt32;
default:
throw std::runtime_error(
"Only 16 and 32 bits are supported for signed PCM.");
"Only 16, 24, and 32 bits are supported for signed PCM.");
}
default:
// default to float32 for the other formats, including
Expand All @@ -142,19 +146,24 @@ caffe2::TypeMeta get_dtype(

caffe2::TypeMeta get_dtype_from_str(const std::string dtype) {
const auto tgt_dtype = [&]() {
if (dtype == "uint8")
if (dtype == "uint8") {
return torch::kUInt8;
else if (dtype == "int16")
} else if (dtype == "int16") {
return torch::kInt16;
else if (dtype == "int32")
} else if (dtype == "int24") {
std::cout << "Warning: Casting 24-bit signed PCM to torch::kInt32";
return torch::kInt32;
} else if (dtype == "int32") {
return torch::kInt32;
else if (dtype == "float32")
} else if (dtype == "float32") {
return torch::kFloat32;
else if (dtype == "float64")
} else if (dtype == "float64") {
return torch::kFloat64;
else
} else {
throw std::runtime_error("Unsupported dtype");
}
}();

return c10::scalarTypeToTypeMeta(tgt_dtype);
}

Expand Down

0 comments on commit 803cc7a

Please sign in to comment.