Skip to content

Batching for transforms #337

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 16 commits into from
Nov 18, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ Transforms expect and return the following dimensions.
* `MuLawDecode`: (channel, time) -> (channel, time)
* `Resample`: (channel, time) -> (channel, time)

Complex numbers are supported via tensors of dimension (..., 2), and torchaudio provides `complex_norm` and `angle` to convert such a tensor into its magnitude and phase.
Complex numbers are supported via tensors of dimension (..., 2), and torchaudio provides `complex_norm` and `angle` to convert such a tensor into its magnitude and phase. Here, and in the documentation, we use an ellipsis "..." as a placeholder for the rest of the dimensions of a tensor, e.g. optional batching and channel dimensions.

Contributing Guidelines
-----------------------
Expand Down
49 changes: 46 additions & 3 deletions test/test_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@ class TestFunctional(unittest.TestCase):
number_of_trials = 100
specgram = torch.tensor([1., 2., 3., 4.])

test_dirpath, test_dir = common_utils.create_temp_assets_dir()
test_filepath = os.path.join(test_dirpath, 'assets',
'steam-train-whistle-daniel_simon.mp3')

def _test_compute_deltas(self, specgram, expected, win_length=3, atol=1e-6, rtol=1e-8):
computed = F.compute_deltas(specgram, win_length=win_length)
self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape))
Expand All @@ -46,6 +50,20 @@ def test_compute_deltas_randn(self):
computed = F.compute_deltas(specgram, win_length=win_length)
self.assertTrue(computed.shape == specgram.shape, (computed.shape, specgram.shape))

def test_batch_pitch(self):
waveform, sample_rate = torchaudio.load(self.test_filepath)

# Single then transform then batch
expected = F.detect_pitch_frequency(waveform, sample_rate)
expected = expected.unsqueeze(0).repeat(3, 1, 1)

# Batch then transform
waveform = waveform.unsqueeze(0).repeat(3, 1, 1)
computed = F.detect_pitch_frequency(waveform, sample_rate)

self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape))
self.assertTrue(torch.allclose(computed, expected))

def _compare_estimate(self, sound, estimate, atol=1e-6, rtol=1e-8):
# trim sound for case when constructed signal is shorter than original
sound = sound[..., :estimate.size(-1)]
Expand All @@ -58,13 +76,23 @@ def _test_istft_is_inverse_of_stft(self, kwargs):
# operation to check whether we can reconstruct signal
for data_size in self.data_sizes:
for i in range(self.number_of_trials):

# Non-batch
sound = common_utils.random_float_tensor(i, data_size)

stft = torch.stft(sound, **kwargs)
estimate = torchaudio.functional.istft(stft, length=sound.size(1), **kwargs)

self._compare_estimate(sound, estimate)

# Batch
stft = torch.stft(sound, **kwargs)
stft = stft.repeat(3, 1, 1, 1, 1)
sound = sound.repeat(3, 1, 1)

estimate = torchaudio.functional.istft(stft, length=sound.size(1), **kwargs)
self._compare_estimate(sound, estimate)

def test_istft_is_inverse_of_stft1(self):
# hann_window, centered, normalized, onesided
kwargs1 = {
Expand Down Expand Up @@ -326,15 +354,30 @@ def test_pitch(self):
for filename, freq_ref in tests:
waveform, sample_rate = torchaudio.load(filename)

# Convert to stereo for testing purposes
waveform = waveform.repeat(2, 1, 1)

freq = torchaudio.functional.detect_pitch_frequency(waveform, sample_rate)

threshold = 1
s = ((freq - freq_ref).abs() > threshold).sum()
self.assertFalse(s)

# Convert to stereo and batch for testing purposes
freq = freq.repeat(3, 2, 1, 1)
waveform = waveform.repeat(3, 2, 1, 1)

freq2 = torchaudio.functional.detect_pitch_frequency(waveform, sample_rate)

assert torch.allclose(freq, freq2, atol=1e-5)

def _test_batch(self, functional):
waveform, sample_rate = torchaudio.load(self.test_filepath) # (2, 278756), 44100

# Single then transform then batch
expected = functional(waveform).unsqueeze(0).repeat(3, 1, 1, 1)

# Batch then transform
waveform = waveform.unsqueeze(0).repeat(3, 1, 1)
computed = functional(waveform)


def _num_stft_bins(signal_len, fft_len, hop_length, pad):
return (signal_len + 2 * pad - fft_len + hop_length) // hop_length
Expand Down
39 changes: 39 additions & 0 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,45 @@ def test_compute_deltas_twochannel(self):
computed = transform(specgram)
self.assertTrue(computed.shape == specgram.shape, (computed.shape, specgram.shape))

def test_batch_compute_deltas(self):
specgram = torch.randn(2, 31, 2786)

# Single then transform then batch
expected = transforms.ComputeDeltas()(specgram).repeat(3, 1, 1, 1)

# Batch then transform
computed = transforms.ComputeDeltas()(specgram.repeat(3, 1, 1, 1))

# shape = (3, 2, 201, 1394)
self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape))
self.assertTrue(torch.allclose(computed, expected))

