Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Small changes for flexibility and performance #68

Merged
merged 3 commits into from
Feb 9, 2024

Conversation

simonschwaer
Copy link
Contributor

Here are a few changes I made to make the library even more useful for me... I can remove some commits or split them into separated PRs if desired.

1. Add option to use any window function available in scipy.signal

This allows to use additional window types for STFTLoss by providing a window type implemented in scipy.signal.windows.get_window() (docs) – or at least the ones that do not need additional arguments. The torch windows are still supported, so it does not change the default behaviour or other existing code.

2. Change magnitude calculation in STFTLoss for better interpretability of the eps parameter

This would change the default behaviour, since previously, all values were clamped at $10^{-4}$ by default as torch.clamp is applied inside torch.sqrt, resulting in a threshold of $\sqrt{10^{-8}}$. I think it would be more intuitive if the eps specified in STFTLoss would actually specify the resulting smallest value.

3. Allow for more flexible log compression

By using $\log(\gamma x + \varepsilon)$ for the log-compression in STFTMagnitudeLoss, the amount of compression and output value range becomes more controllable. For example, $\varepsilon = 1$ would lead to a minimum output value of $0$, and $\gamma = 10$ would compress the magnitudes more. With the given default values, the commit does not change the previous default behaviour.

4. Compute torch.angle of STFT only when needed

I noticed that the MSS loss was slower than expected on my M1 GPU and CPU. The torch profiler showed that torch.angle is taking a lot of time, even though it is not used in many cases. The proposed improvement just calculates the angle only if it is actually used for the loss.

With torch.angle:

---------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
                       Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls  
---------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
                aten::angle        29.54%        3.359s        52.33%        5.951s     299.177us         19890  
             aten::_fft_r2c        14.02%        1.594s        14.02%        1.594s     129.756us         12288  
                  aten::abs         9.31%        1.059s        22.43%        2.550s     108.683us         23462  
                  aten::log         8.94%        1.016s         8.94%        1.016s      82.694us         12288  
     aten::reflection_pad1d         8.29%     942.407ms         8.29%     942.407ms      76.693us         12288  
                aten::copy_         7.33%     833.787ms         7.33%     833.787ms       7.755us        107520  
                  aten::mul         5.68%     646.045ms         5.96%     677.840ms      18.388us         36864  
                 aten::stft         3.96%     450.563ms        24.64%        2.801s     172.056us         16282  
             aten::mse_loss         3.46%     393.617ms         6.08%     691.316ms      56.259us         12288  
                  aten::add         2.17%     246.182ms         2.39%     271.226ms       8.544us         31744  
                  aten::sum         2.14%     243.566ms         2.18%     247.356ms      13.420us         18432  
                aten::clamp         1.80%     205.020ms         1.80%     205.047ms      16.687us         12288  
---------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 11.370s

Without torch.angle:

---------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
                       Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls  
---------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
             aten::_fft_r2c        21.45%        1.600s        21.46%        1.602s     130.347us         12288  
                  aten::abs        14.00%        1.045s        33.96%        2.535s     108.396us         23384  
                  aten::log        13.34%     995.217ms        13.34%     995.217ms      80.991us         12288  
     aten::reflection_pad1d        12.39%     924.513ms        12.39%     924.513ms      75.237us         12288  
                  aten::mul         8.89%     663.190ms         9.32%     695.856ms      18.876us         36864  
                 aten::stft         5.59%     417.491ms        36.85%        2.750s     172.311us         15960  
                aten::copy_         5.15%     384.563ms         5.15%     384.563ms       4.038us         95232  
             aten::mse_loss         5.02%     374.319ms         9.01%     672.708ms      54.745us         12288  
                  aten::sum         3.28%     245.045ms         3.34%     249.108ms      13.515us         18432  
                  aten::add         3.25%     242.866ms         3.59%     267.813ms       8.437us         31744  
                aten::clamp         2.83%     210.954ms         2.83%     210.959ms      17.168us         12288  
                  aten::pad         1.65%     122.975ms        12.47%     930.943ms      75.760us         12288
---------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 7.463s

This does not change the default behavior or other existing code, since the torch windows can still be accessed using the original arguments.
@csteinmetz1
Copy link
Owner

Thanks @simonschwaer ! I would love to merge these changes. I think everything is good that doesn't cause any change in default behavior. So, would it be possible to split change 2. dealing with eps to a separate PR?

By using `log(a*x+e)`, it is possible to control the amount of compression and the output value range. This does not change default behaviour.
This seems to be an expensive operation, at least on some architectures
@simonschwaer
Copy link
Contributor Author

simonschwaer commented Feb 9, 2024

Thanks @simonschwaer ! I would love to merge these changes. I think everything is good that doesn't cause any change in default behavior. So, would it be possible to split change 2. dealing with eps to a separate PR?

Done.

@csteinmetz1
Copy link
Owner

LGTM

@csteinmetz1 csteinmetz1 merged commit 0853e9e into csteinmetz1:main Feb 9, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants