Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add wpe_v8 to torch, i.e., a loopy implementation #79

Merged
merged 1 commit into from
Sep 10, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 82 additions & 0 deletions nara_wpe/torch_wpe.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,3 +225,85 @@ def wpe_v6(Y, taps=10, delay=3, iterations=3, psd_context=0, statistics_mode='fu
X = Y - torch.matmul(hermite(G), Y_tilde)

return X


def wpe_v8(
Y,
taps=10,
delay=3,
iterations=3,
psd_context=0,
statistics_mode='full',
inplace=False
):
"""
Loopy Multiple Input Multiple Output Weighted Prediction Error [1, 2] implementation

For numpy it is often the fastest Numpy implementation, torch has to be
profiled.
It loops over the independent axes. This reduces the memory footprint.

Args:
Y: Complex valued STFT signal with shape (..., D, T).
taps: Filter order
delay: Delay as a guard interval, such that X does not become zero.
iterations:
psd_context: Defines the number of elements in the time window
to improve the power estimation. Total number of elements will
be (psd_context + 1 + psd_context).
statistics_mode: Either 'full' or 'valid'.
'full': Pad the observation with zeros on the left for the
estimation of the correlation matrix and vector.
'valid': Only calculate correlation matrix and vector on valid
slices of the observation.
inplace: Whether to change Y inplace. Has only advantages, when Y has
independent axes, because the core WPE algorithm does not support
an inplace modification of the observation.
This option may be relevant, when Y is so large, that you do not
want to double the memory consumption (i.e. save Y and the
dereverberated signal in the memory).

Returns:
Estimated signal with the same shape as Y

[1] "Generalization of multi-channel linear prediction methods for blind MIMO
impulse response shortening", Yoshioka, Takuya and Nakatani, Tomohiro, 2012
[2] NARA-WPE: A Python package for weighted prediction error dereverberation in
Numpy and Tensorflow for online and offline processing, Drude, Lukas and
Heymann, Jahn and Boeddeker, Christoph and Haeb-Umbach, Reinhold, 2018

"""
ndim = Y.ndim
if ndim == 2:
out = wpe_v6(
Y,
taps=taps,
delay=delay,
iterations=iterations,
psd_context=psd_context,
statistics_mode=statistics_mode
)
if inplace:
Y[...] = out
return out
elif ndim >= 3:
if inplace:
out = Y
else:
out = torch.empty_like(Y)

for index in np.ndindex(Y.shape[:-2]):
out[index] = wpe_v6(
Y=Y[index],
taps=taps,
delay=delay,
iterations=iterations,
psd_context=psd_context,
statistics_mode=statistics_mode,
)
return out
else:
raise NotImplementedError(
'Input shape has to be (..., D, T) and not {}.'.format(Y.shape)
)

Loading