def test_batch_mulaw(self):
waveform, sample_rate = torchaudio.load(self.test_filepath) # (2, 278756), 44100

# Single then transform then batch
waveform_encoded = transforms.MuLawEncoding()(waveform)
expected = waveform_encoded.unsqueeze(0).repeat(3, 1, 1)

# Batch then transform
waveform_batched = waveform.unsqueeze(0).repeat(3, 1, 1)
computed = transforms.MuLawEncoding()(waveform_batched)

# shape = (3, 2, 201, 1394)
self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape))
self.assertTrue(torch.allclose(computed, expected))

# Single then transform then batch
waveform_decoded = transforms.MuLawDecoding()(waveform_encoded)
expected = waveform_decoded.unsqueeze(0).repeat(3, 1, 1)

# Batch then transform
computed = transforms.MuLawDecoding()(computed)

# shape = (3, 2, 201, 1394)
self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape))
self.assertTrue(torch.allclose(computed, expected))

def test_batch_spectrogram(self):
waveform, sample_rate = torchaudio.load(self.test_filepath)

Expand Down
75 changes: 55 additions & 20 deletions torchaudio/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,16 +114,20 @@ def istft(
original signal length). (Default: whole signal)

Returns:
torch.Tensor: Least squares estimation of the original signal of size
(channel, signal_length) or (signal_length)
torch.Tensor: Least squares estimation of the original signal of size (..., signal_length)
"""
stft_matrix_dim = stft_matrix.dim()
assert 3 <= stft_matrix_dim <= 4, "Incorrect stft dimension: %d" % (stft_matrix_dim)
assert 3 <= stft_matrix_dim, "Incorrect stft dimension: %d" % (stft_matrix_dim)
assert stft_matrix.nelement() > 0

if stft_matrix_dim == 3:
# add a channel dimension
stft_matrix = stft_matrix.unsqueeze(0)

# pack batch
shape = stft_matrix.size()
stft_matrix = stft_matrix.reshape(-1, *shape[-3:])

dtype = stft_matrix.dtype
device = stft_matrix.device
fft_size = stft_matrix.size(1)
Expand Down Expand Up @@ -208,8 +212,12 @@ def istft(

y = (y / window_envelop).squeeze(1) # size (channel, expected_signal_len)

# unpack batch
y = y.reshape(shape[:-3] + y.shape[-1:])

if stft_matrix_dim == 3: # remove the channel dimension
y = y.squeeze(0)

return y


Expand Down Expand Up @@ -514,14 +522,14 @@ def phase_vocoder(complex_specgrams, rate, phase_advance):
dtype=complex_specgrams.dtype)

alphas = time_steps % 1.0
phase_0 = angle(complex_specgrams[:, :, :1])
phase_0 = angle(complex_specgrams[..., :1, :])

# Time Padding
complex_specgrams = torch.nn.functional.pad(complex_specgrams, [0, 0, 0, 2])

# (new_bins, freq, 2)
complex_specgrams_0 = complex_specgrams[:, :, time_steps.long()]
complex_specgrams_1 = complex_specgrams[:, :, (time_steps + 1).long()]
complex_specgrams_0 = complex_specgrams.index_select(-2, time_steps.long())
complex_specgrams_1 = complex_specgrams.index_select(-2, (time_steps + 1).long())

angle_0 = angle(complex_specgrams_0)
angle_1 = angle(complex_specgrams_1)
Expand All @@ -534,7 +542,7 @@ def phase_vocoder(complex_specgrams, rate, phase_advance):

# Compute Phase Accum
phase = phase + phase_advance
phase = torch.cat([phase_0, phase[:, :, :-1]], dim=-1)
phase = torch.cat([phase_0, phase[..., :-1]], dim=-1)
phase_acc = torch.cumsum(phase, -1)

mag = alphas * norm_1 + (1 - alphas) * norm_0
Expand All @@ -554,7 +562,7 @@ def lfilter(waveform, a_coeffs, b_coeffs):
Performs an IIR filter by evaluating difference equation.

Args:
waveform (torch.Tensor): audio waveform of dimension of `(channel, time)`. Must be normalized to -1 to 1.
waveform (torch.Tensor): audio waveform of dimension of `(..., time)`. Must be normalized to -1 to 1.
a_coeffs (torch.Tensor): denominator coefficients of difference equation of dimension of `(n_order + 1)`.
Lower delays coefficients are first, e.g. `[a0, a1, a2, ...]`.
Must be same size as b_coeffs (pad with 0's as necessary).
Expand All @@ -563,10 +571,16 @@ def lfilter(waveform, a_coeffs, b_coeffs):
Must be same size as a_coeffs (pad with 0's as necessary).

Returns:
output_waveform (torch.Tensor): Dimension of `(channel, time)`. Output will be clipped to -1 to 1.
output_waveform (torch.Tensor): Dimension of `(..., time)`. Output will be clipped to -1 to 1.

"""

