Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

audio metrics: SNR, SI_SDR, SI_SNR #292

Merged
merged 126 commits into from
Jun 22, 2021
Merged
Show file tree
Hide file tree
Changes from 111 commits
Commits
Show all changes
126 commits
Select commit Hold shift + click to select a range
74b1b9b
add snr, si_sdr, si_snr
quancs Jun 13, 2021
a653056
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 13, 2021
70f82b4
format
quancs Jun 13, 2021
9e30f76
Merge branch 'audio-dev' of https://github.com/quancs/metrics into au…
quancs Jun 13, 2021
0331dd0
add noqa: F401 to __init__.py
quancs Jun 13, 2021
fcc827e
Update torchmetrics/functional/audio/si_sdr.py
quancs Jun 13, 2021
ef19181
Update torchmetrics/functional/audio/si_sdr.py
quancs Jun 13, 2021
80e6c53
Update torchmetrics/functional/audio/si_sdr.py
quancs Jun 13, 2021
32c0ce0
Update torchmetrics/functional/audio/si_snr.py
quancs Jun 13, 2021
489dba0
Update torchmetrics/functional/audio/snr.py
quancs Jun 13, 2021
249d848
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 13, 2021
017ba3a
remove types in doc, change estimate to preds, remove EPS
quancs Jun 13, 2021
c8aa372
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 13, 2021
c82a094
update functional.rst
quancs Jun 13, 2021
a33ff6e
update CHANGELOG.md
quancs Jun 13, 2021
282681d
switch preds and target
quancs Jun 13, 2021
2ba693c
switch preds and target in Example
quancs Jun 13, 2021
b0a8382
add SNR, SI_SNR, SI_SDR module implementation
quancs Jun 13, 2021
4e08d16
Merge branch 'audio-dev' of https://github.com/quancs/metrics into au…
quancs Jun 13, 2021
db57b13
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 13, 2021
fcfe0ac
add test
quancs Jun 13, 2021
1d77709
add module doc
quancs Jun 13, 2021
6a2d886
Update torchmetrics/audio/SI_SDR.py
quancs Jun 13, 2021
44a4a17
Update torchmetrics/audio/SI_SDR.py
quancs Jun 13, 2021
56ee58e
Update torchmetrics/audio/SI_SNR.py
quancs Jun 13, 2021
0f551b0
Update torchmetrics/audio/SI_SNR.py
quancs Jun 13, 2021
2a62352
Update torchmetrics/audio/SNR.py
quancs Jun 13, 2021
c4bd0c5
Update torchmetrics/audio/SNR.py
quancs Jun 13, 2021
3bd1e7b
Update torchmetrics/functional/audio/si_snr.py
quancs Jun 13, 2021
e028cd3
use _check_same_shape
quancs Jun 13, 2021
96c579a
to alphabetical order
quancs Jun 13, 2021
0848743
Merge branch 'audio-dev' of https://github.com/quancs/metrics into au…
quancs Jun 13, 2021
19e0f0b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 13, 2021
d7b5b0d
update test
quancs Jun 13, 2021
b179eca
Merge branch 'master' into audio-dev
SkafteNicki Jun 14, 2021
ea3aee4
Update docs/source/references/modules.rst
quancs Jun 14, 2021
fe6e6bc
move Base to the top of Audio
quancs Jun 14, 2021
67467cb
Merge branch 'master' into audio-dev
mergify[bot] Jun 14, 2021
58734ed
add soundfile
quancs Jun 14, 2021
f776693
Merge branch 'audio-dev' of https://github.com/quancs/metrics into au…
quancs Jun 14, 2021
868a9d3
gcc
Borda Jun 15, 2021
6ec17a8
fix cyclic import
Borda Jun 15, 2021
1b0d379
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 15, 2021
37a812f
pysndfile
Borda Jun 15, 2021
cd51a6c
Merge branch 'audio-dev' of https://github.com/quancs/metrics into au…
Borda Jun 15, 2021
7423b6e
v0.4.5
Borda Jun 15, 2021
0bc3956
pl
Borda Jun 15, 2021
587b26e
clang
Borda Jun 15, 2021
fdceaf4
Add FID metric (#213)
SkafteNicki Jun 15, 2021
01b5c2e
fix cyclic import `_reduce_stat_scores` (#296)
Borda Jun 15, 2021
b4ee3f8
Merge branch 'master' into audio-dev
Borda Jun 15, 2021
d500663
update test_snr
quancs Jun 15, 2021
3501dd0
update test_si_snr
quancs Jun 15, 2021
7ae7778
new snr: use torch.finfo(preds.dtype).eps
quancs Jun 15, 2021
e7dfbee
update test_snr.py
quancs Jun 15, 2021
de67913
new si_sdr imp
quancs Jun 15, 2021
6de666f
update test_si_sdr
quancs Jun 15, 2021
5fdbb0d
update test_si_snr
quancs Jun 15, 2021
608b21a
remove pb_bss_eval
quancs Jun 15, 2021
aea0091
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 15, 2021
0add577
add museval
quancs Jun 15, 2021
74d0550
update test files
quancs Jun 16, 2021
f7249b1
remove museval
quancs Jun 16, 2021
8f3e032
add funcs update return None annotation
quancs Jun 16, 2021
e876520
Merge branch 'master' into audio-dev
mergify[bot] Jun 16, 2021
be1ca0b
add 'Setup ffmpeg'
quancs Jun 16, 2021
4ce6b02
Merge branch 'audio-dev' of https://github.com/quancs/metrics into au…
quancs Jun 16, 2021
134d8cd
update "Setup ffmpeg"
quancs Jun 16, 2021
794d1ce
use setup-conda@v1
quancs Jun 16, 2021
2da22cd
multi-OS
Borda Jun 16, 2021
e786ed2
update atol to 1e-5
quancs Jun 16, 2021
cd9504e
Merge branch 'audio-dev' of https://github.com/quancs/metrics into au…
quancs Jun 16, 2021
662cf28
Apply suggestions from code review
Borda Jun 16, 2021
d44be30
change atol to 1e-2
quancs Jun 16, 2021
fe3385b
Merge branch 'audio-dev' of https://github.com/quancs/metrics into au…
quancs Jun 16, 2021
5e875ce
update
quancs Jun 16, 2021
5e3f5c8
fix 'Setup Linux' not activated
quancs Jun 16, 2021
f08916d
add sudo
quancs Jun 16, 2021
05226eb
Merge branch 'master' into audio-dev
mergify[bot] Jun 16, 2021
2f26cd6
reduce Time to 100 to reduce the test time
quancs Jun 17, 2021
ad593c7
increase timeoutInMinutes to 40
quancs Jun 17, 2021
2b2946b
Merge branch 'audio-dev' of https://github.com/quancs/metrics into au…
quancs Jun 17, 2021
377d4a4
install ffmpeg
quancs Jun 17, 2021
41fddfb
timeout-minutes to 55
quancs Jun 17, 2021
53a23c8
+git
quancs Jun 17, 2021
e8a6b6d
show-error-codes
quancs Jun 17, 2021
70a655e
.detach().cpu().numpy() first
quancs Jun 17, 2021
96e7280
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 17, 2021
a23e165
add numpy
quancs Jun 17, 2021
ddbd966
Merge branch 'audio-dev' of https://github.com/quancs/metrics into au…
quancs Jun 17, 2021
23b4683
numpy
Borda Jun 17, 2021
d726ba3
ignore_errors torchmetrics.audio.*
quancs Jun 17, 2021
0861c6e
solve mypy no-redef error
quancs Jun 17, 2021
1848f9e
remove --quiet
quancs Jun 17, 2021
8469dc8
pypesq
Borda Jun 17, 2021
4f3d41f
apt
Borda Jun 18, 2021
9358bfd
Inception Score (#299)
SkafteNicki Jun 17, 2021
fd3aa0a
Apply suggestions from code review
justusschock Jun 18, 2021
05e0f19
add # type: ignore
quancs Jun 18, 2021
a351c8d
try without test_si_snr & test_si_sdr
quancs Jun 18, 2021
67da485
test_import_speechmetrics
quancs Jun 18, 2021
64ffd9a
test_speechmetrics_si_sdr
quancs Jun 18, 2021
0b5fb27
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 18, 2021
85d9f72
test_si_sdr_functional
quancs Jun 18, 2021
f5afe0f
test audio only
quancs Jun 18, 2021
6e11386
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 18, 2021
fe9a7cf
install libsndfile1
quancs Jun 18, 2021
9491551
add sisnr sisdr test
quancs Jun 18, 2021
2236698
test all & add quiet & remove test_speechmetrics
quancs Jun 18, 2021
f5a0411
remove sudo & install libsndfile1
quancs Jun 18, 2021
3858397
apt
Borda Jun 18, 2021
3eb3d65
Update torchmetrics/functional/audio/si_sdr.py
quancs Jun 21, 2021
4c18540
[feat] Add _apply_sync to nn.Metric (#302)
tchaton Jun 21, 2021
3408927
v0.4.0rc0
Borda Jun 21, 2021
267380d
[pre-commit.ci] pre-commit suggestions (#306)
pre-commit-ci[bot] Jun 21, 2021
69c8fbc
adding KID metric (#301)
SkafteNicki Jun 21, 2021
e7753e1
Apply suggestions from code review
Borda Jun 22, 2021
e82bc96
SRMRpy
Borda Jun 22, 2021
bb045c9
pesq
Borda Jun 22, 2021
70687f1
gcc
Borda Jun 22, 2021
050bb7b
comment
Borda Jun 22, 2021
de034a3
env
Borda Jun 22, 2021
2fa46af
Merge branch 'master' into audio-dev
Borda Jun 22, 2021
cda3c6b
env
Borda Jun 22, 2021
47e0cf8
Merge branch 'audio-dev' of https://github.com/quancs/metrics into au…
Borda Jun 22, 2021
fec3598
env
Borda Jun 22, 2021
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .github/workflows/ci_test-conda.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ jobs:
pytorch-version: [1.3, 1.4, 1.5, 1.6, 1.7, 1.8, 1.9]

# Timeout: https://stackoverflow.com/a/59076067/4521646
timeout-minutes: 35
timeout-minutes: 55
steps:
- uses: actions/checkout@v2

Expand Down Expand Up @@ -54,9 +54,11 @@ jobs:

- name: Update Environment
run: |
sudo apt install libsndfile1
conda info
conda install mkl pytorch=${{ matrix.pytorch-version }} cpuonly
conda install cpuonly $(python ./requirements/adjust-versions.py conda)
conda install -c conda-forge ffmpeg
conda list
pip --version
python ./requirements/adjust-versions.py requirements.txt
Expand Down
13 changes: 10 additions & 3 deletions .github/workflows/ci_test-full.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ jobs:
requires: 'minimal'

# Timeout: https://stackoverflow.com/a/59076067/4521646
timeout-minutes: 35
timeout-minutes: 55

steps:
- uses: actions/checkout@v2
Expand All @@ -43,7 +43,15 @@ jobs:
- name: Setup macOS
if: runner.os == 'macOS'
run: |
brew install libomp # https://github.com/pytorch/pytorch/issues/20030
brew install gcc libomp ffmpeg # https://github.com/pytorch/pytorch/issues/20030
- name: Setup Linux
if: runner.os == 'Linux'
run: |
sudo apt install -y ffmpeg
- name: Setup Windows
if: runner.os == 'windows'
run: |
choco install ffmpeg

- name: Set min. dependencies
if: matrix.requires == 'minimal'
Expand All @@ -70,7 +78,6 @@ jobs:

- name: Install dependencies
run: |
python --version
pip --version
pip install --requirement requirements.txt --upgrade --find-links https://download.pytorch.org/whl/cpu/torch_stable.html
python ./requirements/adjust-versions.py requirements.txt
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/code-format.yml
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ jobs:
pip list
- name: mypy
run: |
mypy
mypy --show-error-codes

# format-check-yapf:
# runs-on: ubuntu-20.04
Expand Down
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added `is_differentiable` property to `ConfusionMatrix`, `F1`, `FBeta`, `Hamming`, `Hinge`, `IOU`, `MatthewsCorrcoef`, `Precision`, `Recall`, `PrecisionRecallCurve`, `ROC`, `StatScores` ([#253](https://github.com/PyTorchLightning/metrics/pull/253))


- Added audio metrics: SNR, SI_SDR, SI_SNR ([#292](https://github.com/PyTorchLightning/metrics/pull/292))


- Added Inception Score metric to image module ([#299](https://github.com/PyTorchLightning/metrics/pull/299))


### Changed

- Forward cache is now reset when `reset` method is called ([#260](https://github.com/PyTorchLightning/metrics/pull/260))
Expand Down
9 changes: 6 additions & 3 deletions azure-pipelines.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,14 @@ pr:
jobs:
- job: pytest
# how long to run the job before automatically cancelling
timeoutInMinutes: 25
timeoutInMinutes: 45
# how much time to give 'run always even if cancelled tasks' before stopping them
cancelTimeoutInMinutes: 2

pool: gridai-spot-pool

container:
image: "pytorch/pytorch:1.7.1-cuda11.0-cudnn8-runtime"
image: "pytorch/pytorch:1.8.1-cuda11.1-cudnn8-runtime"
options: "--runtime=nvidia -e NVIDIA_VISIBLE_DEVICES=all"

workspace:
Expand All @@ -44,11 +44,14 @@ jobs:
displayName: 'Image info & NVIDIA'

- bash: |
#sudo apt-get install -y cmake
sudo apt-get update
sudo apt install -y cmake ffmpeg git libsndfile1
# python -m pip install "pip==20.1"
pip install --requirement ./requirements/devel.txt --upgrade-strategy only-if-needed
pip uninstall -y torchmetrics
pip list
env:
DEBIAN_FRONTEND: noninteractive
displayName: 'Install dependencies'

- bash: |
Expand Down
25 changes: 25 additions & 0 deletions docs/source/references/functional.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,31 @@
Functional metrics
##################

*************
Audio Metrics
*************

si_sdr [func]
~~~~~~~~~~~~~

.. autofunction:: torchmetrics.functional.si_sdr
:noindex:


si_snr [func]
~~~~~~~~~~~~~

.. autofunction:: torchmetrics.functional.si_snr
:noindex:


snr [func]
~~~~~~~~~~

.. autofunction:: torchmetrics.functional.snr
:noindex:


**********************
Classification Metrics
**********************
Expand Down
48 changes: 48 additions & 0 deletions docs/source/references/modules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,46 @@ your own metric type might be too burdensome.
.. autoclass:: torchmetrics.AverageMeter
:noindex:

*************
Audio Metrics
*************

About Audio Metrics
~~~~~~~~~~~~~~~~~~~

For the purposes of audio metrics, inputs (predictions, targets) must have the same size.
quancs marked this conversation as resolved.
Show resolved Hide resolved
If the input is 1D tensors the output will be a scalar. If the input is multi-dimensional with shape [..., time]` the metric will be computed over the `time` dimension.

.. doctest::

>>> import torch
>>> from torchmetrics import SNR
>>> target = torch.tensor([3.0, -0.5, 2.0, 7.0])
>>> preds = torch.tensor([2.5, 0.0, 2.0, 8.0])
>>> snr = SNR()
>>> snr_val = snr(preds, target)
>>> snr_val
tensor(16.1805)

SI_SDR
~~~~~~

.. autoclass:: torchmetrics.SI_SDR
:noindex:

SI_SNR
~~~~~~

.. autoclass:: torchmetrics.SI_SNR
:noindex:

SNR
~~~

.. autoclass:: torchmetrics.SNR
:noindex:


**********************
Classification Metrics
**********************
Expand Down Expand Up @@ -257,9 +297,17 @@ Image Quality Metrics
Image quality metrics can be used to access the quality of synthetic generated images from machine
learning algorithms such as `Generative Adverserial Networks (GANs) <https://en.wikipedia.org/wiki/Generative_adversarial_network>`_.

FID
~~~

.. autoclass:: torchmetrics.FID
:noindex:

IS
~~

.. autoclass:: torchmetrics.IS
:noindex:

******************
Regression Metrics
Expand Down
5 changes: 5 additions & 0 deletions requirements/test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,8 @@ nltk>=3.6

# add extra requirements
-r image.txt

# audio
pypesq
mir_eval>=0.6
https://github.com/aliutkus/speechmetrics/archive/refs/heads/master.zip
Borda marked this conversation as resolved.
Show resolved Hide resolved
Empty file added tests/audio/__init__.py
Empty file.
131 changes: 131 additions & 0 deletions tests/audio/test_si_sdr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import namedtuple
from functools import partial

import pytest
import speechmetrics
import torch
from torch import Tensor

from tests.helpers import seed_all
from tests.helpers.testers import BATCH_SIZE, NUM_BATCHES, MetricTester
from torchmetrics.audio import SI_SDR
from torchmetrics.functional import si_sdr
from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_6

seed_all(42)

Time = 100

Input = namedtuple('Input', ["preds", "target"])

inputs = Input(
preds=torch.rand(NUM_BATCHES, BATCH_SIZE, 1, Time),
target=torch.rand(NUM_BATCHES, BATCH_SIZE, 1, Time),
)
SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved

speechmetrics_sisdr = speechmetrics.load('sisdr')


def speechmetrics_si_sdr(preds: Tensor, target: Tensor, zero_mean: bool):
# shape: preds [BATCH_SIZE, 1, Time] , target [BATCH_SIZE, 1, Time]
# or shape: preds [NUM_BATCHES*BATCH_SIZE, 1, Time] , target [NUM_BATCHES*BATCH_SIZE, 1, Time]
if zero_mean:
preds = preds - preds.mean(dim=2, keepdim=True)
target = target - target.mean(dim=2, keepdim=True)
target = target.detach().cpu().numpy()
preds = preds.detach().cpu().numpy()
mss = []
for i in range(preds.shape[0]):
ms = []
for j in range(preds.shape[1]):
metric = speechmetrics_sisdr(preds[i, j], target[i, j], rate=16000)
ms.append(metric['sisdr'][0])
mss.append(ms)
return torch.tensor(mss)


def average_metric(preds, target, metric_func):
# shape: preds [BATCH_SIZE, 1, Time] , target [BATCH_SIZE, 1, Time]
# or shape: preds [NUM_BATCHES*BATCH_SIZE, 1, Time] , target [NUM_BATCHES*BATCH_SIZE, 1, Time]
return metric_func(preds, target).mean()


speechmetrics_si_sdr_zero_mean = partial(speechmetrics_si_sdr, zero_mean=True)
speechmetrics_si_sdr_no_zero_mean = partial(speechmetrics_si_sdr, zero_mean=False)


@pytest.mark.parametrize(
"preds, target, sk_metric, zero_mean",
[
(inputs.preds, inputs.target, speechmetrics_si_sdr_zero_mean, True),
(inputs.preds, inputs.target, speechmetrics_si_sdr_no_zero_mean, False),
],
)
class TestSISDR(MetricTester):
atol = 1e-2

@pytest.mark.parametrize("ddp", [True, False])
@pytest.mark.parametrize("dist_sync_on_step", [True, False])
def test_si_sdr(self, preds, target, sk_metric, zero_mean, ddp, dist_sync_on_step):
self.run_class_metric_test(
ddp,
preds,
target,
SI_SDR,
sk_metric=partial(average_metric, metric_func=sk_metric),
dist_sync_on_step=dist_sync_on_step,
metric_args=dict(zero_mean=zero_mean),
)

def test_si_sdr_functional(self, preds, target, sk_metric, zero_mean):
self.run_functional_metric_test(
preds,
target,
si_sdr,
sk_metric,
metric_args=dict(zero_mean=zero_mean),
)

def test_si_sdr_differentiability(self, preds, target, sk_metric, zero_mean):
self.run_differentiability_test(
preds=preds,
target=target,
metric_module=SI_SDR,
metric_functional=si_sdr,
metric_args={'zero_mean': zero_mean}
)

@pytest.mark.skipif(
not _TORCH_GREATER_EQUAL_1_6, reason='half support of core operations on not support before pytorch v1.6'
)
def test_si_sdr_half_cpu(self, preds, target, sk_metric, zero_mean):
pytest.xfail("SI-SDR metric does not support cpu + half precision")

@pytest.mark.skipif(not torch.cuda.is_available(), reason='test requires cuda')
def test_si_sdr_half_gpu(self, preds, target, sk_metric, zero_mean):
self.run_precision_test_gpu(
preds=preds,
target=target,
metric_module=SI_SDR,
metric_functional=si_sdr,
metric_args={'zero_mean': zero_mean}
)


def test_error_on_different_shape(metric_class=SI_SDR):
metric = metric_class()
with pytest.raises(RuntimeError, match='Predictions and targets are expected to have the same shape'):
metric(torch.randn(100, ), torch.randn(50, ))
Loading