Skip to content

Commit 84d40bc

Browse files
committed
Initial commit for SoX logic in VCTK
1 parent 3dcf812 commit 84d40bc

File tree

3 files changed

+242
-12
lines changed

3 files changed

+242
-12
lines changed

test/test_functional.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,11 @@
55
import torch
66
import torchaudio
77
import torchaudio.functional as F
8+
import torchaudio.transforms as T
89
import pytest
910
import unittest
1011
import common_utils
12+
import os
1113

1214
from torchaudio.common_utils import IMPORT_LIBROSA
1315

@@ -20,6 +22,13 @@ class TestFunctional(unittest.TestCase):
2022
data_sizes = [(2, 20), (3, 15), (4, 10)]
2123
number_of_trials = 100
2224
specgram = torch.tensor([1., 2., 3., 4.])
25+
26+
test_dirpath, test_dir = common_utils.create_temp_assets_dir()
27+
test_filepath = os.path.join(test_dirpath, "assets", "sinewave.wav")
28+
waveform, sample_rate = torchaudio.load(test_filepath)
29+
30+
E = torchaudio.sox_effects.SoxEffectsChain()
31+
E.set_input_file(test_filepath)
2332

2433
def _test_compute_deltas(self, specgram, expected, win_length=3, atol=1e-6, rtol=1e-8):
2534
computed = F.compute_deltas(specgram, win_length=win_length)
@@ -311,6 +320,60 @@ def test_create_fb(self):
311320
self._test_create_fb(n_mels=56, fmin=1900.0, fmax=900.0)
312321
self._test_create_fb(n_mels=10, fmin=1900.0, fmax=900.0)
313322

323+
def test_gain(self):
324+
waveform_gain = F.gain(self.waveform, 5)
325+
self.assertTrue(waveform_gain.abs().max().item(), 1.)
326+
327+
self.E.append_effect_to_chain("gain", [5])
328+
sox_gain_waveform = self.E.sox_build_flow_effects()[0]
329+
330+
self.assertTrue(torch.allclose(waveform_gain, sox_gain_waveform))
331+
self.E.clear_chain()
332+
333+
def test_scale_to_interval(self):
334+
scaled = 5.5 # [-5.5, 5.5]
335+
waveform_scaled = F.scale_to_interval(self.waveform, scaled)
336+
337+
self.assertTrue(torch.max(waveform_scaled) <= scaled)
338+
self.assertTrue(torch.min(waveform_scaled) >= -scaled)
339+
340+
def test_dither(self):
341+
waveform_dithered = F.dither(self.waveform)
342+
waveform_dithered_noiseshaped = F.dither(self.waveform, noise_shaping=True)
343+
344+
self.E.append_effect_to_chain("dither", [])
345+
sox_dither_waveform = self.E.sox_build_flow_effects()[0]
346+
347+
self.assertTrue(torch.allclose(waveform_dithered, sox_dither_waveform, rtol=1e-03, atol=1e-03))
348+
self.E.clear_chain()
349+
350+
self.E.append_effect_to_chain("dither", ["-s"])
351+
sox_dither_waveform_ns = self.E.sox_build_flow_effects()[0]
352+
353+
self.assertTrue(torch.allclose(waveform_dithered_noiseshaped, sox_dither_waveform_ns, rtol=1e-03, atol=1e-03))
354+
self.E.clear_chain()
355+
356+
def test_vctk_transform_pipeline(self):
357+
test_filepath_vctk = os.path.join(self.test_dirpath, "assets/VCTK-Corpus/wav48/p224/", "p224_002.wav")
358+
wf_vctk, sr_vctk = torchaudio.load(test_filepath_vctk)
359+
360+
# rate
361+
sample = T.Resample(sr_vctk, 16000, resampling_method='sinc_interpolation')
362+
wf_vctk = sample(wf_vctk)
363+
# dither
364+
wf_vctk = F.dither(wf_vctk, noise_shaping=True)
365+
366+
self.E.set_input_file(test_filepath_vctk)
367+
self.E.append_effect_to_chain("gain", ["-h"])
368+
self.E.append_effect_to_chain("channels", [1])
369+
self.E.append_effect_to_chain("rate", [16000])
370+
self.E.append_effect_to_chain("gain", ["-rh"])
371+
self.E.append_effect_to_chain("dither", ["-s"])
372+
wf_vctk_sox = self.E.sox_build_flow_effects()[0]
373+
374+
self.assertTrue(torch.allclose(wf_vctk, wf_vctk_sox, rtol=1e-03, atol=1e-03))
375+
self.E.clear_chain()
376+
314377
def test_pitch(self):
315378

316379
test_dirpath, test_dir = common_utils.create_temp_assets_dir()