dim = waveform.dim()

# pack batch
shape = waveform.size()
waveform = waveform.reshape(-1, shape[-1])

assert(a_coeffs.size(0) == b_coeffs.size(0))
assert(len(waveform.size()) == 2)
assert(waveform.device == a_coeffs.device)
Expand Down Expand Up @@ -606,7 +620,14 @@ def lfilter(waveform, a_coeffs, b_coeffs):

padded_output_waveform[:, i_sample + n_order - 1] = o0

return torch.min(ones, torch.max(ones * -1, padded_output_waveform[:, (n_order - 1):]))
output = torch.min(
ones, torch.max(ones * -1, padded_output_waveform[:, (n_order - 1):])
)

# unpack batch
output = output.reshape(shape[:-1] + output.shape[-1:])

return output


@torch.jit.script
Expand Down Expand Up @@ -817,22 +838,24 @@ def compute_deltas(specgram, win_length=5, mode="replicate"):
:math:`N` is (`win_length`-1)//2.

Args:
specgram (torch.Tensor): Tensor of audio of dimension (channel, freq, time)
specgram (torch.Tensor): Tensor of audio of dimension (..., freq, time)
win_length (int): The window length used for computing delta
mode (str): Mode parameter passed to padding

Returns:
deltas (torch.Tensor): Tensor of audio of dimension (channel, freq, time)
deltas (torch.Tensor): Tensor of audio of dimension (..., freq, time)

Example
>>> specgram = torch.randn(1, 40, 1000)
>>> delta = compute_deltas(specgram)
>>> delta2 = compute_deltas(delta)
"""

# pack batch
shape = specgram.size()
specgram = specgram.reshape(1, -1, shape[-1])

assert win_length >= 3
assert specgram.dim() == 3
assert not specgram.shape[1] % specgram.shape[0]

n = (win_length - 1) // 2

Expand All @@ -844,12 +867,15 @@ def compute_deltas(specgram, win_length=5, mode="replicate"):
kernel = (
torch
.arange(-n, n + 1, 1, device=specgram.device, dtype=specgram.dtype)
.repeat(specgram.shape[1], specgram.shape[0], 1)
.repeat(specgram.shape[1], 1, 1)
)

return torch.nn.functional.conv1d(
specgram, kernel, groups=specgram.shape[1] // specgram.shape[0]
) / denom
output = torch.nn.functional.conv1d(specgram, kernel, groups=specgram.shape[1]) / denom

# unpack batch
output = output.reshape(shape)

return output


@torch.jit.script
Expand Down Expand Up @@ -982,16 +1008,22 @@ def detect_pitch_frequency(
It is implemented using normalized cross-correlation function and median smoothing.

Args:
waveform (torch.Tensor): Tensor of audio of dimension (channel, freq, time)
waveform (torch.Tensor): Tensor of audio of dimension (..., freq, time)
sample_rate (int): The sample rate of the waveform (Hz)
win_length (int): The window length for median smoothing (in number of frames)
freq_low (int): Lowest frequency that can be detected (Hz)
freq_high (int): Highest frequency that can be detected (Hz)

Returns:
freq (torch.Tensor): Tensor of audio of dimension (channel, frame)
freq (torch.Tensor): Tensor of audio of dimension (..., frame)
"""

dim = waveform.dim()

# pack batch
shape = waveform.size()
waveform = waveform.reshape([-1] + shape[-1:])

nccf = _compute_nccf(waveform, sample_rate, frame_time, freq_low)
indices = _find_max_per_frame(nccf, sample_rate, freq_high)
indices = _median_smoothing(indices, win_length)
Expand All @@ -1000,4 +1032,7 @@ def detect_pitch_frequency(
EPSILON = 10 ** (-9)
freq = sample_rate / (EPSILON + indices.to(torch.float))

# unpack batch
freq = freq.reshape(shape[:-1] + freq.shape[-1:])

return freq