-
Notifications
You must be signed in to change notification settings - Fork 696
Speed up resample with kernel generation modification #2553
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
1. loops are replaced with broadcasting 2. device is cast at end of kernel creation 3. dtype is defaulted to float32 ( float64) does nothing valuable 4. defaults for beta and dtype are moved to function declaration
1. loops are replaced with broadcasting 2. device is cast at end of kernel creation 3. dtype is defaulted to float32 ( float64) does nothing valuable 4. defaults for beta and dtype are moved to function declaration
cc @adefossez |
|
||
# we do not use built in torch windows here as we need to evaluate the window |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you leave this comment in?
torchaudio/functional/functional.py
Outdated
beta_tensor = torch.tensor(float(beta)) | ||
window = torch.i0(beta_tensor * torch.sqrt(1 - (t / lowpass_filter_width) ** 2)) / torch.i0(beta_tensor) | ||
|
||
t *= torch.pi |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we can leave it using math.pi
torchaudio/functional/functional.py
Outdated
kernels.append(kernel) | ||
t = torch.arange(0, -new_freq, -1, dtype=dtype)[:, None, None] / new_freq + idx | ||
t *= base_freq | ||
t = t.clamp_(-lowpass_filter_width, lowpass_filter_width) | ||
|
||
scale = base_freq / orig_freq |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
scale can be moved down closer to where it's actually being used, after the window parts
torchaudio/functional/functional.py
Outdated
|
||
kernels = torch.where(t == 0, torch.tensor(1.0).to(t), t.sin() / t) | ||
kernels *= window * scale | ||
kernels.to(device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
here kernels
is being moved to the device after computation, rather than starting out on this device (originally device was set in idx
). the GH issue mentions this is preferred since computation is faster on cpu than gpu, but can you additionally run benchmarks where device=gpu to compare this? if you'd like, you could also first remove this change and merge the rest of this PR which looks good, and follow up on this change separately
Looking good to me :) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
@skim0514 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
missing the second line to the comment but otherwise looks good!
t *= base_freq | ||
t = t.clamp_(-lowpass_filter_width, lowpass_filter_width) | ||
|
||
# we do not use built in torch windows here as we need to evaluate the window |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
there's a second line to this comment missing
@skim0514 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
Modification from pull request #2415 to improve resample.
Benchmarked for a 89% time reduction, tested in comparison to original resample method.