torchaudio/datasets/vctk.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -21,18 +21,16 @@ def load_vctk_item(
2121

2222
# Read wav
2323
file_audio = os.path.join(path, folder_audio, speaker_id, fileid + ext_audio)
24+
waveform, sample_rate = torchaudio.load(file_audio)
2425
if downsample:
25-
# Legacy
26-
E = torchaudio.sox_effects.SoxEffectsChain()
27-
E.set_input_file(file_audio)
28-
E.append_effect_to_chain("gain", ["-h"])
29-
E.append_effect_to_chain("channels", [1])
30-
E.append_effect_to_chain("rate", [16000])
31-
E.append_effect_to_chain("gain", ["-rh"])
32-
E.append_effect_to_chain("dither", ["-s"])
33-
waveform, sample_rate = E.sox_build_flow_effects()
34-
else:
35-
waveform, sample_rate = torchaudio.load(file_audio)
26+
# TODO Remove this parameter after deprecation
27+
F = torchaudio.functional
28+
T = torchaudio.transforms
29+
# rate
30+
sample = T.Resample(sample_rate, 16000, resampling_method='sinc_interpolation')
31+
waveform = sample(waveform)
32+
# dither
33+
waveform = F.dither(waveform, noise_shaping=True)
3634

3735
return waveform, sample_rate, utterance, speaker_id, utterance_id
3836

torchaudio/functional.py

Lines changed: 170 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -703,7 +703,7 @@ def equalizer_biquad(waveform, sample_rate, center_freq, gain, Q=0.707):
703703
Args:
704704
waveform (torch.Tensor): audio waveform of dimension of `(channel, time)`
705705
sample_rate (int): sampling rate of the waveform, e.g. 44100 (Hz)
706-
center_freq (float): filters central frequency
706+
center_freq (float): filter's central frequency
707707
gain (float): desired gain at the boost (or attenuation) in dB
708708
q_factor (float): https://en.wikipedia.org/wiki/Q_factor
709709
@@ -844,6 +844,175 @@ def compute_deltas(specgram, win_length=5, mode="replicate"):
844844

845845

846846
@torch.jit.script
847+
def gain(waveform, gain_db=1.0):
848+
# type: (Tensor, float) -> Tensor
849+
r"""Apply amplification or attenuation to the whole waveform.
850+
851+
Args:
852+
waveform (torch.Tensor): Tensor of audio of dimension (channel, time).
853+
gain_db (float) Gain adjustment in decibels (dB) (Default: `1.0`).
854+
855+
Returns:
856+
torch.Tensor: the whole waveform amplified by gain_db.
857+
"""
858+
if (gain_db == 0):
859+
return waveform
860+
861+
ratio = 10 ** (gain_db / 20)
862+
863+
return waveform * ratio
864+
865+
866+
@torch.jit.script
867+
def scale_to_interval(waveform, interval=1.0):
868+
# type: (Tensor, float) -> Tensor
869+
r"""Scales the whole waveform to an interval.
870+
871+
Args:
872+
waveform (torch.Tensor): Tensor of audio of dimension (channel, time).
873+
interval (float): The bounds of the interval, where the float indicates
874+
the upper bound and the negative of the float indicates the lower
875+
bound (Default: `1.0`).
876+
Example: interval=1.0 -> [-1.0, 1.0]
877+
878+
Returns:
879+
torch.Tensor: the whole waveform scaled to interval.
880+
"""
881+
abs_max = torch.max(torch.abs(waveform))
882+
ratio = abs_max / interval
883+
waveform /= ratio
884+
885+
return waveform
886+
887+
888+
def _add_noise_shaping(dithered_waveform, waveform):
889+
r"""Noise shaping is calculated by error:
890+
error[n] = dithered[n] - original[n]
891+
noise_shaped_waveform[n] = dithered[n] + error[n-1]
892+
"""
893+
wf_shape = waveform.size()
894+
waveform = waveform.reshape(-1, wf_shape[-1])
895+
896+
dithered_shape = dithered_waveform.size()
897+
dithered_waveform = dithered_waveform.reshape(-1, dithered_shape[-1])
898+
899+
error = dithered_waveform - waveform
900+
901+
# add error[n-1] to dithered_waveform[n], so offset the error by 1 index
902+
for index in range(error.size()[0]):
903+
err = error[index]
904+
error_offset = torch.cat((torch.zeros(1), err))
905+
error[index] = error_offset[:waveform.size()[1]]
906+
907+
noise_shaped = dithered_waveform + error
908+
return noise_shaped.reshape(dithered_shape[:-1] + noise_shaped.shape[-1:])
909+
910+
911+
@torch.jit.script
912+
def probability_distribution(waveform, density_function="TPDF"):
913+
# type: (Tensor, str) -> Tensor
914+
r"""Apply a probability distribution function on a waveform.
915+
916+
Triangular probability density function (TPDF) dither noise has a
917+
triangular distribution; values in the center of the range have a higher
918+
probability of occurring.
919+
920+
Rectangular probability density function (RPDF) dither noise has a
921+
uniform distribution; any value in the specified range has the same
922+
probability of occurring.
923+
924+
Gaussian probability density function (GPDF) has a normal distribution.
925+
The relationship of probabilities of results follows a bell-shaped,
926+
or Gaussian curve, typical of dither generated by analog sources.
927+
Args:
928+
waveform (torch.Tensor): Tensor of audio of dimension (channel, time)
929+
probability_density_function (string): The density function of a
930+
continuous random variable (Default: `TPDF`)
931+
Options: Triangular Probability Density Function - `TPDF`
932+
Rectangular Probability Density Function - `RPDF`
933+
Gaussian Probability Density Function - `GPDF`
934+
Returns:
935+
torch.Tensor: waveform dithered with TPDF
936+
"""
937+
shape = waveform.size()
938+
waveform = waveform.reshape(-1, shape[-1])
939+
940+
channel_size = waveform.size()[0] - 1
941+
time_size = waveform.size()[-1] - 1
942+
943+
random_channel = int(torch.randint(channel_size, [1, ]).item()) if channel_size > 0 else 0
944+
random_time = int(torch.randint(time_size, [1, ]).item()) if time_size > 0 else 0
945+
946+
number_of_bits = 16
947+
up_scaling = 2 ** (number_of_bits - 1) - 2
948+
signal_scaled = waveform * up_scaling
949+
down_scaling = 2 ** (number_of_bits - 1)
950+
951+
signal_scaled_dis = waveform
952+
if (density_function == "RPDF"):
953+
RPDF = waveform[random_channel][random_time] - 0.5
954+
955+
signal_scaled_dis = signal_scaled + RPDF
956+
elif (density_function == "GPDF"):
957+
# TODO Replace by distribution code once
958+
# https://github.com/pytorch/pytorch/issues/29843 is resolved
959+
# gaussian = torch.distributions.normal.Normal(torch.mean(waveform), 1).sample()
960+
961+
EPOCH = 6
962+
963+
gaussian = waveform[random_channel][random_time]
964+
for ws in EPOCH * [time_size]:
965+
rand_chan = int(torch.randint(channel_size, [1, ]).item())
966+
gaussian += waveform[rand_chan][int(torch.randint(ws, [1, ]).item())]
967+
968+
signal_scaled_dis = signal_scaled + gaussian
969+
else:
970+
TPDF = torch.bartlett_window(time_size + 1)
971+
972+
signal_scaled_dis = signal_scaled
973+
for index in range(channel_size + 1):
974+
signal_scaled_dis[index] += TPDF
975+
976+
quantised_signal_scaled = torch.round(signal_scaled_dis)
977+
quantised_signal = quantised_signal_scaled / down_scaling
978+
return quantised_signal.reshape(shape[:-1] + quantised_signal.shape[-1:])
979+
980+
981+
@torch.jit.script
982+
def dither(waveform, probability_density_function="TPDF", noise_shaping=False, ns_filter=""):
983+
# type: (Tensor, str, bool, str) -> Tensor
984+
r"""Dither increases the perceived dynamic range of audio stored at a
985+
particular bit-depth by eliminating nonlinear truncation distortion
986+
(i.e. adding minimally perceived noise to mask distortion caused by quantization).
987+
Args:
988+
waveform (torch.Tensor): Tensor of audio of dimension (channel, time)
989+
probability_density_function (string): The density function of a
990+
continuous random variable (Default: `TPDF`)
991+
Options: Triangular Probability Density Function - `TPDF`
992+
Rectangular Probability Density Function - `RPDF`
993+
Gaussian Probability Density Function - `GPDF`
994+
noise_shaping (boolean): a filtering process that shapes the spectral
995+
energy of quantisation error (Default: `False`)
996+
ns_filter (string): TODO The noise shaping filter (Default: `""`)
997+
Options: Lipshitz - `L`
998+
F-Weighted - `FW`
999+
Modified-E-Weighted - `MEW`
1000+
Improved-E-Weighted - `IEW`
1001+
Gesemann - `G`
1002+
Shibata - `S`
1003+
Low-Shibata - `LS`
1004+
High-Shibata - `HS`
1005+
Returns:
1006+
torch.Tensor: waveform dithered
1007+
"""
1008+
dithered = probability_distribution(waveform, density_function=probability_density_function)
1009+
1010+
if noise_shaping:
1011+
return _add_noise_shaping(dithered, waveform)
1012+
else:
1013+
return dithered
1014+
1015+
8471016
def _compute_nccf(waveform, sample_rate, frame_time, freq_low):
8481017
# type: (Tensor, int, float, int) -> Tensor
8491018
r"""

0 commit comments

Comments
 (0)