Skip to content

Commit

Permalink
Make resampling simpler and faster (#1087)
Browse files Browse the repository at this point in the history
  • Loading branch information
adefossez authored Jan 19, 2021
1 parent f1d8d1e commit e43a8e7
Showing 1 changed file with 64 additions and 201 deletions.
265 changes: 64 additions & 201 deletions torchaudio/compliance/kaldi.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import math
import torch
from torch import Tensor
from torch.nn import functional as F

import torchaudio
import torchaudio._internal.fft
Expand Down Expand Up @@ -752,141 +753,54 @@ def mfcc(
return feature


def _get_LR_indices_and_weights(orig_freq: float,
new_freq: float,
output_samples_in_unit: int,
window_width: float,
lowpass_cutoff: float,
lowpass_filter_width: int,
device: torch.device,
dtype: int) -> Tuple[Tensor, Tensor]:
r"""Based on LinearResample::SetIndexesAndWeights where it retrieves the weights for
resampling as well as the indices in which they are valid. LinearResample (LR) means
that the output signal is at linearly spaced intervals (i.e the output signal has a
frequency of ``new_freq``). It uses sinc/bandlimited interpolation to upsample/downsample
the signal.
The reason why the same filter is not used for multiple convolutions is because the
sinc function could sampled at different points in time. For example, suppose
a signal is sampled at the timestamps (seconds)
0 16 32
and we want it to be sampled at the timestamps (seconds)
0 5 10 15 20 25 30 35
at the timestamp of 16, the delta timestamps are
16 11 6 1 4 9 14 19
at the timestamp of 32, the delta timestamps are
32 27 22 17 12 8 2 3
As we can see from deltas, the sinc function is sampled at different points of time
assuming the center of the sinc function is at 0, 16, and 32 (the deltas [..., 6, 1, 4, ....]
for 16 vs [...., 2, 3, ....] for 32)
Example, one case is when the ``orig_freq`` and ``new_freq`` are multiples of each other then
there needs to be one filter.
A windowed filter function (i.e. Hanning * sinc) because the ideal case of sinc function
has infinite support (non-zero for all values) so instead it is truncated and multiplied by
a window function which gives it less-than-perfect rolloff [1].
[1] Chapter 16: Windowed-Sinc Filters, https://www.dspguide.com/ch16/1.htm
Args:
orig_freq (float): The original frequency of the signal
new_freq (float): The desired frequency
output_samples_in_unit (int): The number of output samples in the smallest repeating unit:
num_samp_out = new_freq / Gcd(orig_freq, new_freq)
window_width (float): The width of the window which is nonzero
lowpass_cutoff (float): The filter cutoff in Hz. The filter cutoff needs to be less
than samp_rate_in_hz/2 and less than samp_rate_out_hz/2.
lowpass_filter_width (int): Controls the sharpness of the filter, more == sharper but less
efficient. We suggest around 4 to 10 for normal use
Returns:
(Tensor, Tensor): A tuple of ``min_input_index`` (which is the minimum indices
where the window is valid, size (``output_samples_in_unit``)) and ``weights`` (which is the weights
which correspond with min_input_index, size (``output_samples_in_unit``, ``max_weight_width``)).
"""
assert lowpass_cutoff < min(orig_freq, new_freq) / 2
output_t = torch.arange(0., output_samples_in_unit, device=device, dtype=dtype) / new_freq
min_t = output_t - window_width
max_t = output_t + window_width

min_input_index = torch.ceil(min_t * orig_freq) # size (output_samples_in_unit)
max_input_index = torch.floor(max_t * orig_freq) # size (output_samples_in_unit)
num_indices = max_input_index - min_input_index + 1 # size (output_samples_in_unit)

max_weight_width = num_indices.max()
# create a group of weights of size (output_samples_in_unit, max_weight_width)
j = torch.arange(max_weight_width, device=device, dtype=dtype).unsqueeze(0)
input_index = min_input_index.unsqueeze(1) + j
delta_t = (input_index / orig_freq) - output_t.unsqueeze(1)

weights = torch.zeros_like(delta_t)
inside_window_indices = delta_t.abs().lt(window_width)
# raised-cosine (Hanning) window with width `window_width`
weights[inside_window_indices] = 0.5 * (1 + torch.cos(2 * math.pi * lowpass_cutoff /
lowpass_filter_width * delta_t[inside_window_indices]))

t_eq_zero_indices = delta_t.eq(0.0)
t_not_eq_zero_indices = ~t_eq_zero_indices
# sinc filter function
weights[t_not_eq_zero_indices] *= torch.sin(
2 * math.pi * lowpass_cutoff * delta_t[t_not_eq_zero_indices]) / (math.pi * delta_t[t_not_eq_zero_indices])
# limit of the function at t = 0
weights[t_eq_zero_indices] *= 2 * lowpass_cutoff

weights /= orig_freq # size (output_samples_in_unit, max_weight_width)
return min_input_index, weights


def _lcm(a: int, b: int) -> int:
return abs(a * b) // math.gcd(a, b)


def _get_num_LR_output_samples(input_num_samp: int,
samp_rate_in: float,
samp_rate_out: float) -> int:
r"""Based on LinearResample::GetNumOutputSamples. LinearResample (LR) means that
the output signal is at linearly spaced intervals (i.e the output signal has a
frequency of ``new_freq``). It uses sinc/bandlimited interpolation to upsample/downsample
the signal.
Args:
input_num_samp (int): The number of samples in the input
samp_rate_in (float): The original frequency of the signal
samp_rate_out (float): The desired frequency
Returns:
int: The number of output samples
"""
# For exact computation, we measure time in "ticks" of 1.0 / tick_freq,
# where tick_freq is the least common multiple of samp_rate_in and
# samp_rate_out.
samp_rate_in = int(samp_rate_in)
samp_rate_out = int(samp_rate_out)

tick_freq = _lcm(samp_rate_in, samp_rate_out)
ticks_per_input_period = tick_freq // samp_rate_in

# work out the number of ticks in the time interval
# [ 0, input_num_samp/samp_rate_in ).
interval_length_in_ticks = input_num_samp * ticks_per_input_period
if interval_length_in_ticks <= 0:
return 0
ticks_per_output_period = tick_freq // samp_rate_out
# Get the last output-sample in the closed interval, i.e. replacing [ ) with
# [ ]. Note: integer division rounds down. See
# http://en.wikipedia.org/wiki/Interval_(mathematics) for an explanation of
# the notation.
last_output_samp = interval_length_in_ticks // ticks_per_output_period
# We need the last output-sample in the open interval, so if it takes us to
# the end of the interval exactly, subtract one.
if last_output_samp * ticks_per_output_period == interval_length_in_ticks:
last_output_samp -= 1
# First output-sample index is zero, so the number of output samples
# is the last output-sample plus one.
num_output_samp = last_output_samp + 1
return num_output_samp
def _get_sinc_resample_kernel(orig_freq: int, new_freq: int, lowpass_filter_width: int,
device: torch.device, dtype: torch.dtype):
assert lowpass_filter_width > 0
kernels = []
base_freq = min(orig_freq, new_freq)
# This will perform antialiasing filtering by removing the highest frequencies.
# At first I thought I only needed this when downsampling, but when upsampling
# you will get edge artifacts without this, as the edge is equivalent to zero padding,
# which will add high freq artifacts.
base_freq *= 0.99

# The key idea of the algorithm is that x(t) can be exactly reconstructed from x[i] (tensor)
# using the sinc interpolation formula:
# x(t) = sum_i x[i] sinc(pi * orig_freq * (i / orig_freq - t))
# We can then sample the function x(t) with a different sample rate:
# y[j] = x(j / new_freq)
# or,
# y[j] = sum_i x[i] sinc(pi * orig_freq * (i / orig_freq - j / new_freq))

# We see here that y[j] is the convolution of x[i] with a specific filter, for which
# we take an FIR approximation, stopping when we see at least `lowpass_filter_width` zeros crossing.
# But y[j+1] is going to have a different set of weights and so on, until y[j + new_freq].
# Indeed:
# y[j + new_freq] = sum_i x[i] sinc(pi * orig_freq * ((i / orig_freq - (j + new_freq) / new_freq))
# = sum_i x[i] sinc(pi * orig_freq * ((i - orig_freq) / orig_freq - j / new_freq))
# = sum_i x[i + orig_freq] sinc(pi * orig_freq * (i / orig_freq - j / new_freq))
# so y[j+new_freq] uses the same filter as y[j], but on a shifted version of x by `orig_freq`.
# This will explain the F.conv1d after, with a stride of orig_freq.
width = math.ceil(lowpass_filter_width * orig_freq / base_freq)
# If orig_freq is still big after GCD reduction, most filters will be very unbalanced, i.e.,
# they will have a lot of almost zero values to the left or to the right...
# There is probably a way to evaluate those filters more efficiently, but this is kept for
# future work.
idx = torch.arange(-width, width + orig_freq, device=device, dtype=dtype)

for i in range(new_freq):
t = (-i / new_freq + idx / orig_freq) * base_freq
t = t.clamp_(-lowpass_filter_width, lowpass_filter_width)
t *= math.pi
# we do not use torch.hann_window here as we need to evaluate the window
# at spectifics positions, not over a regular grid.
window = torch.cos(t / lowpass_filter_width / 2)**2
kernel = torch.where(t == 0, torch.tensor(1.).to(t), torch.sin(t) / t)
kernel.mul_(window)
kernels.append(kernel)

scale = base_freq / orig_freq
return torch.stack(kernels).view(new_freq, 1, -1).mul_(scale), width


def resample_waveform(waveform: Tensor,
Expand All @@ -912,72 +826,21 @@ def resample_waveform(waveform: Tensor,
Returns:
Tensor: The waveform at the new frequency
"""
device, dtype = waveform.device, waveform.dtype

assert waveform.dim() == 2
assert orig_freq > 0.0 and new_freq > 0.0

min_freq = min(orig_freq, new_freq)
lowpass_cutoff = 0.99 * 0.5 * min_freq

assert lowpass_cutoff * 2 <= min_freq

base_freq = math.gcd(int(orig_freq), int(new_freq))
input_samples_in_unit = int(orig_freq) // base_freq
output_samples_in_unit = int(new_freq) // base_freq

window_width = lowpass_filter_width / (2.0 * lowpass_cutoff)
first_indices, weights = _get_LR_indices_and_weights(
orig_freq, new_freq, output_samples_in_unit,
window_width, lowpass_cutoff, lowpass_filter_width, device, dtype)

assert first_indices.dim() == 1
# TODO figure a better way to do this. conv1d reaches every element i*stride + padding
# all the weights have the same stride but have different padding.
# Current implementation takes the input and applies the various padding before
# doing a conv1d for that specific weight.
conv_stride = input_samples_in_unit
conv_transpose_stride = output_samples_in_unit
num_channels, wave_len = waveform.size()
window_size = weights.size(1)
tot_output_samp = _get_num_LR_output_samples(wave_len, orig_freq, new_freq)
output = torch.zeros((num_channels, tot_output_samp),
device=device, dtype=dtype)
# eye size: (num_channels, num_channels, 1)
eye = torch.eye(num_channels, device=device, dtype=dtype).unsqueeze(2)
for i in range(first_indices.size(0)):
wave_to_conv = waveform
first_index = int(first_indices[i].item())
if first_index >= 0:
# trim the signal as the filter will not be applied before the first_index
wave_to_conv = wave_to_conv[..., first_index:]

# pad the right of the signal to allow partial convolutions meaning compute
# values for partial windows (e.g. end of the window is outside the signal length)
max_unit_index = (tot_output_samp - 1) // output_samples_in_unit
end_index_of_last_window = max_unit_index * conv_stride + window_size
current_wave_len = wave_len - first_index
right_padding = max(0, end_index_of_last_window + 1 - current_wave_len)

left_padding = max(0, -first_index)
if left_padding != 0 or right_padding != 0:
wave_to_conv = torch.nn.functional.pad(wave_to_conv, (left_padding, right_padding))

conv_wave = torch.nn.functional.conv1d(
wave_to_conv.unsqueeze(0), weights[i].repeat(num_channels, 1, 1),
stride=conv_stride, groups=num_channels)

# we want conv_wave[:, i] to be at output[:, i + n*conv_transpose_stride]
dilated_conv_wave = torch.nn.functional.conv_transpose1d(
conv_wave, eye, stride=conv_transpose_stride).squeeze(0)

# pad dilated_conv_wave so it reaches the output length if needed.
dialated_conv_wave_len = dilated_conv_wave.size(-1)
left_padding = i
right_padding = max(0, tot_output_samp - (left_padding + dialated_conv_wave_len))
dilated_conv_wave = torch.nn.functional.pad(
dilated_conv_wave, (left_padding, right_padding))[..., :tot_output_samp]

output += dilated_conv_wave

return output
orig_freq = int(orig_freq)
new_freq = int(new_freq)
gcd = math.gcd(orig_freq, new_freq)
orig_freq = orig_freq // gcd
new_freq = new_freq // gcd

kernel, width = _get_sinc_resample_kernel(orig_freq, new_freq, lowpass_filter_width,
waveform.device, waveform.dtype)

num_wavs, length = waveform.shape
waveform = F.pad(waveform, (width, width + orig_freq))
resampled = F.conv1d(waveform[:, None], kernel, stride=orig_freq)
resampled = resampled.transpose(1, 2).reshape(num_wavs, -1)
target_length = int(math.ceil(new_freq * length / orig_freq))
return resampled[..., :target_length]

0 comments on commit e43a8e7

Please sign in to comment.