Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: reduce computations in backprop of lfilter #3831

Merged
merged 5 commits into from
Sep 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions docs/source/refs.bib
Original file line number Diff line number Diff line change
Expand Up @@ -633,3 +633,11 @@ @article{forgione2021dynonet
year={2021},
publisher={Wiley Online Library}
}

@inproceedings{ycy2024diffapf,
title={Differentiable All-pole Filters for Time-varying Audio Systems},
author={Chin-Yun Yu and Christopher Mitcheltree and Alistair Carson and Stefan Bilbao and Joshua D. Reiss and György Fazekas},
booktitle={International Conference on Digital Audio Effects (DAFx)},
year={2024},
pages={345--352},
}
31 changes: 14 additions & 17 deletions src/libtorchaudio/lfilter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,6 @@ class DifferentiableIIR : public torch::autograd::Function<DifferentiableIIR> {
auto a_coeffs_normalized = saved[1];
auto y = saved[2];

int64_t n_batch = x.size(0);
int64_t n_channel = x.size(1);
int64_t n_order = a_coeffs_normalized.size(1);

Expand All @@ -161,26 +160,24 @@ class DifferentiableIIR : public torch::autograd::Function<DifferentiableIIR> {

namespace F = torch::nn::functional;

if (a_coeffs_normalized.requires_grad()) {
auto dyda = F::pad(
DifferentiableIIR::apply(-y, a_coeffs_normalized),
F::PadFuncOptions({n_order - 1, 0}));

da = F::conv1d(
dyda.view({1, n_batch * n_channel, -1}),
dy.reshape({n_batch * n_channel, 1, -1}),
F::Conv1dFuncOptions().groups(n_batch * n_channel))
.view({n_batch, n_channel, -1})
.sum(0)
.flip(1);
}
auto tmp =
DifferentiableIIR::apply(dy.flip(2).contiguous(), a_coeffs_normalized)
.flip(2);

if (x.requires_grad()) {
dx =
DifferentiableIIR::apply(dy.flip(2).contiguous(), a_coeffs_normalized)
.flip(2);
dx = tmp;
}

if (a_coeffs_normalized.requires_grad()) {
da = -torch::matmul(
tmp.transpose(0, 1).reshape({n_channel, 1, -1}),
F::pad(y, F::PadFuncOptions({n_order - 1, 0}))
.unfold(2, n_order, 1)
.transpose(0, 1)
.reshape({n_channel, -1, n_order}))
.squeeze(1)
.flip(1);
}
return {dx, da};
}
};
Expand Down
3 changes: 2 additions & 1 deletion src/torchaudio/functional/filtering.py
Original file line number Diff line number Diff line change
Expand Up @@ -999,7 +999,8 @@ def _lfilter_core(

def lfilter(waveform: Tensor, a_coeffs: Tensor, b_coeffs: Tensor, clamp: bool = True, batching: bool = True) -> Tensor:
r"""Perform an IIR filter by evaluating difference equation, using differentiable implementation
developed independently by *Yu et al.* :cite:`ismir_YuF23` and *Forgione et al.* :cite:`forgione2021dynonet`.
developed separately by *Yu et al.* :cite:`ismir_YuF23` and *Forgione et al.* :cite:`forgione2021dynonet`.
The gradients of ``a_coeffs`` are computed based on a faster algorithm from :cite:`ycy2024diffapf`.

.. devices:: CPU CUDA

Expand Down
Loading