From 74b1b9ba58e25648678927f44ce7280b5edfb214 Mon Sep 17 00:00:00 2001 From: quancs Date: Sun, 13 Jun 2021 13:32:57 +0800 Subject: [PATCH 001/109] add snr, si_sdr, si_snr --- torchmetrics/functional/audio/__init__.py | 16 ++++++ torchmetrics/functional/audio/si_sdr.py | 59 +++++++++++++++++++++++ torchmetrics/functional/audio/si_snr.py | 44 +++++++++++++++++ torchmetrics/functional/audio/snr.py | 57 ++++++++++++++++++++++ 4 files changed, 176 insertions(+) create mode 100644 torchmetrics/functional/audio/__init__.py create mode 100644 torchmetrics/functional/audio/si_sdr.py create mode 100644 torchmetrics/functional/audio/si_snr.py create mode 100644 torchmetrics/functional/audio/snr.py diff --git a/torchmetrics/functional/audio/__init__.py b/torchmetrics/functional/audio/__init__.py new file mode 100644 index 00000000000..c2230dd0926 --- /dev/null +++ b/torchmetrics/functional/audio/__init__.py @@ -0,0 +1,16 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from torchmetrics.functional.audio.snr import snr +from torchmetrics.functional.audio.si_sdr import si_sdr +from torchmetrics.functional.audio.si_snr import si_snr diff --git a/torchmetrics/functional/audio/si_sdr.py b/torchmetrics/functional/audio/si_sdr.py new file mode 100644 index 00000000000..98d5295027f --- /dev/null +++ b/torchmetrics/functional/audio/si_sdr.py @@ -0,0 +1,59 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch + + +def si_sdr(target, estimate, zero_mean=False, EPS=1e-8): + """ scale-invariant signal-to-distortion ratio (SI-SDR) + + Args: + target (Tensor): shape [..., time] + estimate (Tensor): shape [..., time] + zero_mean (Bool): if to zero mean target and estimate or not + EPS (float, optional): a small value for numerical stability. Defaults to 1e-8. + + Raises: + TypeError: if target and estimate have a different shape + + Returns: + Tensor: si-sdr value has a shape of [...] + + Example: + >>> from torchmetrics.functional.audio import si_sdr + >>> target = torch.tensor([3.0, -0.5, 2.0, 7.0]) + >>> estimate = torch.tensor([2.5, 0.0, 2.0, 8.0]) + >>> si_sdr_val = si_sdr(target,estimate) + >>> si_sdr_val + tensor(18.4030) + + References: + [1] Le Roux, Jonathan, et al. "SDR half-baked or well done." IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP) 2019. + """ + + if target.shape != estimate.shape: + raise TypeError(f"Inputs must be of shape [..., time], got {target.shape} and {estimate.shape} instead") + + if zero_mean: + target = target - torch.mean(target, dim=-1, keepdim=True) + estimate = estimate - torch.mean(estimate, dim=-1, keepdim=True) + + α = torch.sum(estimate * target, dim=-1, keepdim=True) / (torch.sum(target**2, dim=-1, keepdim=True) + EPS) + target_scaled = α * target + + noise = target_scaled - estimate + + si_sdr_value = torch.sum(target_scaled**2, dim=-1) / (torch.sum(noise**2, dim=-1) + EPS) + si_sdr_value = 10 * torch.log10(si_sdr_value + EPS) + + return si_sdr_value diff --git a/torchmetrics/functional/audio/si_snr.py b/torchmetrics/functional/audio/si_snr.py new file mode 100644 index 00000000000..dc74246de3b --- /dev/null +++ b/torchmetrics/functional/audio/si_snr.py @@ -0,0 +1,44 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +from .si_sdr import si_sdr + + +def si_snr(target, estimate, EPS=1e-8): + """ scale-invariant signal-to-noise ratio (SI-SNR) + + Args: + target (Tensor): shape [..., time] + estimate (Tensor): shape [..., time] + EPS (float, optional): a small value for numerical stability. Defaults to 1e-8. + + Raises: + TypeError: if target and estimate have a different shape + + Returns: + Tensor: si-snr value has a shape of [...] + + Example: + >>> from torchmetrics.functional.audio import si_snr + >>> target = torch.tensor([3.0, -0.5, 2.0, 7.0]) + >>> estimate = torch.tensor([2.5, 0.0, 2.0, 8.0]) + >>> si_snr_val = si_snr(target,estimate) + >>> si_snr_val + tensor(15.0918) + + References: + [1] Y. Luo and N. Mesgarani, "TaSNet: Time-Domain Audio Separation Network for Real-Time, Single-Channel Speech Separation," 2018 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP), 2018, pp. 696-700, doi: 10.1109/ICASSP.2018.8462116. + """ + + return si_sdr(target=target, estimate=estimate, zero_mean=True, EPS=EPS) diff --git a/torchmetrics/functional/audio/snr.py b/torchmetrics/functional/audio/snr.py new file mode 100644 index 00000000000..ccdaa59df77 --- /dev/null +++ b/torchmetrics/functional/audio/snr.py @@ -0,0 +1,57 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +from torch import Tensor + + +def snr(target, estimate, zero_mean=False, EPS=1e-8) -> Tensor: + """ signal-to-noise ratio (SNR) + + Args: + target (Tensor): shape [..., time] + estimate (Tensor): shape [..., time] + zero_mean (Bool): if to zero mean target and estimate or not + EPS (float, optional): a small value for numerical stability. Defaults to 1e-8. + + Raises: + TypeError: if target and estimate have a different shape + + Returns: + Tensor: snr value has a shape of [...] + + Example: + >>> from torchmetrics.functional.audio import snr + >>> target = torch.tensor([3.0, -0.5, 2.0, 7.0]) + >>> estimate = torch.tensor([2.5, 0.0, 2.0, 8.0]) + >>> snr_val = snr(target,estimate) + >>> snr_val + tensor(16.1805) + + References: + [1] Le Roux, Jonathan, et al. "SDR half-baked or well done." IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP) 2019. + """ + + if target.shape != estimate.shape: + raise TypeError(f"Inputs must be of shape [..., time], got {target.shape} and {estimate.shape} instead") + + if zero_mean: + target = target - torch.mean(target, dim=-1, keepdim=True) + estimate = estimate - torch.mean(estimate, dim=-1, keepdim=True) + + noise = target - estimate + + snr_value = torch.sum(target**2, dim=-1) / (torch.sum(noise**2, dim=-1) + EPS) + snr_value = 10 * torch.log10(snr_value + EPS) + + return snr_value From a653056ea7e3b1b9cbdf7e42dc825a26144e3ee5 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 13 Jun 2021 05:51:21 +0000 Subject: [PATCH 002/109] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- torchmetrics/functional/audio/__init__.py | 2 +- torchmetrics/functional/audio/si_sdr.py | 4 ++-- torchmetrics/functional/audio/si_snr.py | 5 +++-- torchmetrics/functional/audio/snr.py | 2 +- 4 files changed, 7 insertions(+), 6 deletions(-) diff --git a/torchmetrics/functional/audio/__init__.py b/torchmetrics/functional/audio/__init__.py index c2230dd0926..191e8a04dae 100644 --- a/torchmetrics/functional/audio/__init__.py +++ b/torchmetrics/functional/audio/__init__.py @@ -11,6 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from torchmetrics.functional.audio.snr import snr from torchmetrics.functional.audio.si_sdr import si_sdr from torchmetrics.functional.audio.si_snr import si_snr +from torchmetrics.functional.audio.snr import snr diff --git a/torchmetrics/functional/audio/si_sdr.py b/torchmetrics/functional/audio/si_sdr.py index 98d5295027f..a5ad3ec8537 100644 --- a/torchmetrics/functional/audio/si_sdr.py +++ b/torchmetrics/functional/audio/si_sdr.py @@ -28,7 +28,7 @@ def si_sdr(target, estimate, zero_mean=False, EPS=1e-8): Returns: Tensor: si-sdr value has a shape of [...] - + Example: >>> from torchmetrics.functional.audio import si_sdr >>> target = torch.tensor([3.0, -0.5, 2.0, 7.0]) @@ -36,7 +36,7 @@ def si_sdr(target, estimate, zero_mean=False, EPS=1e-8): >>> si_sdr_val = si_sdr(target,estimate) >>> si_sdr_val tensor(18.4030) - + References: [1] Le Roux, Jonathan, et al. "SDR half-baked or well done." IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP) 2019. """ diff --git a/torchmetrics/functional/audio/si_snr.py b/torchmetrics/functional/audio/si_snr.py index dc74246de3b..3cb907f2ffb 100644 --- a/torchmetrics/functional/audio/si_snr.py +++ b/torchmetrics/functional/audio/si_snr.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import torch + from .si_sdr import si_sdr @@ -28,7 +29,7 @@ def si_snr(target, estimate, EPS=1e-8): Returns: Tensor: si-snr value has a shape of [...] - + Example: >>> from torchmetrics.functional.audio import si_snr >>> target = torch.tensor([3.0, -0.5, 2.0, 7.0]) @@ -36,7 +37,7 @@ def si_snr(target, estimate, EPS=1e-8): >>> si_snr_val = si_snr(target,estimate) >>> si_snr_val tensor(15.0918) - + References: [1] Y. Luo and N. Mesgarani, "TaSNet: Time-Domain Audio Separation Network for Real-Time, Single-Channel Speech Separation," 2018 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP), 2018, pp. 696-700, doi: 10.1109/ICASSP.2018.8462116. """ diff --git a/torchmetrics/functional/audio/snr.py b/torchmetrics/functional/audio/snr.py index ccdaa59df77..327ef20a128 100644 --- a/torchmetrics/functional/audio/snr.py +++ b/torchmetrics/functional/audio/snr.py @@ -29,7 +29,7 @@ def snr(target, estimate, zero_mean=False, EPS=1e-8) -> Tensor: Returns: Tensor: snr value has a shape of [...] - + Example: >>> from torchmetrics.functional.audio import snr >>> target = torch.tensor([3.0, -0.5, 2.0, 7.0]) From 70f82b43ab2dc55fbc273d011db3b4f208b32991 Mon Sep 17 00:00:00 2001 From: quancs Date: Sun, 13 Jun 2021 15:31:56 +0800 Subject: [PATCH 003/109] format --- torchmetrics/functional/audio/si_sdr.py | 7 ++++--- torchmetrics/functional/audio/si_snr.py | 8 +++++--- torchmetrics/functional/audio/snr.py | 5 +++-- 3 files changed, 12 insertions(+), 8 deletions(-) diff --git a/torchmetrics/functional/audio/si_sdr.py b/torchmetrics/functional/audio/si_sdr.py index 98d5295027f..536e2001bc8 100644 --- a/torchmetrics/functional/audio/si_sdr.py +++ b/torchmetrics/functional/audio/si_sdr.py @@ -28,7 +28,7 @@ def si_sdr(target, estimate, zero_mean=False, EPS=1e-8): Returns: Tensor: si-sdr value has a shape of [...] - + Example: >>> from torchmetrics.functional.audio import si_sdr >>> target = torch.tensor([3.0, -0.5, 2.0, 7.0]) @@ -36,9 +36,10 @@ def si_sdr(target, estimate, zero_mean=False, EPS=1e-8): >>> si_sdr_val = si_sdr(target,estimate) >>> si_sdr_val tensor(18.4030) - + References: - [1] Le Roux, Jonathan, et al. "SDR half-baked or well done." IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP) 2019. + [1] Le Roux, Jonathan, et al. "SDR half-baked or well done." IEEE International Conference on Acoustics, Speech + and Signal Processing (ICASSP) 2019. """ if target.shape != estimate.shape: diff --git a/torchmetrics/functional/audio/si_snr.py b/torchmetrics/functional/audio/si_snr.py index dc74246de3b..83372bc486e 100644 --- a/torchmetrics/functional/audio/si_snr.py +++ b/torchmetrics/functional/audio/si_snr.py @@ -28,7 +28,7 @@ def si_snr(target, estimate, EPS=1e-8): Returns: Tensor: si-snr value has a shape of [...] - + Example: >>> from torchmetrics.functional.audio import si_snr >>> target = torch.tensor([3.0, -0.5, 2.0, 7.0]) @@ -36,9 +36,11 @@ def si_snr(target, estimate, EPS=1e-8): >>> si_snr_val = si_snr(target,estimate) >>> si_snr_val tensor(15.0918) - + References: - [1] Y. Luo and N. Mesgarani, "TaSNet: Time-Domain Audio Separation Network for Real-Time, Single-Channel Speech Separation," 2018 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP), 2018, pp. 696-700, doi: 10.1109/ICASSP.2018.8462116. + [1] Y. Luo and N. Mesgarani, "TaSNet: Time-Domain Audio Separation Network for Real-Time, Single-Channel Speech + Separation," 2018 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP), 2018, pp. + 696-700, doi: 10.1109/ICASSP.2018.8462116. """ return si_sdr(target=target, estimate=estimate, zero_mean=True, EPS=EPS) diff --git a/torchmetrics/functional/audio/snr.py b/torchmetrics/functional/audio/snr.py index ccdaa59df77..d5322f58b05 100644 --- a/torchmetrics/functional/audio/snr.py +++ b/torchmetrics/functional/audio/snr.py @@ -29,7 +29,7 @@ def snr(target, estimate, zero_mean=False, EPS=1e-8) -> Tensor: Returns: Tensor: snr value has a shape of [...] - + Example: >>> from torchmetrics.functional.audio import snr >>> target = torch.tensor([3.0, -0.5, 2.0, 7.0]) @@ -39,7 +39,8 @@ def snr(target, estimate, zero_mean=False, EPS=1e-8) -> Tensor: tensor(16.1805) References: - [1] Le Roux, Jonathan, et al. "SDR half-baked or well done." IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP) 2019. + [1] Le Roux, Jonathan, et al. "SDR half-baked or well done." IEEE International Conference on Acoustics, Speech + and Signal Processing (ICASSP) 2019. """ if target.shape != estimate.shape: From 0331dd0312b9850266b2a476aa33c740d1d317cc Mon Sep 17 00:00:00 2001 From: quancs Date: Sun, 13 Jun 2021 16:34:26 +0800 Subject: [PATCH 004/109] add noqa: F401 to __init__.py --- torchmetrics/functional/audio/__init__.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torchmetrics/functional/audio/__init__.py b/torchmetrics/functional/audio/__init__.py index 191e8a04dae..d5bb919a914 100644 --- a/torchmetrics/functional/audio/__init__.py +++ b/torchmetrics/functional/audio/__init__.py @@ -11,6 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from torchmetrics.functional.audio.si_sdr import si_sdr -from torchmetrics.functional.audio.si_snr import si_snr -from torchmetrics.functional.audio.snr import snr +from torchmetrics.functional.audio.si_sdr import si_sdr # noqa: F401 +from torchmetrics.functional.audio.si_snr import si_snr # noqa: F401 +from torchmetrics.functional.audio.snr import snr # noqa: F401 From fcc827eeb1bbc3b73c8803bddbf600d68cc597ac Mon Sep 17 00:00:00 2001 From: quancs <1017241746@qq.com> Date: Sun, 13 Jun 2021 21:31:52 +0800 Subject: [PATCH 005/109] Update torchmetrics/functional/audio/si_sdr.py Co-authored-by: Nicki Skafte --- torchmetrics/functional/audio/si_sdr.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchmetrics/functional/audio/si_sdr.py b/torchmetrics/functional/audio/si_sdr.py index 536e2001bc8..b00bec8efd4 100644 --- a/torchmetrics/functional/audio/si_sdr.py +++ b/torchmetrics/functional/audio/si_sdr.py @@ -19,7 +19,7 @@ def si_sdr(target, estimate, zero_mean=False, EPS=1e-8): Args: target (Tensor): shape [..., time] - estimate (Tensor): shape [..., time] + preds (Tensor): shape [..., time] zero_mean (Bool): if to zero mean target and estimate or not EPS (float, optional): a small value for numerical stability. Defaults to 1e-8. From ef19181319d514a5e8ecb870d1458b8667ac7526 Mon Sep 17 00:00:00 2001 From: quancs <1017241746@qq.com> Date: Sun, 13 Jun 2021 21:33:31 +0800 Subject: [PATCH 006/109] Update torchmetrics/functional/audio/si_sdr.py Co-authored-by: Nicki Skafte --- torchmetrics/functional/audio/si_sdr.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchmetrics/functional/audio/si_sdr.py b/torchmetrics/functional/audio/si_sdr.py index b00bec8efd4..85677e2db7f 100644 --- a/torchmetrics/functional/audio/si_sdr.py +++ b/torchmetrics/functional/audio/si_sdr.py @@ -14,7 +14,7 @@ import torch -def si_sdr(target, estimate, zero_mean=False, EPS=1e-8): +def si_sdr(target: Tensor, estimate: Tensor, zero_mean: bool = False, EPS: bool = 1e-8) -> Tensor: """ scale-invariant signal-to-distortion ratio (SI-SDR) Args: From 80e6c53ce697391904e43623ce2c3584cfcc4c75 Mon Sep 17 00:00:00 2001 From: quancs <1017241746@qq.com> Date: Sun, 13 Jun 2021 21:33:41 +0800 Subject: [PATCH 007/109] Update torchmetrics/functional/audio/si_sdr.py Co-authored-by: Nicki Skafte --- torchmetrics/functional/audio/si_sdr.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torchmetrics/functional/audio/si_sdr.py b/torchmetrics/functional/audio/si_sdr.py index 85677e2db7f..3d9e87dd560 100644 --- a/torchmetrics/functional/audio/si_sdr.py +++ b/torchmetrics/functional/audio/si_sdr.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import torch +from torch import Tensor def si_sdr(target: Tensor, estimate: Tensor, zero_mean: bool = False, EPS: bool = 1e-8) -> Tensor: From 32c0ce08d8947a08bf2b2e94639a21dd6819efc6 Mon Sep 17 00:00:00 2001 From: quancs <1017241746@qq.com> Date: Sun, 13 Jun 2021 21:34:31 +0800 Subject: [PATCH 008/109] Update torchmetrics/functional/audio/si_snr.py Co-authored-by: Nicki Skafte --- torchmetrics/functional/audio/si_snr.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchmetrics/functional/audio/si_snr.py b/torchmetrics/functional/audio/si_snr.py index e97d13e6096..b6709dcd222 100644 --- a/torchmetrics/functional/audio/si_snr.py +++ b/torchmetrics/functional/audio/si_snr.py @@ -16,7 +16,7 @@ from .si_sdr import si_sdr -def si_snr(target, estimate, EPS=1e-8): +def si_snr(target: Tensor, estimate: Tensor, EPS: bool =1e-8) -> Tensor: """ scale-invariant signal-to-noise ratio (SI-SNR) Args: From 489dba0cf3b642ce33ab001488fc0ade21010967 Mon Sep 17 00:00:00 2001 From: quancs <1017241746@qq.com> Date: Sun, 13 Jun 2021 21:34:38 +0800 Subject: [PATCH 009/109] Update torchmetrics/functional/audio/snr.py Co-authored-by: Nicki Skafte --- torchmetrics/functional/audio/snr.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchmetrics/functional/audio/snr.py b/torchmetrics/functional/audio/snr.py index d5322f58b05..e29a136deeb 100644 --- a/torchmetrics/functional/audio/snr.py +++ b/torchmetrics/functional/audio/snr.py @@ -15,7 +15,7 @@ from torch import Tensor -def snr(target, estimate, zero_mean=False, EPS=1e-8) -> Tensor: +def snr(target: Tensor, estimate: Tensor, zero_mean: bool = False, EPS: bool = 1e-8) -> Tensor: """ signal-to-noise ratio (SNR) Args: From 249d848c07ac61f78ea9e0e0debc77281feb9299 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 13 Jun 2021 13:35:08 +0000 Subject: [PATCH 010/109] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- torchmetrics/functional/audio/si_snr.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchmetrics/functional/audio/si_snr.py b/torchmetrics/functional/audio/si_snr.py index b6709dcd222..b139cfa57e7 100644 --- a/torchmetrics/functional/audio/si_snr.py +++ b/torchmetrics/functional/audio/si_snr.py @@ -16,7 +16,7 @@ from .si_sdr import si_sdr -def si_snr(target: Tensor, estimate: Tensor, EPS: bool =1e-8) -> Tensor: +def si_snr(target: Tensor, estimate: Tensor, EPS: bool = 1e-8) -> Tensor: """ scale-invariant signal-to-noise ratio (SI-SNR) Args: From 017ba3a61a6993f8650040f1b40ba154ddd5c2e7 Mon Sep 17 00:00:00 2001 From: quancs Date: Sun, 13 Jun 2021 21:57:46 +0800 Subject: [PATCH 011/109] remove types in doc, change estimate to preds, remove EPS --- torchmetrics/functional/__init__.py | 3 +++ torchmetrics/functional/audio/si_sdr.py | 35 ++++++++++++++----------- torchmetrics/functional/audio/si_snr.py | 23 ++++++++-------- torchmetrics/functional/audio/snr.py | 33 ++++++++++++----------- 4 files changed, 52 insertions(+), 42 deletions(-) diff --git a/torchmetrics/functional/__init__.py b/torchmetrics/functional/__init__.py index 939987268b5..767e8d43525 100644 --- a/torchmetrics/functional/__init__.py +++ b/torchmetrics/functional/__init__.py @@ -50,3 +50,6 @@ from torchmetrics.functional.retrieval.recall import retrieval_recall # noqa: F401 from torchmetrics.functional.retrieval.reciprocal_rank import retrieval_reciprocal_rank # noqa: F401 from torchmetrics.functional.self_supervised import embedding_similarity # noqa: F401 +from torchmetrics.functional.audio.si_sdr import si_sdr # noqa: F401 +from torchmetrics.functional.audio.si_snr import si_snr # noqa: F401 +from torchmetrics.functional.audio.snr import snr # noqa: F401 diff --git a/torchmetrics/functional/audio/si_sdr.py b/torchmetrics/functional/audio/si_sdr.py index 3d9e87dd560..58c7976ccf1 100644 --- a/torchmetrics/functional/audio/si_sdr.py +++ b/torchmetrics/functional/audio/si_sdr.py @@ -15,26 +15,29 @@ from torch import Tensor -def si_sdr(target: Tensor, estimate: Tensor, zero_mean: bool = False, EPS: bool = 1e-8) -> Tensor: +def si_sdr(target: Tensor, preds: Tensor, zero_mean: bool = False) -> Tensor: """ scale-invariant signal-to-distortion ratio (SI-SDR) Args: - target (Tensor): shape [..., time] - preds (Tensor): shape [..., time] - zero_mean (Bool): if to zero mean target and estimate or not - EPS (float, optional): a small value for numerical stability. Defaults to 1e-8. + target: + shape [..., time] + preds: + shape [..., time] + zero_mean: + if to zero mean target and preds or not Raises: - TypeError: if target and estimate have a different shape + TypeError: + if target and preds have a different shape Returns: - Tensor: si-sdr value has a shape of [...] + si-sdr value of shape [...] Example: >>> from torchmetrics.functional.audio import si_sdr >>> target = torch.tensor([3.0, -0.5, 2.0, 7.0]) - >>> estimate = torch.tensor([2.5, 0.0, 2.0, 8.0]) - >>> si_sdr_val = si_sdr(target,estimate) + >>> preds = torch.tensor([2.5, 0.0, 2.0, 8.0]) + >>> si_sdr_val = si_sdr(target, preds) >>> si_sdr_val tensor(18.4030) @@ -43,19 +46,19 @@ def si_sdr(target: Tensor, estimate: Tensor, zero_mean: bool = False, EPS: bool and Signal Processing (ICASSP) 2019. """ - if target.shape != estimate.shape: - raise TypeError(f"Inputs must be of shape [..., time], got {target.shape} and {estimate.shape} instead") + if target.shape != preds.shape: + raise TypeError(f"Inputs must be of shape [..., time], got {target.shape} and {preds.shape} instead") if zero_mean: target = target - torch.mean(target, dim=-1, keepdim=True) - estimate = estimate - torch.mean(estimate, dim=-1, keepdim=True) + preds = preds - torch.mean(preds, dim=-1, keepdim=True) - α = torch.sum(estimate * target, dim=-1, keepdim=True) / (torch.sum(target**2, dim=-1, keepdim=True) + EPS) + α = torch.sum(preds * target, dim=-1, keepdim=True) / (torch.sum(target**2, dim=-1, keepdim=True) + 1e-8) target_scaled = α * target - noise = target_scaled - estimate + noise = target_scaled - preds - si_sdr_value = torch.sum(target_scaled**2, dim=-1) / (torch.sum(noise**2, dim=-1) + EPS) - si_sdr_value = 10 * torch.log10(si_sdr_value + EPS) + si_sdr_value = torch.sum(target_scaled**2, dim=-1) / (torch.sum(noise**2, dim=-1) + 1e-8) + si_sdr_value = 10 * torch.log10(si_sdr_value + 1e-8) return si_sdr_value diff --git a/torchmetrics/functional/audio/si_snr.py b/torchmetrics/functional/audio/si_snr.py index b139cfa57e7..b78714989fa 100644 --- a/torchmetrics/functional/audio/si_snr.py +++ b/torchmetrics/functional/audio/si_snr.py @@ -11,30 +11,31 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import torch - +from torch import Tensor from .si_sdr import si_sdr -def si_snr(target: Tensor, estimate: Tensor, EPS: bool = 1e-8) -> Tensor: +def si_snr(target: Tensor, preds: Tensor) -> Tensor: """ scale-invariant signal-to-noise ratio (SI-SNR) Args: - target (Tensor): shape [..., time] - estimate (Tensor): shape [..., time] - EPS (float, optional): a small value for numerical stability. Defaults to 1e-8. + target: + shape [..., time] + preds: + shape [..., time] Raises: - TypeError: if target and estimate have a different shape + TypeError: + if target and preds have a different shape Returns: - Tensor: si-snr value has a shape of [...] + si-snr value of shape [...] Example: >>> from torchmetrics.functional.audio import si_snr >>> target = torch.tensor([3.0, -0.5, 2.0, 7.0]) - >>> estimate = torch.tensor([2.5, 0.0, 2.0, 8.0]) - >>> si_snr_val = si_snr(target,estimate) + >>> preds = torch.tensor([2.5, 0.0, 2.0, 8.0]) + >>> si_snr_val = si_snr(target, preds) >>> si_snr_val tensor(15.0918) @@ -44,4 +45,4 @@ def si_snr(target: Tensor, estimate: Tensor, EPS: bool = 1e-8) -> Tensor: 696-700, doi: 10.1109/ICASSP.2018.8462116. """ - return si_sdr(target=target, estimate=estimate, zero_mean=True, EPS=EPS) + return si_sdr(target=target, preds=preds, zero_mean=True) diff --git a/torchmetrics/functional/audio/snr.py b/torchmetrics/functional/audio/snr.py index e29a136deeb..17878406538 100644 --- a/torchmetrics/functional/audio/snr.py +++ b/torchmetrics/functional/audio/snr.py @@ -15,26 +15,29 @@ from torch import Tensor -def snr(target: Tensor, estimate: Tensor, zero_mean: bool = False, EPS: bool = 1e-8) -> Tensor: +def snr(target: Tensor, preds: Tensor, zero_mean: bool = False) -> Tensor: """ signal-to-noise ratio (SNR) Args: - target (Tensor): shape [..., time] - estimate (Tensor): shape [..., time] - zero_mean (Bool): if to zero mean target and estimate or not - EPS (float, optional): a small value for numerical stability. Defaults to 1e-8. + target: + shape [..., time] + preds: + shape [..., time] + zero_mean: + if to zero mean target and preds or not Raises: - TypeError: if target and estimate have a different shape + TypeError: + if target and preds have a different shape Returns: - Tensor: snr value has a shape of [...] + snr value of shape [...] Example: >>> from torchmetrics.functional.audio import snr >>> target = torch.tensor([3.0, -0.5, 2.0, 7.0]) - >>> estimate = torch.tensor([2.5, 0.0, 2.0, 8.0]) - >>> snr_val = snr(target,estimate) + >>> preds = torch.tensor([2.5, 0.0, 2.0, 8.0]) + >>> snr_val = snr(target, preds) >>> snr_val tensor(16.1805) @@ -43,16 +46,16 @@ def snr(target: Tensor, estimate: Tensor, zero_mean: bool = False, EPS: bool = 1 and Signal Processing (ICASSP) 2019. """ - if target.shape != estimate.shape: - raise TypeError(f"Inputs must be of shape [..., time], got {target.shape} and {estimate.shape} instead") + if target.shape != preds.shape: + raise TypeError(f"Inputs must be of shape [..., time], got {target.shape} and {preds.shape} instead") if zero_mean: target = target - torch.mean(target, dim=-1, keepdim=True) - estimate = estimate - torch.mean(estimate, dim=-1, keepdim=True) + preds = preds - torch.mean(preds, dim=-1, keepdim=True) - noise = target - estimate + noise = target - preds - snr_value = torch.sum(target**2, dim=-1) / (torch.sum(noise**2, dim=-1) + EPS) - snr_value = 10 * torch.log10(snr_value + EPS) + snr_value = torch.sum(target**2, dim=-1) / (torch.sum(noise**2, dim=-1) + 1e-8) + snr_value = 10 * torch.log10(snr_value + 1e-8) return snr_value From c8aa372d4728dd5eeaf186cf2e7aadc190980e61 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 13 Jun 2021 13:58:26 +0000 Subject: [PATCH 012/109] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- torchmetrics/functional/__init__.py | 6 +++--- torchmetrics/functional/audio/si_snr.py | 1 + 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/torchmetrics/functional/__init__.py b/torchmetrics/functional/__init__.py index 767e8d43525..df25a3e7269 100644 --- a/torchmetrics/functional/__init__.py +++ b/torchmetrics/functional/__init__.py @@ -11,6 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from torchmetrics.functional.audio.si_sdr import si_sdr # noqa: F401 +from torchmetrics.functional.audio.si_snr import si_snr # noqa: F401 +from torchmetrics.functional.audio.snr import snr # noqa: F401 from torchmetrics.functional.classification.accuracy import accuracy # noqa: F401 from torchmetrics.functional.classification.auc import auc # noqa: F401 from torchmetrics.functional.classification.auroc import auroc # noqa: F401 @@ -50,6 +53,3 @@ from torchmetrics.functional.retrieval.recall import retrieval_recall # noqa: F401 from torchmetrics.functional.retrieval.reciprocal_rank import retrieval_reciprocal_rank # noqa: F401 from torchmetrics.functional.self_supervised import embedding_similarity # noqa: F401 -from torchmetrics.functional.audio.si_sdr import si_sdr # noqa: F401 -from torchmetrics.functional.audio.si_snr import si_snr # noqa: F401 -from torchmetrics.functional.audio.snr import snr # noqa: F401 diff --git a/torchmetrics/functional/audio/si_snr.py b/torchmetrics/functional/audio/si_snr.py index b78714989fa..abcaf0caa69 100644 --- a/torchmetrics/functional/audio/si_snr.py +++ b/torchmetrics/functional/audio/si_snr.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from torch import Tensor + from .si_sdr import si_sdr From c82a09468c06d01957e53ee8f6354c25341dd2ae Mon Sep 17 00:00:00 2001 From: quancs Date: Sun, 13 Jun 2021 22:08:48 +0800 Subject: [PATCH 013/109] update functional.rst --- docs/source/references/functional.rst | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/docs/source/references/functional.rst b/docs/source/references/functional.rst index 7de8d14c065..c5d7a6eac80 100644 --- a/docs/source/references/functional.rst +++ b/docs/source/references/functional.rst @@ -308,3 +308,29 @@ retrieval_normalized_dcg [func] .. autofunction:: torchmetrics.functional.retrieval_normalized_dcg :noindex: + + +****************** +Audio Metrics +****************** + +snr [func] +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autofunction:: torchmetrics.functional.snr + :noindex: + + +si_sdr [func] +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autofunction:: torchmetrics.functional.si_sdr + :noindex: + + +si_snr [func] +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autofunction:: torchmetrics.functional.si_snr + :noindex: + From a33ff6ec3a1c51e4fb1747f9bc0406dcddf5d3a7 Mon Sep 17 00:00:00 2001 From: quancs Date: Sun, 13 Jun 2021 22:11:22 +0800 Subject: [PATCH 014/109] update CHANGELOG.md --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 5955815102b..a3d801e0612 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -33,6 +33,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added `is_differentiable` property to `ConfusionMatrix`, `F1`, `FBeta`, `Hamming`, `Hinge`, `IOU`, `MatthewsCorrcoef`, `Precision`, `Recall`, `PrecisionRecallCurve`, `ROC`, `StatScores` ([#253](https://github.com/PyTorchLightning/metrics/pull/253)) +- Added audio metrics: SNR, SI_SDR, SI_SNR ([#292](https://github.com/PyTorchLightning/metrics/pull/292)) + + ### Changed - Forward cache is now reset when `reset` method is called ([#260](https://github.com/PyTorchLightning/metrics/pull/260)) From 282681dc46eb8d70e994e8c466c260c6e16725a6 Mon Sep 17 00:00:00 2001 From: quancs Date: Sun, 13 Jun 2021 22:27:34 +0800 Subject: [PATCH 015/109] switch preds and target --- torchmetrics/functional/audio/si_sdr.py | 6 +++--- torchmetrics/functional/audio/si_snr.py | 7 ++++--- torchmetrics/functional/audio/snr.py | 6 +++--- 3 files changed, 10 insertions(+), 9 deletions(-) diff --git a/torchmetrics/functional/audio/si_sdr.py b/torchmetrics/functional/audio/si_sdr.py index 58c7976ccf1..00d0cd745c5 100644 --- a/torchmetrics/functional/audio/si_sdr.py +++ b/torchmetrics/functional/audio/si_sdr.py @@ -15,14 +15,14 @@ from torch import Tensor -def si_sdr(target: Tensor, preds: Tensor, zero_mean: bool = False) -> Tensor: +def si_sdr(preds: Tensor, target: Tensor, zero_mean: bool = False) -> Tensor: """ scale-invariant signal-to-distortion ratio (SI-SDR) Args: - target: - shape [..., time] preds: shape [..., time] + target: + shape [..., time] zero_mean: if to zero mean target and preds or not diff --git a/torchmetrics/functional/audio/si_snr.py b/torchmetrics/functional/audio/si_snr.py index b78714989fa..5b819ef765a 100644 --- a/torchmetrics/functional/audio/si_snr.py +++ b/torchmetrics/functional/audio/si_snr.py @@ -15,14 +15,14 @@ from .si_sdr import si_sdr -def si_snr(target: Tensor, preds: Tensor) -> Tensor: +def si_snr(preds: Tensor, target: Tensor) -> Tensor: """ scale-invariant signal-to-noise ratio (SI-SNR) Args: - target: - shape [..., time] preds: shape [..., time] + target: + shape [..., time] Raises: TypeError: @@ -32,6 +32,7 @@ def si_snr(target: Tensor, preds: Tensor) -> Tensor: si-snr value of shape [...] Example: + >>> import torch >>> from torchmetrics.functional.audio import si_snr >>> target = torch.tensor([3.0, -0.5, 2.0, 7.0]) >>> preds = torch.tensor([2.5, 0.0, 2.0, 8.0]) diff --git a/torchmetrics/functional/audio/snr.py b/torchmetrics/functional/audio/snr.py index 17878406538..c1a2c93c7d0 100644 --- a/torchmetrics/functional/audio/snr.py +++ b/torchmetrics/functional/audio/snr.py @@ -15,14 +15,14 @@ from torch import Tensor -def snr(target: Tensor, preds: Tensor, zero_mean: bool = False) -> Tensor: +def snr(preds: Tensor, target: Tensor, zero_mean: bool = False) -> Tensor: """ signal-to-noise ratio (SNR) Args: - target: - shape [..., time] preds: shape [..., time] + target: + shape [..., time] zero_mean: if to zero mean target and preds or not From 2ba693c264e7965ca8f518f968201495b5719abd Mon Sep 17 00:00:00 2001 From: quancs Date: Sun, 13 Jun 2021 23:39:54 +0800 Subject: [PATCH 016/109] switch preds and target in Example --- torchmetrics/functional/audio/si_sdr.py | 2 +- torchmetrics/functional/audio/si_snr.py | 2 +- torchmetrics/functional/audio/snr.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/torchmetrics/functional/audio/si_sdr.py b/torchmetrics/functional/audio/si_sdr.py index 00d0cd745c5..930a4f6292d 100644 --- a/torchmetrics/functional/audio/si_sdr.py +++ b/torchmetrics/functional/audio/si_sdr.py @@ -37,7 +37,7 @@ def si_sdr(preds: Tensor, target: Tensor, zero_mean: bool = False) -> Tensor: >>> from torchmetrics.functional.audio import si_sdr >>> target = torch.tensor([3.0, -0.5, 2.0, 7.0]) >>> preds = torch.tensor([2.5, 0.0, 2.0, 8.0]) - >>> si_sdr_val = si_sdr(target, preds) + >>> si_sdr_val = si_sdr(preds, target) >>> si_sdr_val tensor(18.4030) diff --git a/torchmetrics/functional/audio/si_snr.py b/torchmetrics/functional/audio/si_snr.py index 5b819ef765a..c9ce564f77b 100644 --- a/torchmetrics/functional/audio/si_snr.py +++ b/torchmetrics/functional/audio/si_snr.py @@ -36,7 +36,7 @@ def si_snr(preds: Tensor, target: Tensor) -> Tensor: >>> from torchmetrics.functional.audio import si_snr >>> target = torch.tensor([3.0, -0.5, 2.0, 7.0]) >>> preds = torch.tensor([2.5, 0.0, 2.0, 8.0]) - >>> si_snr_val = si_snr(target, preds) + >>> si_snr_val = si_snr(preds, target) >>> si_snr_val tensor(15.0918) diff --git a/torchmetrics/functional/audio/snr.py b/torchmetrics/functional/audio/snr.py index c1a2c93c7d0..e84cf50d130 100644 --- a/torchmetrics/functional/audio/snr.py +++ b/torchmetrics/functional/audio/snr.py @@ -37,7 +37,7 @@ def snr(preds: Tensor, target: Tensor, zero_mean: bool = False) -> Tensor: >>> from torchmetrics.functional.audio import snr >>> target = torch.tensor([3.0, -0.5, 2.0, 7.0]) >>> preds = torch.tensor([2.5, 0.0, 2.0, 8.0]) - >>> snr_val = snr(target, preds) + >>> snr_val = snr(preds, target) >>> snr_val tensor(16.1805) From b0a83820c996fd727d20914aef8cc2b9b3845fef Mon Sep 17 00:00:00 2001 From: quancs Date: Mon, 14 Jun 2021 00:01:50 +0800 Subject: [PATCH 017/109] add SNR, SI_SNR, SI_SDR module implementation --- torchmetrics/__init__.py | 5 ++ torchmetrics/audio/SI_SDR.py | 106 +++++++++++++++++++++++++++++++++ torchmetrics/audio/SI_SNR.py | 101 +++++++++++++++++++++++++++++++ torchmetrics/audio/SNR.py | 104 ++++++++++++++++++++++++++++++++ torchmetrics/audio/__init__.py | 16 +++++ 5 files changed, 332 insertions(+) create mode 100644 torchmetrics/audio/SI_SDR.py create mode 100644 torchmetrics/audio/SI_SNR.py create mode 100644 torchmetrics/audio/SNR.py create mode 100644 torchmetrics/audio/__init__.py diff --git a/torchmetrics/__init__.py b/torchmetrics/__init__.py index 52a48638a6f..e626ae99613 100644 --- a/torchmetrics/__init__.py +++ b/torchmetrics/__init__.py @@ -58,3 +58,8 @@ RetrievalRecall, ) from torchmetrics.wrappers import BootStrapper # noqa: F401 E402 +from torchmetrics.audio import ( # noqa: F401 E402 + SNR, + SI_SNR, + SI_SDR, +) diff --git a/torchmetrics/audio/SI_SDR.py b/torchmetrics/audio/SI_SDR.py new file mode 100644 index 00000000000..774f60a2c64 --- /dev/null +++ b/torchmetrics/audio/SI_SDR.py @@ -0,0 +1,106 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Callable, Optional +from torch import Tensor, tensor +from torchmetrics.functional.audio.si_sdr import si_sdr +from torchmetrics.metric import Metric + + +class SI_SDR(Metric): + """ scale-invariant signal-to-distortion ratio (SI-SDR) + + Forward accepts + + - ``preds`` (Tensor): ``shape [..., time]`` + - ``target`` (Tensor): ``shape [..., time]`` + + Args: + zero_mean: + if to zero mean target and preds or not + compute_on_step: + Forward only calls ``update()`` and returns None if this is set to False. default: True + dist_sync_on_step: + Synchronize metric state across processes at each ``forward()`` + before returning the value at the step. + process_group: + Specify the process group on which synchronization is called. default: None (which selects the entire world) + dist_sync_fn: + Callback that performs the allgather operation on the metric state. When `None`, DDP + will be used to perform the allgather. + + Raises: + TypeError: + if target and preds have a different shape + + Returns: + average si-sdr value + + Example: + >>> import torch + >>> from torchmetrics import SI_SDR + >>> target = torch.tensor([3.0, -0.5, 2.0, 7.0]) + >>> preds = torch.tensor([2.5, 0.0, 2.0, 8.0]) + >>> si_sdr = SI_SDR() + >>> si_sdr_val = si_sdr(preds, target) + >>> si_sdr_val + tensor(18.4030) + + References: + [1] Le Roux, Jonathan, et al. "SDR half-baked or well done." IEEE International Conference on Acoustics, Speech + and Signal Processing (ICASSP) 2019. + """ + + def __init__( + self, + zero_mean: bool = False, + compute_on_step: bool = True, + dist_sync_on_step: bool = False, + process_group: Optional[Any] = None, + dist_sync_fn: Callable = None, + ): + super().__init__( + compute_on_step=compute_on_step, + dist_sync_on_step=dist_sync_on_step, + process_group=process_group, + dist_sync_fn=dist_sync_fn, + ) + self.zero_mean = zero_mean + + self.add_state("sum_si_sdr", default=tensor(0.0), dist_reduce_fx="sum") + self.add_state("total", default=tensor(0), dist_reduce_fx="sum") + + def update(self, preds: Tensor, target: Tensor): + """ + Update state with predictions and targets. + + Args: + preds: Predictions from model + target: Ground truth values + """ + si_sdr_batch = si_sdr(preds=preds, + target=target, + zero_mean=self.zero_mean) + + self.sum_si_sdr += si_sdr_batch.sum() + self.total += si_sdr_batch.numel() + + def compute(self): + """ + Computes average SI-SDR. + """ + return self.sum_si_sdr / self.total + + @property + def is_differentiable(self) -> bool: + return True diff --git a/torchmetrics/audio/SI_SNR.py b/torchmetrics/audio/SI_SNR.py new file mode 100644 index 00000000000..c9107f5b12f --- /dev/null +++ b/torchmetrics/audio/SI_SNR.py @@ -0,0 +1,101 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Callable, Optional +from torch import Tensor, tensor +from torchmetrics.functional.audio.si_snr import si_snr +from torchmetrics.metric import Metric + + +class SI_SNR(Metric): + """ scale-invariant signal-to-noise ratio (SI-SNR) + + Forward accepts + + - ``preds`` (Tensor): ``shape [..., time]`` + - ``target`` (Tensor): ``shape [..., time]`` + + Args: + compute_on_step: + Forward only calls ``update()`` and returns None if this is set to False. default: True + dist_sync_on_step: + Synchronize metric state across processes at each ``forward()`` + before returning the value at the step. + process_group: + Specify the process group on which synchronization is called. default: None (which selects the entire world) + dist_sync_fn: + Callback that performs the allgather operation on the metric state. When `None`, DDP + will be used to perform the allgather. + + Raises: + TypeError: + if target and preds have a different shape + + Returns: + average si-snr value + + Example: + >>> import torch + >>> from torchmetrics import SI_SNR + >>> target = torch.tensor([3.0, -0.5, 2.0, 7.0]) + >>> preds = torch.tensor([2.5, 0.0, 2.0, 8.0]) + >>> si_snr = SI_SNR() + >>> si_snr_val = si_snr(preds, target) + >>> si_snr_val + tensor(15.0918) + + References: + [1] Y. Luo and N. Mesgarani, "TaSNet: Time-Domain Audio Separation Network for Real-Time, Single-Channel Speech + Separation," 2018 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP), 2018, pp. + 696-700, doi: 10.1109/ICASSP.2018.8462116. + """ + + def __init__( + self, + compute_on_step: bool = True, + dist_sync_on_step: bool = False, + process_group: Optional[Any] = None, + dist_sync_fn: Callable = None, + ): + super().__init__( + compute_on_step=compute_on_step, + dist_sync_on_step=dist_sync_on_step, + process_group=process_group, + dist_sync_fn=dist_sync_fn, + ) + + self.add_state("sum_si_snr", default=tensor(0.0), dist_reduce_fx="sum") + self.add_state("total", default=tensor(0), dist_reduce_fx="sum") + + def update(self, preds: Tensor, target: Tensor): + """ + Update state with predictions and targets. + + Args: + preds: Predictions from model + target: Ground truth values + """ + si_snr_batch = si_snr(preds=preds, target=target) + + self.sum_si_snr += si_snr_batch.sum() + self.total += si_snr_batch.numel() + + def compute(self): + """ + Computes average SI-SNR. + """ + return self.sum_si_snr / self.total + + @property + def is_differentiable(self) -> bool: + return True diff --git a/torchmetrics/audio/SNR.py b/torchmetrics/audio/SNR.py new file mode 100644 index 00000000000..a31232eb5c7 --- /dev/null +++ b/torchmetrics/audio/SNR.py @@ -0,0 +1,104 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Callable, Optional +from torch import Tensor, tensor +from torchmetrics.functional.audio.snr import snr +from torchmetrics.metric import Metric + + +class SNR(Metric): + """ signal-to-noise ratio (SNR) + + Forward accepts + + - ``preds`` (Tensor): ``shape [..., time]`` + - ``target`` (Tensor): ``shape [..., time]`` + + Args: + zero_mean: + if to zero mean target and preds or not + compute_on_step: + Forward only calls ``update()`` and returns None if this is set to False. default: True + dist_sync_on_step: + Synchronize metric state across processes at each ``forward()`` + before returning the value at the step. + process_group: + Specify the process group on which synchronization is called. default: None (which selects the entire world) + dist_sync_fn: + Callback that performs the allgather operation on the metric state. When `None`, DDP + will be used to perform the allgather. + + Raises: + TypeError: + if target and preds have a different shape + + Returns: + average snr value + + Example: + >>> import torch + >>> from torchmetrics import SNR + >>> target = torch.tensor([3.0, -0.5, 2.0, 7.0]) + >>> preds = torch.tensor([2.5, 0.0, 2.0, 8.0]) + >>> snr = SNR() + >>> snr_val = snr(preds, target) + >>> snr_val + tensor(16.1805) + + References: + [1] Le Roux, Jonathan, et al. "SDR half-baked or well done." IEEE International Conference on Acoustics, Speech + and Signal Processing (ICASSP) 2019. + """ + + def __init__( + self, + zero_mean: bool = False, + compute_on_step: bool = True, + dist_sync_on_step: bool = False, + process_group: Optional[Any] = None, + dist_sync_fn: Callable = None, + ): + super().__init__( + compute_on_step=compute_on_step, + dist_sync_on_step=dist_sync_on_step, + process_group=process_group, + dist_sync_fn=dist_sync_fn, + ) + self.zero_mean = zero_mean + + self.add_state("sum_snr", default=tensor(0.0), dist_reduce_fx="sum") + self.add_state("total", default=tensor(0), dist_reduce_fx="sum") + + def update(self, preds: Tensor, target: Tensor): + """ + Update state with predictions and targets. + + Args: + preds: Predictions from model + target: Ground truth values + """ + snr_batch = snr(preds=preds, target=target, zero_mean=self.zero_mean) + + self.sum_snr += snr_batch.sum() + self.total += snr_batch.numel() + + def compute(self): + """ + Computes average SNR. + """ + return self.sum_snr / self.total + + @property + def is_differentiable(self) -> bool: + return True diff --git a/torchmetrics/audio/__init__.py b/torchmetrics/audio/__init__.py new file mode 100644 index 00000000000..073bf633ded --- /dev/null +++ b/torchmetrics/audio/__init__.py @@ -0,0 +1,16 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from torchmetrics.audio.SNR import SNR # noqa: F401 +from torchmetrics.audio.SI_SNR import SI_SNR # noqa: F401 +from torchmetrics.audio.SI_SDR import SI_SDR # noqa: F401 From db57b13664ab65d23d2d580d20ec193e7e77b9b3 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 13 Jun 2021 16:02:27 +0000 Subject: [PATCH 018/109] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- docs/source/references/functional.rst | 1 - torchmetrics/__init__.py | 6 +----- torchmetrics/audio/SI_SDR.py | 6 +++--- torchmetrics/audio/SI_SNR.py | 2 ++ torchmetrics/audio/SNR.py | 2 ++ torchmetrics/audio/__init__.py | 4 ++-- 6 files changed, 10 insertions(+), 11 deletions(-) diff --git a/docs/source/references/functional.rst b/docs/source/references/functional.rst index c5d7a6eac80..85ae7f7c81c 100644 --- a/docs/source/references/functional.rst +++ b/docs/source/references/functional.rst @@ -333,4 +333,3 @@ si_snr [func] .. autofunction:: torchmetrics.functional.si_snr :noindex: - diff --git a/torchmetrics/__init__.py b/torchmetrics/__init__.py index e626ae99613..8e31018a634 100644 --- a/torchmetrics/__init__.py +++ b/torchmetrics/__init__.py @@ -11,6 +11,7 @@ _PACKAGE_ROOT = os.path.dirname(__file__) _PROJECT_ROOT = os.path.dirname(_PACKAGE_ROOT) +from torchmetrics.audio import SI_SDR, SI_SNR, SNR # noqa: F401 E402 from torchmetrics.average import AverageMeter # noqa: F401 E402 from torchmetrics.classification import ( # noqa: F401 E402 AUC, @@ -58,8 +59,3 @@ RetrievalRecall, ) from torchmetrics.wrappers import BootStrapper # noqa: F401 E402 -from torchmetrics.audio import ( # noqa: F401 E402 - SNR, - SI_SNR, - SI_SDR, -) diff --git a/torchmetrics/audio/SI_SDR.py b/torchmetrics/audio/SI_SDR.py index 774f60a2c64..28c9872c522 100644 --- a/torchmetrics/audio/SI_SDR.py +++ b/torchmetrics/audio/SI_SDR.py @@ -12,7 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. from typing import Any, Callable, Optional + from torch import Tensor, tensor + from torchmetrics.functional.audio.si_sdr import si_sdr from torchmetrics.metric import Metric @@ -88,9 +90,7 @@ def update(self, preds: Tensor, target: Tensor): preds: Predictions from model target: Ground truth values """ - si_sdr_batch = si_sdr(preds=preds, - target=target, - zero_mean=self.zero_mean) + si_sdr_batch = si_sdr(preds=preds, target=target, zero_mean=self.zero_mean) self.sum_si_sdr += si_sdr_batch.sum() self.total += si_sdr_batch.numel() diff --git a/torchmetrics/audio/SI_SNR.py b/torchmetrics/audio/SI_SNR.py index c9107f5b12f..458f70dd59f 100644 --- a/torchmetrics/audio/SI_SNR.py +++ b/torchmetrics/audio/SI_SNR.py @@ -12,7 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. from typing import Any, Callable, Optional + from torch import Tensor, tensor + from torchmetrics.functional.audio.si_snr import si_snr from torchmetrics.metric import Metric diff --git a/torchmetrics/audio/SNR.py b/torchmetrics/audio/SNR.py index a31232eb5c7..31e3a000116 100644 --- a/torchmetrics/audio/SNR.py +++ b/torchmetrics/audio/SNR.py @@ -12,7 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. from typing import Any, Callable, Optional + from torch import Tensor, tensor + from torchmetrics.functional.audio.snr import snr from torchmetrics.metric import Metric diff --git a/torchmetrics/audio/__init__.py b/torchmetrics/audio/__init__.py index 073bf633ded..035f1cd38fb 100644 --- a/torchmetrics/audio/__init__.py +++ b/torchmetrics/audio/__init__.py @@ -11,6 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from torchmetrics.audio.SNR import SNR # noqa: F401 -from torchmetrics.audio.SI_SNR import SI_SNR # noqa: F401 from torchmetrics.audio.SI_SDR import SI_SDR # noqa: F401 +from torchmetrics.audio.SI_SNR import SI_SNR # noqa: F401 +from torchmetrics.audio.SNR import SNR # noqa: F401 From fcfe0ac76e576a5fe925e5f06ce4174ad1c9269d Mon Sep 17 00:00:00 2001 From: quancs Date: Mon, 14 Jun 2021 02:32:53 +0800 Subject: [PATCH 019/109] add test --- tests/audio/__init__.py | 0 tests/audio/test_si_sdr.py | 121 ++++++++++++++++++++++++++++++++++++ tests/audio/test_si_snr.py | 107 ++++++++++++++++++++++++++++++++ tests/audio/test_snr.py | 124 +++++++++++++++++++++++++++++++++++++ 4 files changed, 352 insertions(+) create mode 100644 tests/audio/__init__.py create mode 100644 tests/audio/test_si_sdr.py create mode 100644 tests/audio/test_si_snr.py create mode 100644 tests/audio/test_snr.py diff --git a/tests/audio/__init__.py b/tests/audio/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tests/audio/test_si_sdr.py b/tests/audio/test_si_sdr.py new file mode 100644 index 00000000000..a178f71d68b --- /dev/null +++ b/tests/audio/test_si_sdr.py @@ -0,0 +1,121 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from collections import namedtuple +from functools import partial + +import pytest +import torch +from asteroid.losses import PairwiseNegSDR + +from tests.helpers import seed_all +from tests.helpers.testers import BATCH_SIZE, NUM_BATCHES, MetricTester +from torchmetrics.functional import si_sdr +from torchmetrics.audio import SI_SDR +from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_6 + +seed_all(42) + +Time = 1000 + +Input = namedtuple('Input', ["preds", "target"]) + +inputs = Input( + preds=torch.rand(NUM_BATCHES, BATCH_SIZE, 1, Time), + target=torch.rand(NUM_BATCHES, BATCH_SIZE, 1, Time), +) + + +def asteroid_metric(preds, target, asteroid_loss_func): + # shape: preds [BATCH_SIZE, 1, Time] , target [BATCH_SIZE, 1, Time] + metric = -asteroid_loss_func(preds, target) + return metric.view(BATCH_SIZE, 1) + + +def average_metric(preds, target, metric_func): + # shape: preds [BATCH_SIZE, 1, Time] , target [BATCH_SIZE, 1, Time] + return metric_func(preds, target).mean() + + +asteroid_sisdr_zero_mean = partial(asteroid_metric, + asteroid_loss_func=PairwiseNegSDR("sisdr")) +asteroid_sisdr_no_zero_mean = partial(asteroid_metric, + asteroid_loss_func=PairwiseNegSDR( + "sisdr", zero_mean=False)) + + +@pytest.mark.parametrize( + "preds, target, sk_metric, zero_mean", + [ + (inputs.preds, inputs.target, asteroid_sisdr_zero_mean, True), + (inputs.preds, inputs.target, asteroid_sisdr_no_zero_mean, False), + ], +) +class TestSISDR(MetricTester): + + @pytest.mark.parametrize("ddp", [True, False]) + @pytest.mark.parametrize("dist_sync_on_step", [True, False]) + def test_si_sdr(self, preds, target, sk_metric, zero_mean, ddp, + dist_sync_on_step): + self.run_class_metric_test( + ddp, + preds, + target, + SI_SDR, + sk_metric=sk_metric, + dist_sync_on_step=dist_sync_on_step, + metric_args=dict(zero_mean=zero_mean), + ) + + def test_si_sdr_functional(self, preds, target, sk_metric, zero_mean): + self.run_functional_metric_test( + preds, + target, + si_sdr, + sk_metric, + metric_args=dict(zero_mean=zero_mean), + ) + + def test_si_sdr_differentiability(self, preds, target, sk_metric, + zero_mean): + self.run_differentiability_test(preds=preds, + target=target, + metric_module=SI_SDR, + metric_functional=si_sdr, + metric_args={'zero_mean': zero_mean}) + + @pytest.mark.skipif( + not _TORCH_GREATER_EQUAL_1_6, + reason= + 'half support of core operations on not support before pytorch v1.6') + def test_si_sdr_half_cpu(self, preds, target, sk_metric, zero_mean): + self.run_precision_test_cpu(preds=preds, + target=target, + metric_module=SI_SDR, + metric_functional=si_sdr, + metric_args={'zero_mean': zero_mean}) + + @pytest.mark.skipif(not torch.cuda.is_available(), + reason='test requires cuda') + def test_si_sdr_half_gpu(self, preds, target, sk_metric, zero_mean): + self.run_precision_test_gpu(preds=preds, + target=target, + metric_module=SI_SDR, + metric_functional=si_sdr, + metric_args={'zero_mean': zero_mean}) + + +def test_error_on_different_shape(metric_class=SI_SDR): + metric = metric_class() + with pytest.raises(ValueError, match='Inputs must be of shape*'): + metric(torch.randn(100,), torch.randn(50,)) diff --git a/tests/audio/test_si_snr.py b/tests/audio/test_si_snr.py new file mode 100644 index 00000000000..0a57f92ea65 --- /dev/null +++ b/tests/audio/test_si_snr.py @@ -0,0 +1,107 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from collections import namedtuple +from functools import partial + +import pytest +import torch +from asteroid.losses import pairwise_neg_sisdr + +from tests.helpers import seed_all +from tests.helpers.testers import BATCH_SIZE, NUM_BATCHES, MetricTester +from torchmetrics.functional import si_snr +from torchmetrics.audio import SI_SNR +from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_6 + +seed_all(42) + +Time = 1000 + +Input = namedtuple('Input', ["preds", "target"]) + +inputs = Input( + preds=torch.rand(NUM_BATCHES, BATCH_SIZE, 1, Time), + target=torch.rand(NUM_BATCHES, BATCH_SIZE, 1, Time), +) + + +def asteroid_si_snr(preds, target): + # shape: preds [BATCH_SIZE, 1, Time] , target [BATCH_SIZE, 1, Time] + si_snr_v = -pairwise_neg_sisdr(preds, target) + return si_snr_v.view(BATCH_SIZE, 1) + + +def average_metric(preds, target, metric_func): + # shape: preds [BATCH_SIZE, 1, Time] , target [BATCH_SIZE, 1, Time] + return metric_func(preds, target).mean() + + +@pytest.mark.parametrize( + "preds, target, sk_metric", + [ + (inputs.preds, inputs.target, asteroid_si_snr), + ], +) +class TestSISNR(MetricTester): + + @pytest.mark.parametrize("ddp", [True, False]) + @pytest.mark.parametrize("dist_sync_on_step", [True, False]) + def test_si_snr(self, preds, target, sk_metric, ddp, + dist_sync_on_step): + self.run_class_metric_test( + ddp, + preds, + target, + SI_SNR, + sk_metric=partial(average_metric, metric_func=sk_metric), + dist_sync_on_step=dist_sync_on_step, + ) + + def test_si_snr_functional(self, preds, target, sk_metric): + self.run_functional_metric_test( + preds, + target, + si_snr, + sk_metric, + ) + + def test_si_snr_differentiability(self, preds, target, sk_metric): + self.run_differentiability_test(preds=preds, + target=target, + metric_module=SI_SNR, + metric_functional=si_snr) + + @pytest.mark.skipif( + not _TORCH_GREATER_EQUAL_1_6, + reason= + 'half support of core operations on not support before pytorch v1.6') + def test_si_snr_half_cpu(self, preds, target, sk_metric): + self.run_precision_test_cpu(preds=preds, + target=target, + metric_module=SI_SNR, + metric_functional=si_snr) + + @pytest.mark.skipif(not torch.cuda.is_available(), + reason='test requires cuda') + def test_si_snr_half_gpu(self, preds, target, sk_metric): + self.run_precision_test_gpu(preds=preds, + target=target, + metric_module=SI_SNR, + metric_functional=si_snr) + + +def test_error_on_different_shape(metric_class=SI_SNR): + metric = metric_class() + with pytest.raises(ValueError, match='Inputs must be of shape*'): + metric(torch.randn(100,), torch.randn(50,)) diff --git a/tests/audio/test_snr.py b/tests/audio/test_snr.py new file mode 100644 index 00000000000..0d8e28a7272 --- /dev/null +++ b/tests/audio/test_snr.py @@ -0,0 +1,124 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from collections import namedtuple +from functools import partial + +import pytest +import torch +from asteroid.losses import pairwise_neg_snr + +from tests.helpers import seed_all +from tests.helpers.testers import BATCH_SIZE, NUM_BATCHES, MetricTester +from torchmetrics.functional import snr +from torchmetrics.audio import SNR +from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_6 +from mir_eval.separation import bss_eval_images + +seed_all(42) + +Time = 1000 + +Input = namedtuple('Input', ["preds", "target"]) + +inputs = Input( + preds=torch.rand(NUM_BATCHES, BATCH_SIZE, 1, Time), + target=torch.rand(NUM_BATCHES, BATCH_SIZE, 1, Time), +) + + +def asteroid_snr(preds, target): + # shape: preds [BATCH_SIZE, 1, Time] , target [BATCH_SIZE, 1, Time] + snr_v = -pairwise_neg_snr(preds, target) + return snr_v.view(BATCH_SIZE, 1) + + +def bss_eval_images_snr(preds, target): + # shape: preds [BATCH_SIZE, 1, Time] , target [BATCH_SIZE, 1, Time] + snr_vb = [] + for j in range(BATCH_SIZE): + snr_v = bss_eval_images([target[j].view(-1).numpy()], + [preds[j].view(-1).numpy()])[0][0][0] + snr_vb.append(snr_v) + return torch.tensor(snr_vb) + + +def average_metric(preds, target, metric_func): + # shape: preds [BATCH_SIZE, 1, Time] , target [BATCH_SIZE, 1, Time] + return metric_func(preds, target).mean() + + +@pytest.mark.parametrize( + "preds, target, sk_metric, zero_mean", + [ + (inputs.preds, inputs.target, asteroid_snr, True), + (inputs.preds, inputs.target, bss_eval_images_snr, False), + ], +) +class TestSNR(MetricTester): + + @pytest.mark.parametrize("ddp", [True, False]) + @pytest.mark.parametrize("dist_sync_on_step", [True, False]) + def test_snr(self, preds, target, sk_metric, zero_mean, ddp, + dist_sync_on_step): + self.run_class_metric_test( + ddp, + preds, + target, + SNR, + sk_metric=partial(average_metric, metric_func=sk_metric), + dist_sync_on_step=dist_sync_on_step, + metric_args=dict(zero_mean=zero_mean), + ) + + def test_snr_functional(self, preds, target, sk_metric, zero_mean): + self.run_functional_metric_test( + preds, + target, + snr, + sk_metric, + metric_args=dict(zero_mean=zero_mean), + ) + + def test_snr_differentiability(self, preds, target, sk_metric, zero_mean): + self.run_differentiability_test(preds=preds, + target=target, + metric_module=SNR, + metric_functional=snr, + metric_args={'zero_mean': zero_mean}) + + @pytest.mark.skipif( + not _TORCH_GREATER_EQUAL_1_6, + reason= + 'half support of core operations on not support before pytorch v1.6') + def test_snr_half_cpu(self, preds, target, sk_metric, zero_mean): + self.run_precision_test_cpu(preds=preds, + target=target, + metric_module=SNR, + metric_functional=snr, + metric_args={'zero_mean': zero_mean}) + + @pytest.mark.skipif(not torch.cuda.is_available(), + reason='test requires cuda') + def test_snr_half_gpu(self, preds, target, sk_metric, zero_mean): + self.run_precision_test_gpu(preds=preds, + target=target, + metric_module=SNR, + metric_functional=snr, + metric_args={'zero_mean': zero_mean}) + + +def test_error_on_different_shape(metric_class=SNR): + metric = metric_class() + with pytest.raises(ValueError, match='Inputs must be of shape*'): + metric(torch.randn(100,), torch.randn(50,)) From 1d77709fa7d24e9929d8259cbf812a1069865f64 Mon Sep 17 00:00:00 2001 From: quancs Date: Mon, 14 Jun 2021 02:50:25 +0800 Subject: [PATCH 020/109] add module doc --- docs/source/references/modules.rst | 38 ++++++++++++++++++++++++++++++ 1 file changed, 38 insertions(+) diff --git a/docs/source/references/modules.rst b/docs/source/references/modules.rst index 1220cc413ce..e9ca37020a8 100644 --- a/docs/source/references/modules.rst +++ b/docs/source/references/modules.rst @@ -2,6 +2,44 @@ Module metrics ############## +************* +Audio Metrics +************* + +Input details +~~~~~~~~~~~~~ + +For the purposes of audio metrics, inputs (predictions, targets) must have the same size. + +.. doctest:: + + >>> import torch + >>> from torchmetrics import SNR + >>> target = torch.tensor([3.0, -0.5, 2.0, 7.0]) + >>> preds = torch.tensor([2.5, 0.0, 2.0, 8.0]) + >>> snr = SNR() + >>> snr_val = snr(preds, target) + >>> snr_val + tensor(16.1805) + +SI_SDR +~~~~~~ + +.. autoclass:: torchmetrics.SI_SDR + :noindex: + +SI_SNR +~~~~~~ + +.. autoclass:: torchmetrics.SI_SNR + :noindex: + +SNR +~~~ + +.. autoclass:: torchmetrics.SNR + :noindex: + ********** Base class ********** From 6a2d8860013fb73696ff6335ae4d621e08b6c23b Mon Sep 17 00:00:00 2001 From: quancs <1017241746@qq.com> Date: Mon, 14 Jun 2021 02:50:44 +0800 Subject: [PATCH 021/109] Update torchmetrics/audio/SI_SDR.py Co-authored-by: Nicki Skafte --- torchmetrics/audio/SI_SDR.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchmetrics/audio/SI_SDR.py b/torchmetrics/audio/SI_SDR.py index 28c9872c522..c86e361c28b 100644 --- a/torchmetrics/audio/SI_SDR.py +++ b/torchmetrics/audio/SI_SDR.py @@ -24,8 +24,8 @@ class SI_SDR(Metric): Forward accepts - - ``preds`` (Tensor): ``shape [..., time]`` - - ``target`` (Tensor): ``shape [..., time]`` + - ``preds``: ``shape [..., time]`` + - ``target``: ``shape [..., time]`` Args: zero_mean: From 44a4a1773371260bfa5827a673b775f26a7bf04d Mon Sep 17 00:00:00 2001 From: quancs <1017241746@qq.com> Date: Mon, 14 Jun 2021 02:50:51 +0800 Subject: [PATCH 022/109] Update torchmetrics/audio/SI_SDR.py Co-authored-by: Nicki Skafte --- torchmetrics/audio/SI_SDR.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchmetrics/audio/SI_SDR.py b/torchmetrics/audio/SI_SDR.py index c86e361c28b..22bb4f1a3ec 100644 --- a/torchmetrics/audio/SI_SDR.py +++ b/torchmetrics/audio/SI_SDR.py @@ -95,7 +95,7 @@ def update(self, preds: Tensor, target: Tensor): self.sum_si_sdr += si_sdr_batch.sum() self.total += si_sdr_batch.numel() - def compute(self): + def compute(self) -> Tensor: """ Computes average SI-SDR. """ From 56ee58e45f929914b4560bad66219baf9d8da672 Mon Sep 17 00:00:00 2001 From: quancs <1017241746@qq.com> Date: Mon, 14 Jun 2021 02:51:00 +0800 Subject: [PATCH 023/109] Update torchmetrics/audio/SI_SNR.py Co-authored-by: Nicki Skafte --- torchmetrics/audio/SI_SNR.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchmetrics/audio/SI_SNR.py b/torchmetrics/audio/SI_SNR.py index 458f70dd59f..99f0b476563 100644 --- a/torchmetrics/audio/SI_SNR.py +++ b/torchmetrics/audio/SI_SNR.py @@ -24,8 +24,8 @@ class SI_SNR(Metric): Forward accepts - - ``preds`` (Tensor): ``shape [..., time]`` - - ``target`` (Tensor): ``shape [..., time]`` + - ``preds``: ``shape [..., time]`` + - ``target``: ``shape [..., time]`` Args: compute_on_step: From 0f551b030b2896e50cd8f0e3cb73005864ce605f Mon Sep 17 00:00:00 2001 From: quancs <1017241746@qq.com> Date: Mon, 14 Jun 2021 02:51:08 +0800 Subject: [PATCH 024/109] Update torchmetrics/audio/SI_SNR.py Co-authored-by: Nicki Skafte --- torchmetrics/audio/SI_SNR.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchmetrics/audio/SI_SNR.py b/torchmetrics/audio/SI_SNR.py index 99f0b476563..98c67e835d7 100644 --- a/torchmetrics/audio/SI_SNR.py +++ b/torchmetrics/audio/SI_SNR.py @@ -92,7 +92,7 @@ def update(self, preds: Tensor, target: Tensor): self.sum_si_snr += si_snr_batch.sum() self.total += si_snr_batch.numel() - def compute(self): + def compute(self) -> Tensor: """ Computes average SI-SNR. """ From 2a6235224ef2ef28948da246a38ffeb4b873f7e5 Mon Sep 17 00:00:00 2001 From: quancs <1017241746@qq.com> Date: Mon, 14 Jun 2021 02:51:15 +0800 Subject: [PATCH 025/109] Update torchmetrics/audio/SNR.py Co-authored-by: Nicki Skafte --- torchmetrics/audio/SNR.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchmetrics/audio/SNR.py b/torchmetrics/audio/SNR.py index 31e3a000116..c6cbdbb0cea 100644 --- a/torchmetrics/audio/SNR.py +++ b/torchmetrics/audio/SNR.py @@ -24,8 +24,8 @@ class SNR(Metric): Forward accepts - - ``preds`` (Tensor): ``shape [..., time]`` - - ``target`` (Tensor): ``shape [..., time]`` + - ``preds``: ``shape [..., time]`` + - ``target``: ``shape [..., time]`` Args: zero_mean: From c4bd0c55217dff1be6886f434442a72ee02cca33 Mon Sep 17 00:00:00 2001 From: quancs <1017241746@qq.com> Date: Mon, 14 Jun 2021 02:51:21 +0800 Subject: [PATCH 026/109] Update torchmetrics/audio/SNR.py Co-authored-by: Nicki Skafte --- torchmetrics/audio/SNR.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchmetrics/audio/SNR.py b/torchmetrics/audio/SNR.py index c6cbdbb0cea..77510b43be0 100644 --- a/torchmetrics/audio/SNR.py +++ b/torchmetrics/audio/SNR.py @@ -95,7 +95,7 @@ def update(self, preds: Tensor, target: Tensor): self.sum_snr += snr_batch.sum() self.total += snr_batch.numel() - def compute(self): + def compute(self) -> Tensor: """ Computes average SNR. """ From 3bd1e7bd3aa077b1b6988fa217471251219ddce4 Mon Sep 17 00:00:00 2001 From: quancs <1017241746@qq.com> Date: Mon, 14 Jun 2021 02:51:43 +0800 Subject: [PATCH 027/109] Update torchmetrics/functional/audio/si_snr.py Co-authored-by: Nicki Skafte --- torchmetrics/functional/audio/si_snr.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchmetrics/functional/audio/si_snr.py b/torchmetrics/functional/audio/si_snr.py index 97216bc5683..c16b2099f1a 100644 --- a/torchmetrics/functional/audio/si_snr.py +++ b/torchmetrics/functional/audio/si_snr.py @@ -13,7 +13,7 @@ # limitations under the License. from torch import Tensor -from .si_sdr import si_sdr +from torchmetrics.functional.audio.si_sdr import si_sdr def si_snr(preds: Tensor, target: Tensor) -> Tensor: From e028cd3cc188e1f46c09dcd43a35d6fdaf458059 Mon Sep 17 00:00:00 2001 From: quancs Date: Mon, 14 Jun 2021 02:58:40 +0800 Subject: [PATCH 028/109] use _check_same_shape --- tests/audio/test_si_sdr.py | 5 ++++- tests/audio/test_si_snr.py | 5 ++++- tests/audio/test_snr.py | 5 ++++- torchmetrics/functional/audio/si_sdr.py | 10 +++------- torchmetrics/functional/audio/si_snr.py | 4 ---- torchmetrics/functional/audio/snr.py | 10 +++------- 6 files changed, 18 insertions(+), 21 deletions(-) diff --git a/tests/audio/test_si_sdr.py b/tests/audio/test_si_sdr.py index a178f71d68b..5ebd9ff303c 100644 --- a/tests/audio/test_si_sdr.py +++ b/tests/audio/test_si_sdr.py @@ -117,5 +117,8 @@ def test_si_sdr_half_gpu(self, preds, target, sk_metric, zero_mean): def test_error_on_different_shape(metric_class=SI_SDR): metric = metric_class() - with pytest.raises(ValueError, match='Inputs must be of shape*'): + with pytest.raises( + RuntimeError, + match='Predictions and targets are expected to have the same shape' + ): metric(torch.randn(100,), torch.randn(50,)) diff --git a/tests/audio/test_si_snr.py b/tests/audio/test_si_snr.py index 0a57f92ea65..cc9ec1fbd5b 100644 --- a/tests/audio/test_si_snr.py +++ b/tests/audio/test_si_snr.py @@ -103,5 +103,8 @@ def test_si_snr_half_gpu(self, preds, target, sk_metric): def test_error_on_different_shape(metric_class=SI_SNR): metric = metric_class() - with pytest.raises(ValueError, match='Inputs must be of shape*'): + with pytest.raises( + RuntimeError, + match='Predictions and targets are expected to have the same shape' + ): metric(torch.randn(100,), torch.randn(50,)) diff --git a/tests/audio/test_snr.py b/tests/audio/test_snr.py index 0d8e28a7272..a32d47ddfa6 100644 --- a/tests/audio/test_snr.py +++ b/tests/audio/test_snr.py @@ -120,5 +120,8 @@ def test_snr_half_gpu(self, preds, target, sk_metric, zero_mean): def test_error_on_different_shape(metric_class=SNR): metric = metric_class() - with pytest.raises(ValueError, match='Inputs must be of shape*'): + with pytest.raises( + RuntimeError, + match='Predictions and targets are expected to have the same shape' + ): metric(torch.randn(100,), torch.randn(50,)) diff --git a/torchmetrics/functional/audio/si_sdr.py b/torchmetrics/functional/audio/si_sdr.py index 930a4f6292d..34396acd31e 100644 --- a/torchmetrics/functional/audio/si_sdr.py +++ b/torchmetrics/functional/audio/si_sdr.py @@ -14,6 +14,8 @@ import torch from torch import Tensor +from torchmetrics.utilities.checks import _check_same_shape + def si_sdr(preds: Tensor, target: Tensor, zero_mean: bool = False) -> Tensor: """ scale-invariant signal-to-distortion ratio (SI-SDR) @@ -26,10 +28,6 @@ def si_sdr(preds: Tensor, target: Tensor, zero_mean: bool = False) -> Tensor: zero_mean: if to zero mean target and preds or not - Raises: - TypeError: - if target and preds have a different shape - Returns: si-sdr value of shape [...] @@ -45,9 +43,7 @@ def si_sdr(preds: Tensor, target: Tensor, zero_mean: bool = False) -> Tensor: [1] Le Roux, Jonathan, et al. "SDR half-baked or well done." IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP) 2019. """ - - if target.shape != preds.shape: - raise TypeError(f"Inputs must be of shape [..., time], got {target.shape} and {preds.shape} instead") + _check_same_shape(preds, target) if zero_mean: target = target - torch.mean(target, dim=-1, keepdim=True) diff --git a/torchmetrics/functional/audio/si_snr.py b/torchmetrics/functional/audio/si_snr.py index 97216bc5683..eff721967fe 100644 --- a/torchmetrics/functional/audio/si_snr.py +++ b/torchmetrics/functional/audio/si_snr.py @@ -25,10 +25,6 @@ def si_snr(preds: Tensor, target: Tensor) -> Tensor: target: shape [..., time] - Raises: - TypeError: - if target and preds have a different shape - Returns: si-snr value of shape [...] diff --git a/torchmetrics/functional/audio/snr.py b/torchmetrics/functional/audio/snr.py index e84cf50d130..4f04e5902a8 100644 --- a/torchmetrics/functional/audio/snr.py +++ b/torchmetrics/functional/audio/snr.py @@ -14,6 +14,8 @@ import torch from torch import Tensor +from torchmetrics.utilities.checks import _check_same_shape + def snr(preds: Tensor, target: Tensor, zero_mean: bool = False) -> Tensor: """ signal-to-noise ratio (SNR) @@ -26,10 +28,6 @@ def snr(preds: Tensor, target: Tensor, zero_mean: bool = False) -> Tensor: zero_mean: if to zero mean target and preds or not - Raises: - TypeError: - if target and preds have a different shape - Returns: snr value of shape [...] @@ -45,9 +43,7 @@ def snr(preds: Tensor, target: Tensor, zero_mean: bool = False) -> Tensor: [1] Le Roux, Jonathan, et al. "SDR half-baked or well done." IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP) 2019. """ - - if target.shape != preds.shape: - raise TypeError(f"Inputs must be of shape [..., time], got {target.shape} and {preds.shape} instead") + _check_same_shape(preds, target) if zero_mean: target = target - torch.mean(target, dim=-1, keepdim=True) From 96c579a7726c1993e1aa02088738121b47f8f927 Mon Sep 17 00:00:00 2001 From: quancs Date: Mon, 14 Jun 2021 03:03:01 +0800 Subject: [PATCH 029/109] to alphabetical order --- docs/source/references/functional.rst | 50 +++++++++++++-------------- 1 file changed, 25 insertions(+), 25 deletions(-) diff --git a/docs/source/references/functional.rst b/docs/source/references/functional.rst index c5d7a6eac80..0610ffe36c7 100644 --- a/docs/source/references/functional.rst +++ b/docs/source/references/functional.rst @@ -5,6 +5,31 @@ Functional metrics ################## +************* +Audio Metrics +************* + +si_sdr [func] +~~~~~~~~~~~~~ + +.. autofunction:: torchmetrics.functional.si_sdr + :noindex: + + +si_snr [func] +~~~~~~~~~~~~~ + +.. autofunction:: torchmetrics.functional.si_snr + :noindex: + + +snr [func] +~~~~~~~~~~ + +.. autofunction:: torchmetrics.functional.snr + :noindex: + + ********************** Classification Metrics ********************** @@ -309,28 +334,3 @@ retrieval_normalized_dcg [func] .. autofunction:: torchmetrics.functional.retrieval_normalized_dcg :noindex: - -****************** -Audio Metrics -****************** - -snr [func] -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -.. autofunction:: torchmetrics.functional.snr - :noindex: - - -si_sdr [func] -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -.. autofunction:: torchmetrics.functional.si_sdr - :noindex: - - -si_snr [func] -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -.. autofunction:: torchmetrics.functional.si_snr - :noindex: - From 19e0f0b773a52109d7cfc1478ca11ee09ea52c84 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 13 Jun 2021 19:10:27 +0000 Subject: [PATCH 030/109] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- docs/source/references/functional.rst | 1 - tests/audio/test_si_sdr.py | 68 +++++++++++++-------------- tests/audio/test_si_snr.py | 35 ++++---------- tests/audio/test_snr.py | 51 ++++++++------------ 4 files changed, 61 insertions(+), 94 deletions(-) diff --git a/docs/source/references/functional.rst b/docs/source/references/functional.rst index 0610ffe36c7..60c1d48911a 100644 --- a/docs/source/references/functional.rst +++ b/docs/source/references/functional.rst @@ -333,4 +333,3 @@ retrieval_normalized_dcg [func] .. autofunction:: torchmetrics.functional.retrieval_normalized_dcg :noindex: - diff --git a/tests/audio/test_si_sdr.py b/tests/audio/test_si_sdr.py index 5ebd9ff303c..20a301c9345 100644 --- a/tests/audio/test_si_sdr.py +++ b/tests/audio/test_si_sdr.py @@ -20,8 +20,8 @@ from tests.helpers import seed_all from tests.helpers.testers import BATCH_SIZE, NUM_BATCHES, MetricTester -from torchmetrics.functional import si_sdr from torchmetrics.audio import SI_SDR +from torchmetrics.functional import si_sdr from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_6 seed_all(42) @@ -47,11 +47,8 @@ def average_metric(preds, target, metric_func): return metric_func(preds, target).mean() -asteroid_sisdr_zero_mean = partial(asteroid_metric, - asteroid_loss_func=PairwiseNegSDR("sisdr")) -asteroid_sisdr_no_zero_mean = partial(asteroid_metric, - asteroid_loss_func=PairwiseNegSDR( - "sisdr", zero_mean=False)) +asteroid_sisdr_zero_mean = partial(asteroid_metric, asteroid_loss_func=PairwiseNegSDR("sisdr")) +asteroid_sisdr_no_zero_mean = partial(asteroid_metric, asteroid_loss_func=PairwiseNegSDR("sisdr", zero_mean=False)) @pytest.mark.parametrize( @@ -65,8 +62,7 @@ class TestSISDR(MetricTester): @pytest.mark.parametrize("ddp", [True, False]) @pytest.mark.parametrize("dist_sync_on_step", [True, False]) - def test_si_sdr(self, preds, target, sk_metric, zero_mean, ddp, - dist_sync_on_step): + def test_si_sdr(self, preds, target, sk_metric, zero_mean, ddp, dist_sync_on_step): self.run_class_metric_test( ddp, preds, @@ -86,39 +82,39 @@ def test_si_sdr_functional(self, preds, target, sk_metric, zero_mean): metric_args=dict(zero_mean=zero_mean), ) - def test_si_sdr_differentiability(self, preds, target, sk_metric, - zero_mean): - self.run_differentiability_test(preds=preds, - target=target, - metric_module=SI_SDR, - metric_functional=si_sdr, - metric_args={'zero_mean': zero_mean}) + def test_si_sdr_differentiability(self, preds, target, sk_metric, zero_mean): + self.run_differentiability_test( + preds=preds, + target=target, + metric_module=SI_SDR, + metric_functional=si_sdr, + metric_args={'zero_mean': zero_mean} + ) @pytest.mark.skipif( - not _TORCH_GREATER_EQUAL_1_6, - reason= - 'half support of core operations on not support before pytorch v1.6') + not _TORCH_GREATER_EQUAL_1_6, reason='half support of core operations on not support before pytorch v1.6' + ) def test_si_sdr_half_cpu(self, preds, target, sk_metric, zero_mean): - self.run_precision_test_cpu(preds=preds, - target=target, - metric_module=SI_SDR, - metric_functional=si_sdr, - metric_args={'zero_mean': zero_mean}) - - @pytest.mark.skipif(not torch.cuda.is_available(), - reason='test requires cuda') + self.run_precision_test_cpu( + preds=preds, + target=target, + metric_module=SI_SDR, + metric_functional=si_sdr, + metric_args={'zero_mean': zero_mean} + ) + + @pytest.mark.skipif(not torch.cuda.is_available(), reason='test requires cuda') def test_si_sdr_half_gpu(self, preds, target, sk_metric, zero_mean): - self.run_precision_test_gpu(preds=preds, - target=target, - metric_module=SI_SDR, - metric_functional=si_sdr, - metric_args={'zero_mean': zero_mean}) + self.run_precision_test_gpu( + preds=preds, + target=target, + metric_module=SI_SDR, + metric_functional=si_sdr, + metric_args={'zero_mean': zero_mean} + ) def test_error_on_different_shape(metric_class=SI_SDR): metric = metric_class() - with pytest.raises( - RuntimeError, - match='Predictions and targets are expected to have the same shape' - ): - metric(torch.randn(100,), torch.randn(50,)) + with pytest.raises(RuntimeError, match='Predictions and targets are expected to have the same shape'): + metric(torch.randn(100, ), torch.randn(50, )) diff --git a/tests/audio/test_si_snr.py b/tests/audio/test_si_snr.py index cc9ec1fbd5b..e9d7bcfa8d8 100644 --- a/tests/audio/test_si_snr.py +++ b/tests/audio/test_si_snr.py @@ -20,8 +20,8 @@ from tests.helpers import seed_all from tests.helpers.testers import BATCH_SIZE, NUM_BATCHES, MetricTester -from torchmetrics.functional import si_snr from torchmetrics.audio import SI_SNR +from torchmetrics.functional import si_snr from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_6 seed_all(42) @@ -57,8 +57,7 @@ class TestSISNR(MetricTester): @pytest.mark.parametrize("ddp", [True, False]) @pytest.mark.parametrize("dist_sync_on_step", [True, False]) - def test_si_snr(self, preds, target, sk_metric, ddp, - dist_sync_on_step): + def test_si_snr(self, preds, target, sk_metric, ddp, dist_sync_on_step): self.run_class_metric_test( ddp, preds, @@ -77,34 +76,20 @@ def test_si_snr_functional(self, preds, target, sk_metric): ) def test_si_snr_differentiability(self, preds, target, sk_metric): - self.run_differentiability_test(preds=preds, - target=target, - metric_module=SI_SNR, - metric_functional=si_snr) + self.run_differentiability_test(preds=preds, target=target, metric_module=SI_SNR, metric_functional=si_snr) @pytest.mark.skipif( - not _TORCH_GREATER_EQUAL_1_6, - reason= - 'half support of core operations on not support before pytorch v1.6') + not _TORCH_GREATER_EQUAL_1_6, reason='half support of core operations on not support before pytorch v1.6' + ) def test_si_snr_half_cpu(self, preds, target, sk_metric): - self.run_precision_test_cpu(preds=preds, - target=target, - metric_module=SI_SNR, - metric_functional=si_snr) + self.run_precision_test_cpu(preds=preds, target=target, metric_module=SI_SNR, metric_functional=si_snr) - @pytest.mark.skipif(not torch.cuda.is_available(), - reason='test requires cuda') + @pytest.mark.skipif(not torch.cuda.is_available(), reason='test requires cuda') def test_si_snr_half_gpu(self, preds, target, sk_metric): - self.run_precision_test_gpu(preds=preds, - target=target, - metric_module=SI_SNR, - metric_functional=si_snr) + self.run_precision_test_gpu(preds=preds, target=target, metric_module=SI_SNR, metric_functional=si_snr) def test_error_on_different_shape(metric_class=SI_SNR): metric = metric_class() - with pytest.raises( - RuntimeError, - match='Predictions and targets are expected to have the same shape' - ): - metric(torch.randn(100,), torch.randn(50,)) + with pytest.raises(RuntimeError, match='Predictions and targets are expected to have the same shape'): + metric(torch.randn(100, ), torch.randn(50, )) diff --git a/tests/audio/test_snr.py b/tests/audio/test_snr.py index a32d47ddfa6..170f3384ff9 100644 --- a/tests/audio/test_snr.py +++ b/tests/audio/test_snr.py @@ -17,13 +17,13 @@ import pytest import torch from asteroid.losses import pairwise_neg_snr +from mir_eval.separation import bss_eval_images from tests.helpers import seed_all from tests.helpers.testers import BATCH_SIZE, NUM_BATCHES, MetricTester -from torchmetrics.functional import snr from torchmetrics.audio import SNR +from torchmetrics.functional import snr from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_6 -from mir_eval.separation import bss_eval_images seed_all(42) @@ -47,8 +47,7 @@ def bss_eval_images_snr(preds, target): # shape: preds [BATCH_SIZE, 1, Time] , target [BATCH_SIZE, 1, Time] snr_vb = [] for j in range(BATCH_SIZE): - snr_v = bss_eval_images([target[j].view(-1).numpy()], - [preds[j].view(-1).numpy()])[0][0][0] + snr_v = bss_eval_images([target[j].view(-1).numpy()], [preds[j].view(-1).numpy()])[0][0][0] snr_vb.append(snr_v) return torch.tensor(snr_vb) @@ -69,8 +68,7 @@ class TestSNR(MetricTester): @pytest.mark.parametrize("ddp", [True, False]) @pytest.mark.parametrize("dist_sync_on_step", [True, False]) - def test_snr(self, preds, target, sk_metric, zero_mean, ddp, - dist_sync_on_step): + def test_snr(self, preds, target, sk_metric, zero_mean, ddp, dist_sync_on_step): self.run_class_metric_test( ddp, preds, @@ -91,37 +89,26 @@ def test_snr_functional(self, preds, target, sk_metric, zero_mean): ) def test_snr_differentiability(self, preds, target, sk_metric, zero_mean): - self.run_differentiability_test(preds=preds, - target=target, - metric_module=SNR, - metric_functional=snr, - metric_args={'zero_mean': zero_mean}) + self.run_differentiability_test( + preds=preds, target=target, metric_module=SNR, metric_functional=snr, metric_args={'zero_mean': zero_mean} + ) @pytest.mark.skipif( - not _TORCH_GREATER_EQUAL_1_6, - reason= - 'half support of core operations on not support before pytorch v1.6') + not _TORCH_GREATER_EQUAL_1_6, reason='half support of core operations on not support before pytorch v1.6' + ) def test_snr_half_cpu(self, preds, target, sk_metric, zero_mean): - self.run_precision_test_cpu(preds=preds, - target=target, - metric_module=SNR, - metric_functional=snr, - metric_args={'zero_mean': zero_mean}) - - @pytest.mark.skipif(not torch.cuda.is_available(), - reason='test requires cuda') + self.run_precision_test_cpu( + preds=preds, target=target, metric_module=SNR, metric_functional=snr, metric_args={'zero_mean': zero_mean} + ) + + @pytest.mark.skipif(not torch.cuda.is_available(), reason='test requires cuda') def test_snr_half_gpu(self, preds, target, sk_metric, zero_mean): - self.run_precision_test_gpu(preds=preds, - target=target, - metric_module=SNR, - metric_functional=snr, - metric_args={'zero_mean': zero_mean}) + self.run_precision_test_gpu( + preds=preds, target=target, metric_module=SNR, metric_functional=snr, metric_args={'zero_mean': zero_mean} + ) def test_error_on_different_shape(metric_class=SNR): metric = metric_class() - with pytest.raises( - RuntimeError, - match='Predictions and targets are expected to have the same shape' - ): - metric(torch.randn(100,), torch.randn(50,)) + with pytest.raises(RuntimeError, match='Predictions and targets are expected to have the same shape'): + metric(torch.randn(100, ), torch.randn(50, )) From d7b5b0dcb4794b2625c5bedeb2b8d3e4faff53eb Mon Sep 17 00:00:00 2001 From: quancs Date: Mon, 14 Jun 2021 03:28:46 +0800 Subject: [PATCH 031/109] update test --- requirements/test.txt | 3 +++ 1 file changed, 3 insertions(+) diff --git a/requirements/test.txt b/requirements/test.txt index f501671f387..17047114c88 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -18,3 +18,6 @@ cloudpickle>=1.3 scikit-learn>=0.24 scikit-image>0.17.1 nltk>=3.6 + +asteroid>=0.5.1 +mir_eval>=0.6 From ea3aee420de2d98d0e5c9c8496d7974751436240 Mon Sep 17 00:00:00 2001 From: quancs <1017241746@qq.com> Date: Mon, 14 Jun 2021 21:02:55 +0800 Subject: [PATCH 032/109] Update docs/source/references/modules.rst Co-authored-by: Nicki Skafte --- docs/source/references/modules.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/source/references/modules.rst b/docs/source/references/modules.rst index e9ca37020a8..922ec72b486 100644 --- a/docs/source/references/modules.rst +++ b/docs/source/references/modules.rst @@ -10,6 +10,7 @@ Input details ~~~~~~~~~~~~~ For the purposes of audio metrics, inputs (predictions, targets) must have the same size. +If the input is 1D tensors the output will be a scalar. If the input is multi-dimensional with shape [..., time]` the metric will be computed over the `time` dimension. .. doctest:: From fe6e6bcf294f22fd5a94a7121d5fd179602b660e Mon Sep 17 00:00:00 2001 From: quancs Date: Mon, 14 Jun 2021 21:19:54 +0800 Subject: [PATCH 033/109] move Base to the top of Audio --- docs/source/references/modules.rst | 31 +++++++++++++++--------------- 1 file changed, 16 insertions(+), 15 deletions(-) diff --git a/docs/source/references/modules.rst b/docs/source/references/modules.rst index 922ec72b486..99da18ab053 100644 --- a/docs/source/references/modules.rst +++ b/docs/source/references/modules.rst @@ -2,6 +2,22 @@ Module metrics ############## +********** +Base class +********** + +The base ``Metric`` class is an abstract base class that are used as the building block for all other Module +metrics. + +.. autoclass:: torchmetrics.Metric + :noindex: + +We also have an ``AverageMeter`` class that is helpful for defining ad-hoc metrics, when creating +your own metric type might be too burdensome. + +.. autoclass:: torchmetrics.AverageMeter + :noindex: + ************* Audio Metrics ************* @@ -41,21 +57,6 @@ SNR .. autoclass:: torchmetrics.SNR :noindex: -********** -Base class -********** - -The base ``Metric`` class is an abstract base class that are used as the building block for all other Module -metrics. - -.. autoclass:: torchmetrics.Metric - :noindex: - -We also have an ``AverageMeter`` class that is helpful for defining ad-hoc metrics, when creating -your own metric type might be too burdensome. - -.. autoclass:: torchmetrics.AverageMeter - :noindex: ********************** Classification Metrics From 58734edbef62b4ed7f214cc01ac3fe7fb72b08e6 Mon Sep 17 00:00:00 2001 From: quancs Date: Tue, 15 Jun 2021 02:19:51 +0800 Subject: [PATCH 034/109] add soundfile --- requirements/test.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements/test.txt b/requirements/test.txt index 17047114c88..8a36d562903 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -21,3 +21,4 @@ nltk>=3.6 asteroid>=0.5.1 mir_eval>=0.6 +soundfile From 868a9d36ee906ecf86016143ac00170fc333375a Mon Sep 17 00:00:00 2001 From: Jirka Date: Tue, 15 Jun 2021 10:05:59 +0200 Subject: [PATCH 035/109] gcc --- .github/workflows/ci_test-full.yml | 2 +- azure-pipelines.yml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/ci_test-full.yml b/.github/workflows/ci_test-full.yml index 1f172d66cda..aaedd86cc5a 100644 --- a/.github/workflows/ci_test-full.yml +++ b/.github/workflows/ci_test-full.yml @@ -70,7 +70,7 @@ jobs: - name: Install dependencies run: | - python --version + sudo apt-get install -y cmake gcc pip --version pip install --requirement requirements/devel.txt --upgrade --find-links https://download.pytorch.org/whl/cpu/torch_stable.html pip uninstall -y torchmetrics diff --git a/azure-pipelines.yml b/azure-pipelines.yml index ad47cc3a360..61e04195f09 100644 --- a/azure-pipelines.yml +++ b/azure-pipelines.yml @@ -44,7 +44,7 @@ jobs: displayName: 'Image info & NVIDIA' - bash: | - #sudo apt-get install -y cmake + sudo apt-get install -y cmake gcc # python -m pip install "pip==20.1" pip install --requirement ./requirements/devel.txt --upgrade-strategy only-if-needed pip uninstall -y torchmetrics From 6ec17a87333157d81d8d35a1e72eabb3737c1fb5 Mon Sep 17 00:00:00 2001 From: Jirka Date: Tue, 15 Jun 2021 10:15:49 +0200 Subject: [PATCH 036/109] fix cyclic import --- tests/classification/test_specificity.py | 2 +- torchmetrics/classification/stat_scores.py | 66 ------------------ .../functional/classification/accuracy.py | 2 +- .../functional/classification/f_beta.py | 3 +- .../classification/precision_recall.py | 3 +- .../functional/classification/specificity.py | 3 +- .../functional/classification/stat_scores.py | 67 +++++++++++++++++++ 7 files changed, 72 insertions(+), 74 deletions(-) diff --git a/tests/classification/test_specificity.py b/tests/classification/test_specificity.py index ac4b97ed779..c8b209efae6 100644 --- a/tests/classification/test_specificity.py +++ b/tests/classification/test_specificity.py @@ -31,7 +31,7 @@ from tests.helpers import seed_all from tests.helpers.testers import NUM_CLASSES, THRESHOLD, MetricTester from torchmetrics import Metric, Specificity -from torchmetrics.classification.stat_scores import _reduce_stat_scores +from torchmetrics.functional.classification.stat_scores import _reduce_stat_scores from torchmetrics.functional import specificity from torchmetrics.utilities.checks import _input_format_classification from torchmetrics.utilities.enums import AverageMethod diff --git a/torchmetrics/classification/stat_scores.py b/torchmetrics/classification/stat_scores.py index 6fbe5bd461b..8abf71513f9 100644 --- a/torchmetrics/classification/stat_scores.py +++ b/torchmetrics/classification/stat_scores.py @@ -281,69 +281,3 @@ def compute(self) -> Tensor: @property def is_differentiable(self) -> bool: return False - - -def _reduce_stat_scores( - numerator: Tensor, - denominator: Tensor, - weights: Optional[Tensor], - average: str, - mdmc_average: Optional[str], - zero_division: int = 0, -) -> Tensor: - """ - Reduces scores of type ``numerator/denominator`` or - ``weights * (numerator/denominator)``, if ``average='weighted'``. - - Args: - numerator: A tensor with numerator numbers. - denominator: A tensor with denominator numbers. If a denominator is - negative, the class will be ignored (if averaging), or its score - will be returned as ``nan`` (if ``average=None``). - If the denominator is zero, then ``zero_division`` score will be - used for those elements. - weights: - A tensor of weights to be used if ``average='weighted'``. - average: - The method to average the scores. Should be one of ``'micro'``, ``'macro'``, - ``'weighted'``, ``'none'``, ``None`` or ``'samples'``. The behavior - corresponds to `sklearn averaging methods `__. - mdmc_average: - The method to average the scores if inputs were multi-dimensional multi-class (MDMC). - Should be either ``'global'`` or ``'samplewise'``. If inputs were not - multi-dimensional multi-class, it should be ``None`` (default). - zero_division: - The value to use for the score if denominator equals zero. - """ - numerator, denominator = numerator.float(), denominator.float() - zero_div_mask = denominator == 0 - ignore_mask = denominator < 0 - - if weights is None: - weights = torch.ones_like(denominator) - else: - weights = weights.float() - - numerator = torch.where(zero_div_mask, tensor(float(zero_division), device=numerator.device), numerator) - denominator = torch.where(zero_div_mask | ignore_mask, tensor(1.0, device=denominator.device), denominator) - weights = torch.where(ignore_mask, tensor(0.0, device=weights.device), weights) - - if average not in (AverageMethod.MICRO, AverageMethod.NONE, None): - weights = weights / weights.sum(dim=-1, keepdim=True) - - scores = weights * (numerator / denominator) - - # This is in case where sum(weights) = 0, which happens if we ignore the only present class with average='weighted' - scores = torch.where(torch.isnan(scores), tensor(float(zero_division), device=scores.device), scores) - - if mdmc_average == MDMCAverageMethod.SAMPLEWISE: - scores = scores.mean(dim=0) - ignore_mask = ignore_mask.sum(dim=0).bool() - - if average in (AverageMethod.NONE, None): - scores = torch.where(ignore_mask, tensor(float('nan'), device=scores.device), scores) - else: - scores = scores.sum() - - return scores diff --git a/torchmetrics/functional/classification/accuracy.py b/torchmetrics/functional/classification/accuracy.py index 876eaef66a3..a8eac8a7b22 100644 --- a/torchmetrics/functional/classification/accuracy.py +++ b/torchmetrics/functional/classification/accuracy.py @@ -16,7 +16,7 @@ import torch from torch import Tensor, tensor -from torchmetrics.classification.stat_scores import _reduce_stat_scores +from torchmetrics.functional.classification.stat_scores import _reduce_stat_scores from torchmetrics.functional.classification.stat_scores import _stat_scores_update from torchmetrics.utilities.checks import _check_classification_inputs, _input_format_classification, _input_squeeze from torchmetrics.utilities.enums import AverageMethod, DataType, MDMCAverageMethod diff --git a/torchmetrics/functional/classification/f_beta.py b/torchmetrics/functional/classification/f_beta.py index 548cd59c343..5c03313fe4b 100644 --- a/torchmetrics/functional/classification/f_beta.py +++ b/torchmetrics/functional/classification/f_beta.py @@ -16,8 +16,7 @@ import torch from torch import Tensor -from torchmetrics.classification.stat_scores import _reduce_stat_scores -from torchmetrics.functional.classification.stat_scores import _stat_scores_update +from torchmetrics.functional.classification.stat_scores import _reduce_stat_scores, _stat_scores_update from torchmetrics.utilities import _deprecation_warn_arg_multilabel from torchmetrics.utilities.enums import AverageMethod, MDMCAverageMethod diff --git a/torchmetrics/functional/classification/precision_recall.py b/torchmetrics/functional/classification/precision_recall.py index 0e8e517c5e7..3956fe04c1c 100644 --- a/torchmetrics/functional/classification/precision_recall.py +++ b/torchmetrics/functional/classification/precision_recall.py @@ -16,8 +16,7 @@ import torch from torch import Tensor -from torchmetrics.classification.stat_scores import _reduce_stat_scores -from torchmetrics.functional.classification.stat_scores import _stat_scores_update +from torchmetrics.functional.classification.stat_scores import _reduce_stat_scores, _stat_scores_update from torchmetrics.utilities import _deprecation_warn_arg_is_multiclass, _deprecation_warn_arg_multilabel from torchmetrics.utilities.enums import AverageMethod, MDMCAverageMethod diff --git a/torchmetrics/functional/classification/specificity.py b/torchmetrics/functional/classification/specificity.py index 7f97e32b637..e8128b5db09 100644 --- a/torchmetrics/functional/classification/specificity.py +++ b/torchmetrics/functional/classification/specificity.py @@ -16,8 +16,7 @@ import torch from torch import Tensor -from torchmetrics.classification.stat_scores import _reduce_stat_scores -from torchmetrics.functional.classification.stat_scores import _stat_scores_update +from torchmetrics.functional.classification.stat_scores import _reduce_stat_scores, _stat_scores_update from torchmetrics.utilities.enums import AverageMethod, MDMCAverageMethod diff --git a/torchmetrics/functional/classification/stat_scores.py b/torchmetrics/functional/classification/stat_scores.py index c2e62f69aa7..c524cedac2a 100644 --- a/torchmetrics/functional/classification/stat_scores.py +++ b/torchmetrics/functional/classification/stat_scores.py @@ -18,6 +18,7 @@ from torchmetrics.utilities import _deprecation_warn_arg_is_multiclass from torchmetrics.utilities.checks import _input_format_classification +from torchmetrics.utilities.enums import AverageMethod, MDMCAverageMethod def _del_column(tensor: Tensor, index: int): @@ -138,6 +139,72 @@ def _stat_scores_compute(tp: Tensor, fp: Tensor, tn: Tensor, fn: Tensor) -> Tens return outputs +def _reduce_stat_scores( + numerator: Tensor, + denominator: Tensor, + weights: Optional[Tensor], + average: str, + mdmc_average: Optional[str], + zero_division: int = 0, +) -> Tensor: + """ + Reduces scores of type ``numerator/denominator`` or + ``weights * (numerator/denominator)``, if ``average='weighted'``. + + Args: + numerator: A tensor with numerator numbers. + denominator: A tensor with denominator numbers. If a denominator is + negative, the class will be ignored (if averaging), or its score + will be returned as ``nan`` (if ``average=None``). + If the denominator is zero, then ``zero_division`` score will be + used for those elements. + weights: + A tensor of weights to be used if ``average='weighted'``. + average: + The method to average the scores. Should be one of ``'micro'``, ``'macro'``, + ``'weighted'``, ``'none'``, ``None`` or ``'samples'``. The behavior + corresponds to `sklearn averaging methods `__. + mdmc_average: + The method to average the scores if inputs were multi-dimensional multi-class (MDMC). + Should be either ``'global'`` or ``'samplewise'``. If inputs were not + multi-dimensional multi-class, it should be ``None`` (default). + zero_division: + The value to use for the score if denominator equals zero. + """ + numerator, denominator = numerator.float(), denominator.float() + zero_div_mask = denominator == 0 + ignore_mask = denominator < 0 + + if weights is None: + weights = torch.ones_like(denominator) + else: + weights = weights.float() + + numerator = torch.where(zero_div_mask, tensor(float(zero_division), device=numerator.device), numerator) + denominator = torch.where(zero_div_mask | ignore_mask, tensor(1.0, device=denominator.device), denominator) + weights = torch.where(ignore_mask, tensor(0.0, device=weights.device), weights) + + if average not in (AverageMethod.MICRO, AverageMethod.NONE, None): + weights = weights / weights.sum(dim=-1, keepdim=True) + + scores = weights * (numerator / denominator) + + # This is in case where sum(weights) = 0, which happens if we ignore the only present class with average='weighted' + scores = torch.where(torch.isnan(scores), tensor(float(zero_division), device=scores.device), scores) + + if mdmc_average == MDMCAverageMethod.SAMPLEWISE: + scores = scores.mean(dim=0) + ignore_mask = ignore_mask.sum(dim=0).bool() + + if average in (AverageMethod.NONE, None): + scores = torch.where(ignore_mask, tensor(float('nan'), device=scores.device), scores) + else: + scores = scores.sum() + + return scores + + def stat_scores( preds: Tensor, target: Tensor, From 1b0d379984330b7820e9ab0f8e97f319f189a99c Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 15 Jun 2021 08:16:30 +0000 Subject: [PATCH 037/109] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/classification/test_specificity.py | 2 +- torchmetrics/functional/classification/accuracy.py | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/tests/classification/test_specificity.py b/tests/classification/test_specificity.py index c8b209efae6..e7d1cfd3d40 100644 --- a/tests/classification/test_specificity.py +++ b/tests/classification/test_specificity.py @@ -31,8 +31,8 @@ from tests.helpers import seed_all from tests.helpers.testers import NUM_CLASSES, THRESHOLD, MetricTester from torchmetrics import Metric, Specificity -from torchmetrics.functional.classification.stat_scores import _reduce_stat_scores from torchmetrics.functional import specificity +from torchmetrics.functional.classification.stat_scores import _reduce_stat_scores from torchmetrics.utilities.checks import _input_format_classification from torchmetrics.utilities.enums import AverageMethod diff --git a/torchmetrics/functional/classification/accuracy.py b/torchmetrics/functional/classification/accuracy.py index a8eac8a7b22..bfd2d897794 100644 --- a/torchmetrics/functional/classification/accuracy.py +++ b/torchmetrics/functional/classification/accuracy.py @@ -16,8 +16,7 @@ import torch from torch import Tensor, tensor -from torchmetrics.functional.classification.stat_scores import _reduce_stat_scores -from torchmetrics.functional.classification.stat_scores import _stat_scores_update +from torchmetrics.functional.classification.stat_scores import _reduce_stat_scores, _stat_scores_update from torchmetrics.utilities.checks import _check_classification_inputs, _input_format_classification, _input_squeeze from torchmetrics.utilities.enums import AverageMethod, DataType, MDMCAverageMethod From 37a812fb9065330840e5d6a9d944008ef615ebde Mon Sep 17 00:00:00 2001 From: Jirka Date: Tue, 15 Jun 2021 10:25:41 +0200 Subject: [PATCH 038/109] pysndfile --- .github/workflows/ci_test-full.yml | 3 +-- requirements/test.txt | 1 + 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/ci_test-full.yml b/.github/workflows/ci_test-full.yml index aaedd86cc5a..58f1c3d83fb 100644 --- a/.github/workflows/ci_test-full.yml +++ b/.github/workflows/ci_test-full.yml @@ -43,7 +43,7 @@ jobs: - name: Setup macOS if: runner.os == 'macOS' run: | - brew install libomp # https://github.com/pytorch/pytorch/issues/20030 + brew install gcc libomp # https://github.com/pytorch/pytorch/issues/20030 - name: Set min. dependencies if: matrix.requires == 'minimal' @@ -70,7 +70,6 @@ jobs: - name: Install dependencies run: | - sudo apt-get install -y cmake gcc pip --version pip install --requirement requirements/devel.txt --upgrade --find-links https://download.pytorch.org/whl/cpu/torch_stable.html pip uninstall -y torchmetrics diff --git a/requirements/test.txt b/requirements/test.txt index 8a36d562903..cd864f645bc 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -22,3 +22,4 @@ nltk>=3.6 asteroid>=0.5.1 mir_eval>=0.6 soundfile +pysndfile From 7423b6eb9439fffb1a5b319e665935aa3570e5d6 Mon Sep 17 00:00:00 2001 From: Jirka Date: Tue, 15 Jun 2021 10:40:58 +0200 Subject: [PATCH 039/109] v0.4.5 --- requirements/test.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements/test.txt b/requirements/test.txt index cd864f645bc..a584bb06f30 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -19,7 +19,7 @@ scikit-learn>=0.24 scikit-image>0.17.1 nltk>=3.6 -asteroid>=0.5.1 +asteroid==0.4.5 # v0.5 drop support for PT<1.8 mir_eval>=0.6 soundfile pysndfile From 0bc3956f3851ff74a38dd9743b8881a436223061 Mon Sep 17 00:00:00 2001 From: Jirka Date: Tue, 15 Jun 2021 10:52:44 +0200 Subject: [PATCH 040/109] pl --- requirements/integrate.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements/integrate.txt b/requirements/integrate.txt index 5c2802a7f46..1164d62fc6a 100644 --- a/requirements/integrate.txt +++ b/requirements/integrate.txt @@ -1 +1 @@ -pytorch-lightning>=1.0 +pytorch-lightning>=1.0.1 # because if asteroid depends From 587b26e983bfc1fe92f63c21b5e392b9fffd1157 Mon Sep 17 00:00:00 2001 From: Jirka Date: Tue, 15 Jun 2021 11:04:28 +0200 Subject: [PATCH 041/109] clang --- azure-pipelines.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/azure-pipelines.yml b/azure-pipelines.yml index 61e04195f09..76bb72a9f56 100644 --- a/azure-pipelines.yml +++ b/azure-pipelines.yml @@ -44,7 +44,7 @@ jobs: displayName: 'Image info & NVIDIA' - bash: | - sudo apt-get install -y cmake gcc + sudo apt-get install -y cmake gcc clang # python -m pip install "pip==20.1" pip install --requirement ./requirements/devel.txt --upgrade-strategy only-if-needed pip uninstall -y torchmetrics From fdceaf4a499e4219e7a35bb4e3b878665758ef96 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Tue, 15 Jun 2021 11:27:42 +0200 Subject: [PATCH 042/109] Add FID metric (#213) * change * tests * nearly done * fid * working * so close * new update * text * pep8 * Update CHANGELOG.md * move scipy * update requirements for more information, see https://pre-commit.ci * typing * Update tests/image_quality/test_fid.py Co-authored-by: Jirka Borovec * setup for more information, see https://pre-commit.ci * fix mocking * image * doctest * mypy * fix requirements * fix dtype * something * update * revert * Update requirements/devel.txt * check conda issue * install * manifest * adjust * Apply suggestions from code review * ci * PT 1.3 * synthetic data * tv version for nightly * add test * revert tv version * numpy * synthetic data * tv version for nightly * add test * revert tv version * Format code with yapf, autopep8 and isort * doctest * np * base * Format code with yapf, autopep8 and isort * reduce memory * batch_size * rename * format * url Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Jirka Borovec Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com> Co-authored-by: Jirka Co-authored-by: deepsource-autofix[bot] <62050782+deepsource-autofix[bot]@users.noreply.github.com> --- .github/workflows/ci_test-base.yml | 7 +- .github/workflows/ci_test-conda.yml | 6 +- .github/workflows/ci_test-full.yml | 3 + CHANGELOG.md | 3 + MANIFEST.in | 3 +- docs/source/references/modules.rst | 11 ++ requirements.txt | 1 + requirements/adjust-versions.py | 71 ++++++++ requirements/image.txt | 3 + requirements/test.txt | 5 +- setup.py | 10 + tests/image/__init__.py | 0 tests/image/test_fid.py | 148 +++++++++++++++ torchmetrics/__init__.py | 1 + torchmetrics/image/__init__.py | 14 ++ torchmetrics/image/fid.py | 273 ++++++++++++++++++++++++++++ torchmetrics/setup_tools.py | 2 +- torchmetrics/utilities/imports.py | 1 + 18 files changed, 553 insertions(+), 9 deletions(-) create mode 100644 requirements/adjust-versions.py create mode 100644 requirements/image.txt create mode 100644 tests/image/__init__.py create mode 100644 tests/image/test_fid.py create mode 100644 torchmetrics/image/__init__.py create mode 100644 torchmetrics/image/fid.py diff --git a/.github/workflows/ci_test-base.yml b/.github/workflows/ci_test-base.yml index f276d222f0d..903860f592e 100644 --- a/.github/workflows/ci_test-base.yml +++ b/.github/workflows/ci_test-base.yml @@ -55,11 +55,12 @@ jobs: - name: Install dependencies run: | - python -m pip install --upgrade --user pip - pip install --requirement ./requirements.txt --find-links https://download.pytorch.org/whl/cpu/torch_stable.html - pip install "pytest>6.0" "pytest-cov>2.10" --upgrade-strategy only-if-needed python --version pip --version + pip install --requirement requirements.txt --upgrade --find-links https://download.pytorch.org/whl/cpu/torch_stable.html + python ./requirements/adjust-versions.py requirements/image.txt + pip install --requirement requirements/devel.txt --upgrade --find-links https://download.pytorch.org/whl/cpu/torch_stable.html + pip uninstall -y torchmetrics pip list shell: bash diff --git a/.github/workflows/ci_test-conda.yml b/.github/workflows/ci_test-conda.yml index 90b589e98d8..24cb2e1c775 100644 --- a/.github/workflows/ci_test-conda.yml +++ b/.github/workflows/ci_test-conda.yml @@ -22,7 +22,6 @@ jobs: - uses: actions/checkout@v2 - run: echo "::set-output name=period::$(python -c 'import time ; days = time.time() / 60 / 60 / 24 ; print(int(days / 7))' 2>&1)" - if: matrix.requires == 'latest' id: times - name: Cache conda @@ -57,10 +56,13 @@ jobs: run: | conda info conda install mkl pytorch=${{ matrix.pytorch-version }} cpuonly + conda install cpuonly $(python ./requirements/adjust-versions.py conda) conda list pip --version + python ./requirements/adjust-versions.py requirements.txt + python ./requirements/adjust-versions.py requirements/image.txt pip install --requirement requirements.txt --upgrade-strategy only-if-needed --quiet - pip install --requirement requirements/test.txt --upgrade-strategy only-if-needed --quiet + pip install --requirement requirements/test.txt --upgrade-strategy only-if-needed --quiet --find-links https://download.pytorch.org/whl/cpu/torch_stable.html pip list python -c "import torch; assert torch.__version__[:3] == '${{ matrix.pytorch-version }}', torch.__version__" shell: bash -l {0} diff --git a/.github/workflows/ci_test-full.yml b/.github/workflows/ci_test-full.yml index 58f1c3d83fb..93558bfa376 100644 --- a/.github/workflows/ci_test-full.yml +++ b/.github/workflows/ci_test-full.yml @@ -71,6 +71,9 @@ jobs: - name: Install dependencies run: | pip --version + pip install --requirement requirements.txt --upgrade --find-links https://download.pytorch.org/whl/cpu/torch_stable.html + python ./requirements/adjust-versions.py requirements.txt + python ./requirements/adjust-versions.py requirements/image.txt pip install --requirement requirements/devel.txt --upgrade --find-links https://download.pytorch.org/whl/cpu/torch_stable.html pip uninstall -y torchmetrics pip list diff --git a/CHANGELOG.md b/CHANGELOG.md index e0f689768be..5bd43f0dc14 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -30,6 +30,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added `squared` argument to `MeanSquaredError` for computing `RMSE` ([#249](https://github.com/PyTorchLightning/metrics/pull/249)) +- Added FID metric ([#213](https://github.com/PyTorchLightning/metrics/pull/213)) + + - Added `is_differentiable` property to `ConfusionMatrix`, `F1`, `FBeta`, `Hamming`, `Hinge`, `IOU`, `MatthewsCorrcoef`, `Precision`, `Recall`, `PrecisionRecallCurve`, `ROC`, `StatScores` ([#253](https://github.com/PyTorchLightning/metrics/pull/253)) diff --git a/MANIFEST.in b/MANIFEST.in index 9ee72796c3f..efa086e1f0f 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -28,7 +28,8 @@ exclude docs # Include the Requirements include requirements.txt -recursive-exclude requirements *.txt +recursive-include requirements *.txt +recursive-exclude requirements *.py # Exclude build configs exclude *.yml diff --git a/docs/source/references/modules.rst b/docs/source/references/modules.rst index 99da18ab053..7fee845bf18 100644 --- a/docs/source/references/modules.rst +++ b/docs/source/references/modules.rst @@ -290,6 +290,17 @@ StatScores :noindex: +********************* +Image Quality Metrics +********************* + +Image quality metrics can be used to access the quality of synthetic generated images from machine +learning algorithms such as `Generative Adverserial Networks (GANs) `_. + +.. autoclass:: torchmetrics.FID + :noindex: + + ****************** Regression Metrics ****************** diff --git a/requirements.txt b/requirements.txt index f14c8d7a062..b47541f6e94 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1,3 @@ +numpy>=1.17.2 torch>=1.3.1 packaging diff --git a/requirements/adjust-versions.py b/requirements/adjust-versions.py new file mode 100644 index 00000000000..f21d3ba6749 --- /dev/null +++ b/requirements/adjust-versions.py @@ -0,0 +1,71 @@ +import logging +import os +import re +import sys +from typing import Dict, Optional + +VERSIONS = [ + dict(torch="1.9.0", torchvision="0.10.0", torchtext=""), # nightly + dict(torch="1.8.1", torchvision="0.9.1", torchtext="0.9.1"), + dict(torch="1.8.0", torchvision="0.9.0", torchtext="0.9.0"), + dict(torch="1.7.1", torchvision="0.8.2", torchtext="0.8.1"), + dict(torch="1.7.0", torchvision="0.8.1", torchtext="0.8.0"), + dict(torch="1.6.0", torchvision="0.7.0", torchtext="0.7"), + dict(torch="1.5.1", torchvision="0.6.1", torchtext="0.6"), + dict(torch="1.5.0", torchvision="0.6.0", torchtext="0.6"), + dict(torch="1.4.0", torchvision="0.5.0", torchtext="0.5"), + dict(torch="1.3.1", torchvision="0.4.2", torchtext="0.4"), + dict(torch="1.3.0", torchvision="0.4.1", torchtext="0.4"), +] +VERSIONS.sort(key=lambda v: v["torch"], reverse=True) + + +def find_latest(ver: str) -> Dict[str, str]: + # drop all except semantic version + ver = re.search(r'([\.\d]+)', ver).groups()[0] + # in case there remaining dot at the end - e.g "1.9.0.dev20210504" + ver = ver[:-1] if ver[-1] == '.' else ver + logging.info(f"finding ecosystem versions for: {ver}") + + # find first match + for option in VERSIONS: + if option["torch"].startswith(ver): + return option + + raise ValueError(f"Missing {ver} in {VERSIONS}") + + +def main(path_req: str, torch_version: Optional[str] = None) -> None: + if not torch_version: + import torch + torch_version = torch.__version__ + assert torch_version, f"invalid torch: {torch_version}" + latest = find_latest(torch_version) + + if path_req == "conda": + # this is a special case when we need to get the remaining lib versions + req = " ".join([f"{lib}={ver}" if ver else lib for lib, ver in latest.items() if lib != "torch"]) + print(req) + return + + with open(path_req, "r") as fp: + req = fp.readlines() + # remove comments + req = [r[:r.index("#")] if "#" in r else r for r in req] + req = [r.strip() for r in req] + + for lib, ver in latest.items(): + for i, ln in enumerate(req): + m = re.search(r"(\w\d-_)*?[>=]{0,2}.*", ln) + if m and m.group() == lib: + req[i] = f"{lib}=={ver}" if ver else lib + + req = [r + os.linesep for r in req] + logging.info(req) # on purpose - to debug + with open(path_req, "w") as fp: + fp.writelines(req) + + +if __name__ == "__main__": + logging.basicConfig(level=logging.INFO) + main(*sys.argv[1:]) diff --git a/requirements/image.txt b/requirements/image.txt new file mode 100644 index 00000000000..462520fd6b6 --- /dev/null +++ b/requirements/image.txt @@ -0,0 +1,3 @@ +scipy +torchvision # this is needed to internally set TV version according installed PT +torch-fidelity diff --git a/requirements/test.txt b/requirements/test.txt index a584bb06f30..f65d1d70196 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -1,5 +1,3 @@ -numpy - coverage>5.2 codecov>=2.1 pytest>=6.0 @@ -19,6 +17,9 @@ scikit-learn>=0.24 scikit-image>0.17.1 nltk>=3.6 +# add extra requirements +-r image.txt + asteroid==0.4.5 # v0.5 drop support for PT<1.8 mir_eval>=0.6 soundfile diff --git a/setup.py b/setup.py index 2e3b7547ff8..b8f29583dd3 100755 --- a/setup.py +++ b/setup.py @@ -5,6 +5,7 @@ from setuptools import find_packages, setup _PATH_ROOT = os.path.realpath(os.path.dirname(__file__)) +_PATH_REQUIRE = os.path.join(_PATH_ROOT, 'requirements') def _load_py_module(fname, pkg="torchmetrics"): @@ -22,6 +23,14 @@ def _load_py_module(fname, pkg="torchmetrics"): version=f'v{about.__version__}', ) + +def _prepare_extras(): + extras = { + 'image': setup_tools._load_requirements(path_dir=_PATH_REQUIRE, file_name='image.txt'), + } + return extras + + # https://packaging.python.org/discussions/install-requires-vs-requirements / # keep the meta-data here for simplicity in reading this file... it's not obvious # what happens and to non-engineers they won't know to look in init ... @@ -72,4 +81,5 @@ def _load_py_module(fname, pkg="torchmetrics"): 'Programming Language :: Python :: 3.8', 'Programming Language :: Python :: 3.9', ], + extras_require=_prepare_extras(), ) diff --git a/tests/image/__init__.py b/tests/image/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tests/image/test_fid.py b/tests/image/test_fid.py new file mode 100644 index 00000000000..7987441520f --- /dev/null +++ b/tests/image/test_fid.py @@ -0,0 +1,148 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pickle + +import pytest +import torch +from scipy.linalg import sqrtm as scipy_sqrtm +from torch.utils.data import Dataset + +from torchmetrics.image.fid import FID, sqrtm +from torchmetrics.utilities.imports import _TORCH_FIDELITY_AVAILABLE + +torch.manual_seed(42) + + +@pytest.mark.parametrize("matrix_size", [2, 10, 100, 500]) +def test_matrix_sqrt(matrix_size): + """ test that metrix sqrt function works as expected """ + + def generate_cov(n): + data = torch.randn(2 * n, n) + return (data - data.mean(dim=0)).T @ (data - data.mean(dim=0)) + + cov1 = generate_cov(matrix_size) + cov2 = generate_cov(matrix_size) + + scipy_res = scipy_sqrtm((cov1 @ cov2).numpy()).real + tm_res = sqrtm(cov1 @ cov2) + assert torch.allclose(torch.tensor(scipy_res).float(), tm_res, atol=1e-3) + + +@pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason="test requires torch-fidelity") +def test_no_train(): + """ Assert that metric never leaves evaluation mode """ + + class MyModel(torch.nn.Module): + + def __init__(self): + super().__init__() + self.metric = FID() + + def forward(self, x): + return x + + model = MyModel() + model.train() + assert model.training + assert not model.metric.inception.training, 'FID metric was changed to training mode which should not happen' + + +@pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason='test requires torch-fidelity') +def test_fid_pickle(): + """ Assert that we can initialize the metric and pickle it""" + metric = FID() + assert metric + + # verify metrics work after being loaded from pickled state + pickled_metric = pickle.dumps(metric) + metric = pickle.loads(pickled_metric) + + +def test_fid_raises_errors_and_warnings(): + """ Test that expected warnings and errors are raised """ + with pytest.warns( + UserWarning, + match='Metric `FID` will save all extracted features in buffer.' + ' For large datasets this may lead to large memory footprint.' + ): + _ = FID() + + if _TORCH_FIDELITY_AVAILABLE: + with pytest.raises(ValueError, match='Integer input to argument `feature` must be one of .*'): + _ = FID(feature=2) + else: + with pytest.raises( + ValueError, + match='FID metric requires that Torch-fidelity is installed.' + 'Either install as `pip install torchmetrics[image-quality]`' + ' or `pip install torch-fidelity`' + ): + _ = FID() + + +@pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason='test requires torch-fidelity') +def test_fid_same_input(): + """ if real and fake are update on the same data the fid score should be 0 """ + metric = FID(feature=192) + + for _ in range(2): + img = torch.randint(0, 255, (5, 3, 299, 299), dtype=torch.uint8) + metric.update(img, real=True) + metric.update(img, real=False) + + assert torch.allclose(torch.cat(metric.real_features, dim=0), torch.cat(metric.fake_features, dim=0)) + + val = metric.compute() + assert torch.allclose(val, torch.zeros_like(val), atol=1e-3) + + +class _ImgDataset(Dataset): + + def __init__(self, imgs): + self.imgs = imgs + + def __getitem__(self, idx): + return self.imgs[idx] + + def __len__(self): + return self.imgs.shape[0] + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason='test is too slow without gpu') +@pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason='test requires torch-fidelity') +def test_compare_fid(tmpdir, feature=2048): + """ check that the hole pipeline give the same result as torch-fidelity """ + from torch_fidelity import calculate_metrics + + metric = FID(feature=feature).cuda() + + # Generate some synthetic data + img1 = torch.randint(0, 180, (100, 3, 299, 299), dtype=torch.uint8) + img2 = torch.randint(100, 255, (100, 3, 299, 299), dtype=torch.uint8) + + batch_size = 10 + for i in range(img1.shape[0] // batch_size): + metric.update(img1[batch_size * i:batch_size * (i + 1)].cuda(), real=True) + + for i in range(img2.shape[0] // batch_size): + metric.update(img2[batch_size * i:batch_size * (i + 1)].cuda(), real=False) + + torch_fid = calculate_metrics( + _ImgDataset(img1), _ImgDataset(img2), fid=True, feature_layer_fid=str(feature), batch_size=batch_size + ) + + tm_res = metric.compute() + + assert torch.allclose(tm_res.cpu(), torch.tensor([torch_fid['frechet_inception_distance']]), atol=1e-3) diff --git a/torchmetrics/__init__.py b/torchmetrics/__init__.py index 8e31018a634..ee0b44f10bd 100644 --- a/torchmetrics/__init__.py +++ b/torchmetrics/__init__.py @@ -37,6 +37,7 @@ StatScores, ) from torchmetrics.collections import MetricCollection # noqa: F401 E402 +from torchmetrics.image import FID # noqa: F401 E402 from torchmetrics.metric import Metric # noqa: F401 E402 from torchmetrics.regression import ( # noqa: F401 E402 PSNR, diff --git a/torchmetrics/image/__init__.py b/torchmetrics/image/__init__.py new file mode 100644 index 00000000000..ec4d70fa511 --- /dev/null +++ b/torchmetrics/image/__init__.py @@ -0,0 +1,14 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from torchmetrics.image.fid import FID # noqa: F401 diff --git a/torchmetrics/image/fid.py b/torchmetrics/image/fid.py new file mode 100644 index 00000000000..3b692564575 --- /dev/null +++ b/torchmetrics/image/fid.py @@ -0,0 +1,273 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Callable, List, Optional, Union + +import numpy as np +import torch +from torch import Tensor +from torch.autograd import Function + +from torchmetrics.metric import Metric +from torchmetrics.utilities import rank_zero_info, rank_zero_warn +from torchmetrics.utilities.imports import _TORCH_FIDELITY_AVAILABLE + +if _TORCH_FIDELITY_AVAILABLE: + from torch_fidelity.feature_extractor_inceptionv3 import FeatureExtractorInceptionV3 +else: + + class FeatureExtractorInceptionV3(torch.nn.Module): # type:ignore + pass + + +class NoTrainInceptionV3(FeatureExtractorInceptionV3): + + def __init__( + self, + name: str, + features_list: List[str], + feature_extractor_weights_path: Optional[str] = None, + ) -> None: + super().__init__(name, features_list, feature_extractor_weights_path) + # put into evaluation mode + self.eval() + + def train(self, mode: bool) -> 'NoTrainInceptionV3': + """ the inception network should not be able to be switched away from evaluation mode """ + return super().train(False) + + def forward(self, x: Tensor) -> Tensor: + out = super().forward(x) + return out[0].reshape(x.shape[0], -1) + + +class MatrixSquareRoot(Function): + """Square root of a positive definite matrix. + All credit to: + https://github.com/steveli/pytorch-sqrtm/blob/master/sqrtm.py + + """ + + @staticmethod + def forward(ctx: Any, input: Tensor) -> Tensor: + import scipy + + # TODO: update whenever pytorch gets an matrix square root function + # Issue: https://github.com/pytorch/pytorch/issues/9983 + m = input.detach().cpu().numpy().astype(np.float_) + scipy_res, _ = scipy.linalg.sqrtm(m, disp=False) + sqrtm = torch.from_numpy(scipy_res.real).to(input) + ctx.save_for_backward(sqrtm) + return sqrtm + + @staticmethod + def backward(ctx: Any, grad_output: Tensor) -> Tensor: + import scipy + grad_input = None + if ctx.needs_input_grad[0]: + sqrtm, = ctx.saved_tensors + sqrtm = sqrtm.data.cpu().numpy().astype(np.float_) + gm = grad_output.data.cpu().numpy().astype(np.float_) + + # Given a positive semi-definite matrix X, + # since X = X^{1/2}X^{1/2}, we can compute the gradient of the + # matrix square root dX^{1/2} by solving the Sylvester equation: + # dX = (d(X^{1/2})X^{1/2} + X^{1/2}(dX^{1/2}). + grad_sqrtm = scipy.linalg.solve_sylvester(sqrtm, sqrtm, gm) + + grad_input = torch.from_numpy(grad_sqrtm).to(grad_output) + return grad_input + + +sqrtm = MatrixSquareRoot.apply + + +def _compute_fid(mu1: Tensor, sigma1: Tensor, mu2: Tensor, sigma2: Tensor, eps: float = 1e-6) -> Tensor: + r""" + Adjusted version of https://github.com/photosynthesis-team/piq/blob/master/piq/fid.py + + The Frechet Inception Distance between two multivariate Gaussians X_x ~ N(mu_1, sigm_1) + and X_y ~ N(mu_2, sigm_2) is d^2 = ||mu_1 - mu_2||^2 + Tr(sigm_1 + sigm_2 - 2*sqrt(sigm_1*sigm_2)). + + Args: + mu1: mean of activations calculated on predicted (x) samples + sigma1: covariance matrix over activations calculated on predicted (x) samples + mu2: mean of activations calculated on target (y) samples + sigma2: covariance matrix over activations calculated on target (y) samples + eps: offset constant. used if sigma_1 @ sigma_2 matrix is singular + + Returns: + Scalar value of the distance between sets. + """ + diff = mu1 - mu2 + + covmean = sqrtm(sigma1.mm(sigma2)) + # Product might be almost singular + if not torch.isfinite(covmean).all(): + rank_zero_info(f'FID calculation produces singular product; adding {eps} to diagonal of covariance estimates') + offset = torch.eye(sigma1.size(0), device=mu1.device, dtype=mu1.dtype) * eps + covmean = sqrtm((sigma1 + offset).mm(sigma2 + offset)) + + tr_covmean = torch.trace(covmean) + return diff.dot(diff) + torch.trace(sigma1) + torch.trace(sigma2) - 2 * tr_covmean + + +class FID(Metric): + r""" + Calculates `Fréchet inception distance (FID) `_ + which is used to access the quality of generated images. Given by + + .. math:: + FID = |\mu - \mu_w| + tr(\Sigma + \Sigma_w - 2(\Sigma \Sigma_w)^{\frac{1}{2}}) + + where :math:`\mathcal{N}(\mu, \Sigma)` is the multivariate normal distribution estimated from Inception v3 [1] + features calculated on real life images and :math:`\mathcal{N}(\mu_w, \Sigma_w)` is the multivariate normal + distribution estimated from Inception v3 features calculated on generated (fake) images. The metric was + originally proposed in [1]. + + Using the default feature extraction (Inception v3 using the original weights from [2]), the input is + expected to be mini-batches of 3-channel RGB images of shape (3 x H x W) with dtype uint8. All images + will be resized to 299 x 299 which is the size of the original training data. The boolian flag ``real`` + determines if the images should update the statistics of the real distribution or the fake distribution. + + .. note:: using this metric with the default feature extractor requires that ``torch-fidelity`` + is installed. Either install as ``pip install torchmetrics[image-quality]`` or + ``pip install torch-fidelity`` + + .. note:: the ``forward`` method can be used but ``compute_on_step`` is disabled by default (oppesit of + all other metrics) as this metric does not really make sense to calculate on a single batch. This + means that by default ``forward`` will just call ``update`` underneat. + + Args: + feature: + Either an integer or ``nn.Module``: + + - an integer will indicate the inceptionv3 feature layer to choose. Can be one of the following: + 64, 192, 768, 2048 + - an ``nn.Module`` for using a custom feature extractor. Expects that its forward method returns + an ``[N,d]`` matrix where ``N`` is the batch size and ``d`` is the feature size. + + compute_on_step: + Forward only calls ``update()`` and return ``None`` if this is set to ``False``. + dist_sync_on_step: + Synchronize metric state across processes at each ``forward()`` + before returning the value at the step + process_group: + Specify the process group on which synchronization is called. + default: ``None`` (which selects the entire world) + dist_sync_fn: + Callback that performs the allgather operation on the metric state. When ``None``, DDP + will be used to perform the allgather + + [1] Rethinking the Inception Architecture for Computer Vision + Christian Szegedy, Vincent Vanhoucke, Sergey Ioffe, Jonathon Shlens, Zbigniew Wojna + https://arxiv.org/abs/1512.00567 + + [2] GANs Trained by a Two Time-Scale Update Rule Converge to a Local Nash Equilibrium, + Martin Heusel, Hubert Ramsauer, Thomas Unterthiner, Bernhard Nessler, Sepp Hochreiter + https://arxiv.org/abs/1706.08500 + + Raises: + ValueError: + If ``feature`` is set to an ``int`` (default settings) and ``torch-fidelity`` is not installed + ValueError: + If ``feature`` is set to an ``int`` not in [64, 192, 768, 2048] + + Example: + >>> import torch + >>> _ = torch.manual_seed(123) + >>> from torchmetrics import FID + >>> fid = FID(feature=64) # doctest: +SKIP + >>> # generate two slightly overlapping image intensity distributions + >>> imgs_dist1 = torch.randint(0, 200, (100, 3, 299, 299), dtype=torch.uint8) # doctest: +SKIP + >>> imgs_dist2 = torch.randint(100, 255, (100, 3, 299, 299), dtype=torch.uint8) # doctest: +SKIP + >>> fid.update(imgs_dist1, real=True) # doctest: +SKIP + >>> fid.update(imgs_dist2, real=False) # doctest: +SKIP + >>> fid.compute() # doctest: +SKIP + tensor(12.7202) + + """ + + def __init__( + self, + feature: Union[int, torch.nn.Module] = 2048, + compute_on_step: bool = False, + dist_sync_on_step: bool = False, + process_group: Optional[Any] = None, + dist_sync_fn: Callable = None + ): + super().__init__( + compute_on_step=compute_on_step, + dist_sync_on_step=dist_sync_on_step, + process_group=process_group, + dist_sync_fn=dist_sync_fn, + ) + + rank_zero_warn( + 'Metric `FID` will save all extracted features in buffer.' + ' For large datasets this may lead to large memory footprint.', UserWarning + ) + + if isinstance(feature, int): + if not _TORCH_FIDELITY_AVAILABLE: + raise ValueError( + 'FID metric requires that Torch-fidelity is installed.' + 'Either install as `pip install torchmetrics[image-quality]` or `pip install torch-fidelity`' + ) + valid_int_input = [64, 192, 768, 2048] + if feature not in valid_int_input: + raise ValueError( + f'Integer input to argument `feature` must be one of {valid_int_input}, but got {feature}.' + ) + + self.inception = NoTrainInceptionV3(name='inception-v3-compat', features_list=[str(feature)]) + else: + self.inception = feature + + self.add_state("real_features", [], dist_reduce_fx=None) + self.add_state("fake_features", [], dist_reduce_fx=None) + + def update(self, imgs: Tensor, real: bool) -> None: # type: ignore + """ Update the state with extracted features + + Args: + imgs: tensor with images feed to the feature extractor + real: bool indicating if imgs belong to the real or the fake distribution + """ + features = self.inception(imgs) + + if real: + self.real_features.append(features) + else: + self.fake_features.append(features) + + def compute(self) -> Tensor: + """ Calculate FID score based on accumulated extracted features from the two distributions """ + real_features = torch.cat(self.real_features, dim=0) + fake_features = torch.cat(self.fake_features, dim=0) + # computation is extremely sensitive so it needs to happen in double precision + orig_dtype = real_features.dtype + real_features = real_features.double() + fake_features = fake_features.double() + + # calculate mean and covariance + n = real_features.shape[0] + mean1 = real_features.mean(dim=0) + mean2 = fake_features.mean(dim=0) + diff1 = real_features - mean1 + diff2 = fake_features - mean2 + cov1 = 1.0 / (n - 1) * diff1.t().mm(diff1) + cov2 = 1.0 / (n - 1) * diff2.t().mm(diff2) + + # compute fid + return _compute_fid(mean1, cov1, mean2, cov2).to(orig_dtype) diff --git a/torchmetrics/setup_tools.py b/torchmetrics/setup_tools.py index 208b3392c61..d26d689ca18 100644 --- a/torchmetrics/setup_tools.py +++ b/torchmetrics/setup_tools.py @@ -22,7 +22,7 @@ def _load_requirements(path_dir: str, file_name: str = 'requirements.txt', comme """Load requirements from a file >>> _load_requirements(_PROJECT_ROOT) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE - ['torch...'] + ['numpy...', 'torch...'] """ with open(os.path.join(path_dir, file_name), 'r') as file: lines = [ln.strip() for ln in file.readlines()] diff --git a/torchmetrics/utilities/imports.py b/torchmetrics/utilities/imports.py index b71e98da241..2bd4a67719a 100644 --- a/torchmetrics/utilities/imports.py +++ b/torchmetrics/utilities/imports.py @@ -74,3 +74,4 @@ def _compare_version(package: str, op, version) -> Optional[bool]: _TORCH_GREATER_EQUAL_1_6 = _compare_version("torch", operator.ge, "1.6.0") _TORCH_GREATER_EQUAL_1_7 = _compare_version("torch", operator.ge, "1.7.0") _LIGHTNING_AVAILABLE = _module_available("pytorch_lightning") +_TORCH_FIDELITY_AVAILABLE = _module_available("torch_fidelity") From 01b5c2e2e7abb66b16da3905799738127437c9fe Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Tue, 15 Jun 2021 11:56:25 +0200 Subject: [PATCH 043/109] fix cyclic import `_reduce_stat_scores` (#296) * fix cyclic import * flake8 Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .github/set-minimal-versions.py | 22 +++++++++++++++++----- torchmetrics/classification/stat_scores.py | 2 +- 2 files changed, 18 insertions(+), 6 deletions(-) diff --git a/.github/set-minimal-versions.py b/.github/set-minimal-versions.py index 210eeda49cc..978149aad90 100644 --- a/.github/set-minimal-versions.py +++ b/.github/set-minimal-versions.py @@ -1,3 +1,4 @@ +import os import re import sys @@ -5,22 +6,33 @@ '3.8': '1.4', '3.9': '1.7.1', } +REQUIREMENTS_FILES = ( + 'requirements.txt', + os.path.join('requirements', 'test.txt'), + os.path.join('requirements', 'integrate.txt'), +) def set_min_torch_by_python(fpath: str = 'requirements.txt') -> None: py_ver = f'{sys.version_info.major}.{sys.version_info.minor}' if py_ver not in LUT_PYTHON_TORCH: return - req = re.sub(r'torch>=[\d\.]+', f'torch>={LUT_PYTHON_TORCH[py_ver]}', open(fpath).read()) - open(fpath, 'w').write(req) + with open(fpath) as fp: + req = fp.read() + req = re.sub(r'torch>=[\d\.]+', f'torch>={LUT_PYTHON_TORCH[py_ver]}', req) + with open(fpath, 'w') as fp: + fp.write(req) def replace_min_requirements(fpath: str) -> None: - req = open(fpath).read().replace('>=', '==') - open(fpath, 'w').write(req) + with open(fpath) as fp: + req = fp.read() + req = req.replace('>=', '==') + with open(fpath, 'w') as fp: + fp.write(req) if __name__ == '__main__': set_min_torch_by_python() - for fpath in ('requirements.txt', 'requirements/test.txt', 'requirements/integrate.txt'): + for fpath in REQUIREMENTS_FILES: replace_min_requirements(fpath) diff --git a/torchmetrics/classification/stat_scores.py b/torchmetrics/classification/stat_scores.py index 8abf71513f9..d8c1b05063e 100644 --- a/torchmetrics/classification/stat_scores.py +++ b/torchmetrics/classification/stat_scores.py @@ -14,7 +14,7 @@ from typing import Any, Callable, Optional, Tuple import torch -from torch import Tensor, tensor +from torch import Tensor from torchmetrics.functional.classification.stat_scores import _stat_scores_compute, _stat_scores_update from torchmetrics.metric import Metric From d500663e70705580df1af06e244f9b601a744d7a Mon Sep 17 00:00:00 2001 From: quancs Date: Wed, 16 Jun 2021 02:21:32 +0800 Subject: [PATCH 044/109] update test_snr --- requirements/test.txt | 5 +-- tests/audio/test_snr.py | 95 ++++++++++++++++++++++++++++------------- 2 files changed, 67 insertions(+), 33 deletions(-) diff --git a/requirements/test.txt b/requirements/test.txt index f65d1d70196..a3e9efbde80 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -20,7 +20,6 @@ nltk>=3.6 # add extra requirements -r image.txt -asteroid==0.4.5 # v0.5 drop support for PT<1.8 mir_eval>=0.6 -soundfile -pysndfile +https://github.com/mpariente/pb_bss/archive/refs/heads/master.zip +https://github.com/aliutkus/speechmetrics/archive/refs/heads/master.zip diff --git a/tests/audio/test_snr.py b/tests/audio/test_snr.py index 170f3384ff9..2419dbace5b 100644 --- a/tests/audio/test_snr.py +++ b/tests/audio/test_snr.py @@ -13,11 +13,13 @@ # limitations under the License. from collections import namedtuple from functools import partial +from typing import Callable import pytest import torch -from asteroid.losses import pairwise_neg_snr -from mir_eval.separation import bss_eval_images +from mir_eval.separation import bss_eval_images as mir_eval_bss_eval_images +from museval.metrics import bss_eval_images as museval_bss_eval_images +from torch.tensor import Tensor from tests.helpers import seed_all from tests.helpers.testers import BATCH_SIZE, NUM_BATCHES, MetricTester @@ -37,38 +39,64 @@ ) -def asteroid_snr(preds, target): - # shape: preds [BATCH_SIZE, 1, Time] , target [BATCH_SIZE, 1, Time] - snr_v = -pairwise_neg_snr(preds, target) - return snr_v.view(BATCH_SIZE, 1) - - -def bss_eval_images_snr(preds, target): +def bss_eval_images_snr(preds: Tensor, target: Tensor, metric_func: Callable, + zero_mean: bool): # shape: preds [BATCH_SIZE, 1, Time] , target [BATCH_SIZE, 1, Time] + # or shape: preds [NUM_BATCHES*BATCH_SIZE, 1, Time] , target [NUM_BATCHES*BATCH_SIZE, 1, Time] snr_vb = [] - for j in range(BATCH_SIZE): - snr_v = bss_eval_images([target[j].view(-1).numpy()], [preds[j].view(-1).numpy()])[0][0][0] + for j in range(preds.shape[0]): + if zero_mean: + t = target[j] - target[j].mean() + e = preds[j] - preds[j].mean() + else: + t = target[j] + e = preds[j] + if metric_func == mir_eval_bss_eval_images: + snr_v = metric_func([t.view(-1).numpy()], + [e.view(-1).numpy()])[0][0] + else: + snr_v = metric_func([t.view(-1).numpy()], + [e.view(-1).numpy()])[0][0][0] snr_vb.append(snr_v) - return torch.tensor(snr_vb) + return torch.tensor(snr_vb).view(-1, 1) -def average_metric(preds, target, metric_func): +def average_metric(preds: Tensor, target: Tensor, metric_func: Callable): # shape: preds [BATCH_SIZE, 1, Time] , target [BATCH_SIZE, 1, Time] + # or shape: preds [NUM_BATCHES*BATCH_SIZE, 1, Time] , target [NUM_BATCHES*BATCH_SIZE, 1, Time] return metric_func(preds, target).mean() +mireval_snr_zeromean = partial(bss_eval_images_snr, + metric_func=mir_eval_bss_eval_images, + zero_mean=True) +mireval_snr_nozeromean = partial(bss_eval_images_snr, + metric_func=mir_eval_bss_eval_images, + zero_mean=False) +museval_snr_zeromean = partial(bss_eval_images_snr, + metric_func=museval_bss_eval_images, + zero_mean=True) +museval_snr_nozeromean = partial(bss_eval_images_snr, + metric_func=museval_bss_eval_images, + zero_mean=False) + + @pytest.mark.parametrize( "preds, target, sk_metric, zero_mean", [ - (inputs.preds, inputs.target, asteroid_snr, True), - (inputs.preds, inputs.target, bss_eval_images_snr, False), + (inputs.preds, inputs.target, mireval_snr_zeromean, True), + (inputs.preds, inputs.target, mireval_snr_nozeromean, False), + (inputs.preds, inputs.target, museval_snr_zeromean, True), + (inputs.preds, inputs.target, museval_snr_nozeromean, False), ], ) class TestSNR(MetricTester): + atol = 1e-5 @pytest.mark.parametrize("ddp", [True, False]) @pytest.mark.parametrize("dist_sync_on_step", [True, False]) - def test_snr(self, preds, target, sk_metric, zero_mean, ddp, dist_sync_on_step): + def test_snr(self, preds, target, sk_metric, zero_mean, ddp, + dist_sync_on_step): self.run_class_metric_test( ddp, preds, @@ -89,26 +117,33 @@ def test_snr_functional(self, preds, target, sk_metric, zero_mean): ) def test_snr_differentiability(self, preds, target, sk_metric, zero_mean): - self.run_differentiability_test( - preds=preds, target=target, metric_module=SNR, metric_functional=snr, metric_args={'zero_mean': zero_mean} - ) + self.run_differentiability_test(preds=preds, + target=target, + metric_module=SNR, + metric_functional=snr, + metric_args={'zero_mean': zero_mean}) @pytest.mark.skipif( - not _TORCH_GREATER_EQUAL_1_6, reason='half support of core operations on not support before pytorch v1.6' - ) + not _TORCH_GREATER_EQUAL_1_6, + reason= + 'half support of core operations on not support before pytorch v1.6') def test_snr_half_cpu(self, preds, target, sk_metric, zero_mean): - self.run_precision_test_cpu( - preds=preds, target=target, metric_module=SNR, metric_functional=snr, metric_args={'zero_mean': zero_mean} - ) + pytest.xfail("SNR metric does not support cpu + half precision") - @pytest.mark.skipif(not torch.cuda.is_available(), reason='test requires cuda') + @pytest.mark.skipif(not torch.cuda.is_available(), + reason='test requires cuda') def test_snr_half_gpu(self, preds, target, sk_metric, zero_mean): - self.run_precision_test_gpu( - preds=preds, target=target, metric_module=SNR, metric_functional=snr, metric_args={'zero_mean': zero_mean} - ) + self.run_precision_test_gpu(preds=preds, + target=target, + metric_module=SNR, + metric_functional=snr, + metric_args={'zero_mean': zero_mean}) def test_error_on_different_shape(metric_class=SNR): metric = metric_class() - with pytest.raises(RuntimeError, match='Predictions and targets are expected to have the same shape'): - metric(torch.randn(100, ), torch.randn(50, )) + with pytest.raises( + RuntimeError, + match='Predictions and targets are expected to have the same shape' + ): + metric(torch.randn(100,), torch.randn(50,)) From 3501dd056f4bc4036c5ba448892e9d7bffed613c Mon Sep 17 00:00:00 2001 From: quancs Date: Wed, 16 Jun 2021 02:22:12 +0800 Subject: [PATCH 045/109] update test_si_snr --- tests/audio/test_si_snr.py | 46 +++++++++++++++++++++++++++----------- 1 file changed, 33 insertions(+), 13 deletions(-) diff --git a/tests/audio/test_si_snr.py b/tests/audio/test_si_snr.py index e9d7bcfa8d8..6c7c3131705 100644 --- a/tests/audio/test_si_snr.py +++ b/tests/audio/test_si_snr.py @@ -16,13 +16,14 @@ import pytest import torch -from asteroid.losses import pairwise_neg_sisdr +from torch.tensor import Tensor from tests.helpers import seed_all from tests.helpers.testers import BATCH_SIZE, NUM_BATCHES, MetricTester from torchmetrics.audio import SI_SNR from torchmetrics.functional import si_snr from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_6 +from pb_bss_eval import OutputMetrics seed_all(42) @@ -36,21 +37,29 @@ ) -def asteroid_si_snr(preds, target): +def pb_bss_eval_si_snr(preds: Tensor, target: Tensor): # shape: preds [BATCH_SIZE, 1, Time] , target [BATCH_SIZE, 1, Time] - si_snr_v = -pairwise_neg_sisdr(preds, target) - return si_snr_v.view(BATCH_SIZE, 1) + # or shape: preds [NUM_BATCHES*BATCH_SIZE, 1, Time] , target [NUM_BATCHES*BATCH_SIZE, 1, Time] + preds = preds - preds.mean(dim=2, keepdim=True) + target = target - target.mean(dim=2, keepdim=True) + vs = [] + for i in range(preds.shape[0]): + om = OutputMetrics(preds[i], target[i], enable_si_sdr=True, compute_permutation=False) + si_snr_v = om['si_sdr'] + vs.append(si_snr_v) + return torch.tensor(vs).view(-1, 1) def average_metric(preds, target, metric_func): # shape: preds [BATCH_SIZE, 1, Time] , target [BATCH_SIZE, 1, Time] + # or shape: preds [NUM_BATCHES*BATCH_SIZE, 1, Time] , target [NUM_BATCHES*BATCH_SIZE, 1, Time] return metric_func(preds, target).mean() @pytest.mark.parametrize( "preds, target, sk_metric", [ - (inputs.preds, inputs.target, asteroid_si_snr), + (inputs.preds, inputs.target, pb_bss_eval_si_snr), ], ) class TestSISNR(MetricTester): @@ -76,20 +85,31 @@ def test_si_snr_functional(self, preds, target, sk_metric): ) def test_si_snr_differentiability(self, preds, target, sk_metric): - self.run_differentiability_test(preds=preds, target=target, metric_module=SI_SNR, metric_functional=si_snr) + self.run_differentiability_test(preds=preds, + target=target, + metric_module=SI_SNR, + metric_functional=si_snr) @pytest.mark.skipif( - not _TORCH_GREATER_EQUAL_1_6, reason='half support of core operations on not support before pytorch v1.6' - ) + not _TORCH_GREATER_EQUAL_1_6, + reason= + 'half support of core operations on not support before pytorch v1.6') def test_si_snr_half_cpu(self, preds, target, sk_metric): - self.run_precision_test_cpu(preds=preds, target=target, metric_module=SI_SNR, metric_functional=si_snr) + pytest.xfail("SI-SNR metric does not support cpu + half precision") - @pytest.mark.skipif(not torch.cuda.is_available(), reason='test requires cuda') + @pytest.mark.skipif(not torch.cuda.is_available(), + reason='test requires cuda') def test_si_snr_half_gpu(self, preds, target, sk_metric): - self.run_precision_test_gpu(preds=preds, target=target, metric_module=SI_SNR, metric_functional=si_snr) + self.run_precision_test_gpu(preds=preds, + target=target, + metric_module=SI_SNR, + metric_functional=si_snr) def test_error_on_different_shape(metric_class=SI_SNR): metric = metric_class() - with pytest.raises(RuntimeError, match='Predictions and targets are expected to have the same shape'): - metric(torch.randn(100, ), torch.randn(50, )) + with pytest.raises( + RuntimeError, + match='Predictions and targets are expected to have the same shape' + ): + metric(torch.randn(100,), torch.randn(50,)) From 7ae77789428390ab352082753a25fd2ea01f77c6 Mon Sep 17 00:00:00 2001 From: quancs Date: Wed, 16 Jun 2021 03:55:22 +0800 Subject: [PATCH 046/109] new snr: use torch.finfo(preds.dtype).eps --- torchmetrics/functional/audio/snr.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/torchmetrics/functional/audio/snr.py b/torchmetrics/functional/audio/snr.py index 4f04e5902a8..08658c164c7 100644 --- a/torchmetrics/functional/audio/snr.py +++ b/torchmetrics/functional/audio/snr.py @@ -44,6 +44,7 @@ def snr(preds: Tensor, target: Tensor, zero_mean: bool = False) -> Tensor: and Signal Processing (ICASSP) 2019. """ _check_same_shape(preds, target) + EPS = torch.finfo(preds.dtype).eps if zero_mean: target = target - torch.mean(target, dim=-1, keepdim=True) @@ -51,7 +52,7 @@ def snr(preds: Tensor, target: Tensor, zero_mean: bool = False) -> Tensor: noise = target - preds - snr_value = torch.sum(target**2, dim=-1) / (torch.sum(noise**2, dim=-1) + 1e-8) - snr_value = 10 * torch.log10(snr_value + 1e-8) + snr_value = (torch.sum(target**2, dim=-1) + EPS) / (torch.sum(noise**2, dim=-1) + EPS) + snr_value = 10 * torch.log10(snr_value) return snr_value From e7dfbee2cdd2be2f19785494a1ad1aabfffa98ad Mon Sep 17 00:00:00 2001 From: quancs Date: Wed, 16 Jun 2021 03:55:37 +0800 Subject: [PATCH 047/109] update test_snr.py --- tests/audio/test_snr.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/tests/audio/test_snr.py b/tests/audio/test_snr.py index 2419dbace5b..7aae8f78bbf 100644 --- a/tests/audio/test_snr.py +++ b/tests/audio/test_snr.py @@ -43,14 +43,13 @@ def bss_eval_images_snr(preds: Tensor, target: Tensor, metric_func: Callable, zero_mean: bool): # shape: preds [BATCH_SIZE, 1, Time] , target [BATCH_SIZE, 1, Time] # or shape: preds [NUM_BATCHES*BATCH_SIZE, 1, Time] , target [NUM_BATCHES*BATCH_SIZE, 1, Time] + if zero_mean: + target = target - torch.mean(target, dim=-1, keepdim=True) + preds = preds - torch.mean(preds, dim=-1, keepdim=True) snr_vb = [] for j in range(preds.shape[0]): - if zero_mean: - t = target[j] - target[j].mean() - e = preds[j] - preds[j].mean() - else: - t = target[j] - e = preds[j] + t = target[j] + e = preds[j] if metric_func == mir_eval_bss_eval_images: snr_v = metric_func([t.view(-1).numpy()], [e.view(-1).numpy()])[0][0] From de679134d862f6841a3ac450283e23111561065a Mon Sep 17 00:00:00 2001 From: quancs Date: Wed, 16 Jun 2021 04:08:17 +0800 Subject: [PATCH 048/109] new si_sdr imp --- torchmetrics/functional/audio/si_sdr.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/torchmetrics/functional/audio/si_sdr.py b/torchmetrics/functional/audio/si_sdr.py index 34396acd31e..5428405ca07 100644 --- a/torchmetrics/functional/audio/si_sdr.py +++ b/torchmetrics/functional/audio/si_sdr.py @@ -44,17 +44,20 @@ def si_sdr(preds: Tensor, target: Tensor, zero_mean: bool = False) -> Tensor: and Signal Processing (ICASSP) 2019. """ _check_same_shape(preds, target) + EPS = torch.finfo(preds.dtype).eps if zero_mean: target = target - torch.mean(target, dim=-1, keepdim=True) preds = preds - torch.mean(preds, dim=-1, keepdim=True) - α = torch.sum(preds * target, dim=-1, keepdim=True) / (torch.sum(target**2, dim=-1, keepdim=True) + 1e-8) + α = (torch.sum(preds * target, dim=-1, keepdim=True) + + EPS) / (torch.sum(target**2, dim=-1, keepdim=True) + EPS) target_scaled = α * target noise = target_scaled - preds - si_sdr_value = torch.sum(target_scaled**2, dim=-1) / (torch.sum(noise**2, dim=-1) + 1e-8) - si_sdr_value = 10 * torch.log10(si_sdr_value + 1e-8) + si_sdr_value = (torch.sum(target_scaled**2, dim=-1) + + EPS) / (torch.sum(noise**2, dim=-1) + EPS) + si_sdr_value = 10 * torch.log10(si_sdr_value) return si_sdr_value From 6de666ff079a0e81132a5be829bfd2601bf9a392 Mon Sep 17 00:00:00 2001 From: quancs Date: Wed, 16 Jun 2021 04:37:07 +0800 Subject: [PATCH 049/109] update test_si_sdr --- tests/audio/test_si_sdr.py | 90 ++++++++++++++++++++++---------------- 1 file changed, 53 insertions(+), 37 deletions(-) diff --git a/tests/audio/test_si_sdr.py b/tests/audio/test_si_sdr.py index 20a301c9345..48d62038b76 100644 --- a/tests/audio/test_si_sdr.py +++ b/tests/audio/test_si_sdr.py @@ -16,7 +16,7 @@ import pytest import torch -from asteroid.losses import PairwiseNegSDR +from torch.tensor import Tensor from tests.helpers import seed_all from tests.helpers.testers import BATCH_SIZE, NUM_BATCHES, MetricTester @@ -24,6 +24,8 @@ from torchmetrics.functional import si_sdr from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_6 +import speechmetrics + seed_all(42) Time = 1000 @@ -35,40 +37,58 @@ target=torch.rand(NUM_BATCHES, BATCH_SIZE, 1, Time), ) +speechmetrics_sisdr = speechmetrics.load('sisdr') + -def asteroid_metric(preds, target, asteroid_loss_func): +def speechmetrics_si_sdr(preds: Tensor, target: Tensor, zero_mean: bool): # shape: preds [BATCH_SIZE, 1, Time] , target [BATCH_SIZE, 1, Time] - metric = -asteroid_loss_func(preds, target) - return metric.view(BATCH_SIZE, 1) + # or shape: preds [NUM_BATCHES*BATCH_SIZE, 1, Time] , target [NUM_BATCHES*BATCH_SIZE, 1, Time] + if zero_mean: + preds = preds - preds.mean(dim=2, keepdim=True) + target = target - target.mean(dim=2, keepdim=True) + mss = [] + for i in range(preds.shape[0]): + ms = [] + for j in range(preds.shape[1]): + metric = speechmetrics_sisdr(preds[i, j].numpy(), + target[i, j].numpy(), + rate=16000) + ms.append(metric['sisdr'][0]) + mss.append(ms) + return torch.tensor(mss) def average_metric(preds, target, metric_func): # shape: preds [BATCH_SIZE, 1, Time] , target [BATCH_SIZE, 1, Time] + # or shape: preds [NUM_BATCHES*BATCH_SIZE, 1, Time] , target [NUM_BATCHES*BATCH_SIZE, 1, Time] return metric_func(preds, target).mean() -asteroid_sisdr_zero_mean = partial(asteroid_metric, asteroid_loss_func=PairwiseNegSDR("sisdr")) -asteroid_sisdr_no_zero_mean = partial(asteroid_metric, asteroid_loss_func=PairwiseNegSDR("sisdr", zero_mean=False)) +speechmetrics_si_sdr_zero_mean = partial(speechmetrics_si_sdr, zero_mean=True) +speechmetrics_si_sdr_no_zero_mean = partial(speechmetrics_si_sdr, + zero_mean=False) @pytest.mark.parametrize( "preds, target, sk_metric, zero_mean", [ - (inputs.preds, inputs.target, asteroid_sisdr_zero_mean, True), - (inputs.preds, inputs.target, asteroid_sisdr_no_zero_mean, False), + (inputs.preds, inputs.target, speechmetrics_si_sdr_zero_mean, True), + (inputs.preds, inputs.target, speechmetrics_si_sdr_no_zero_mean, + False), ], ) class TestSISDR(MetricTester): @pytest.mark.parametrize("ddp", [True, False]) @pytest.mark.parametrize("dist_sync_on_step", [True, False]) - def test_si_sdr(self, preds, target, sk_metric, zero_mean, ddp, dist_sync_on_step): + def test_si_sdr(self, preds, target, sk_metric, zero_mean, ddp, + dist_sync_on_step): self.run_class_metric_test( ddp, preds, target, SI_SDR, - sk_metric=sk_metric, + sk_metric=partial(average_metric, metric_func=sk_metric), dist_sync_on_step=dist_sync_on_step, metric_args=dict(zero_mean=zero_mean), ) @@ -82,39 +102,35 @@ def test_si_sdr_functional(self, preds, target, sk_metric, zero_mean): metric_args=dict(zero_mean=zero_mean), ) - def test_si_sdr_differentiability(self, preds, target, sk_metric, zero_mean): - self.run_differentiability_test( - preds=preds, - target=target, - metric_module=SI_SDR, - metric_functional=si_sdr, - metric_args={'zero_mean': zero_mean} - ) + def test_si_sdr_differentiability(self, preds, target, sk_metric, + zero_mean): + self.run_differentiability_test(preds=preds, + target=target, + metric_module=SI_SDR, + metric_functional=si_sdr, + metric_args={'zero_mean': zero_mean}) @pytest.mark.skipif( - not _TORCH_GREATER_EQUAL_1_6, reason='half support of core operations on not support before pytorch v1.6' - ) + not _TORCH_GREATER_EQUAL_1_6, + reason= + 'half support of core operations on not support before pytorch v1.6') def test_si_sdr_half_cpu(self, preds, target, sk_metric, zero_mean): - self.run_precision_test_cpu( - preds=preds, - target=target, - metric_module=SI_SDR, - metric_functional=si_sdr, - metric_args={'zero_mean': zero_mean} - ) + pytest.xfail("SI-SDR metric does not support cpu + half precision") - @pytest.mark.skipif(not torch.cuda.is_available(), reason='test requires cuda') + @pytest.mark.skipif(not torch.cuda.is_available(), + reason='test requires cuda') def test_si_sdr_half_gpu(self, preds, target, sk_metric, zero_mean): - self.run_precision_test_gpu( - preds=preds, - target=target, - metric_module=SI_SDR, - metric_functional=si_sdr, - metric_args={'zero_mean': zero_mean} - ) + self.run_precision_test_gpu(preds=preds, + target=target, + metric_module=SI_SDR, + metric_functional=si_sdr, + metric_args={'zero_mean': zero_mean}) def test_error_on_different_shape(metric_class=SI_SDR): metric = metric_class() - with pytest.raises(RuntimeError, match='Predictions and targets are expected to have the same shape'): - metric(torch.randn(100, ), torch.randn(50, )) + with pytest.raises( + RuntimeError, + match='Predictions and targets are expected to have the same shape' + ): + metric(torch.randn(100,), torch.randn(50,)) From 5fdbb0d7d6d33dc7a4cd069e64c0c159d6d59ea5 Mon Sep 17 00:00:00 2001 From: quancs Date: Wed, 16 Jun 2021 04:39:17 +0800 Subject: [PATCH 050/109] update test_si_snr --- tests/audio/test_si_snr.py | 30 ++++++++++++++++++++---------- 1 file changed, 20 insertions(+), 10 deletions(-) diff --git a/tests/audio/test_si_snr.py b/tests/audio/test_si_snr.py index 6c7c3131705..e3b1a73a4b1 100644 --- a/tests/audio/test_si_snr.py +++ b/tests/audio/test_si_snr.py @@ -23,7 +23,8 @@ from torchmetrics.audio import SI_SNR from torchmetrics.functional import si_snr from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_6 -from pb_bss_eval import OutputMetrics + +import speechmetrics seed_all(42) @@ -36,18 +37,27 @@ target=torch.rand(NUM_BATCHES, BATCH_SIZE, 1, Time), ) +speechmetrics_sisdr = speechmetrics.load('sisdr') + -def pb_bss_eval_si_snr(preds: Tensor, target: Tensor): +def speechmetrics_si_sdr(preds: Tensor, + target: Tensor, + zero_mean: bool = True): # shape: preds [BATCH_SIZE, 1, Time] , target [BATCH_SIZE, 1, Time] # or shape: preds [NUM_BATCHES*BATCH_SIZE, 1, Time] , target [NUM_BATCHES*BATCH_SIZE, 1, Time] - preds = preds - preds.mean(dim=2, keepdim=True) - target = target - target.mean(dim=2, keepdim=True) - vs = [] + if zero_mean: + preds = preds - preds.mean(dim=2, keepdim=True) + target = target - target.mean(dim=2, keepdim=True) + mss = [] for i in range(preds.shape[0]): - om = OutputMetrics(preds[i], target[i], enable_si_sdr=True, compute_permutation=False) - si_snr_v = om['si_sdr'] - vs.append(si_snr_v) - return torch.tensor(vs).view(-1, 1) + ms = [] + for j in range(preds.shape[1]): + metric = speechmetrics_sisdr(preds[i, j].numpy(), + target[i, j].numpy(), + rate=16000) + ms.append(metric['sisdr'][0]) + mss.append(ms) + return torch.tensor(mss) def average_metric(preds, target, metric_func): @@ -59,7 +69,7 @@ def average_metric(preds, target, metric_func): @pytest.mark.parametrize( "preds, target, sk_metric", [ - (inputs.preds, inputs.target, pb_bss_eval_si_snr), + (inputs.preds, inputs.target, speechmetrics_si_sdr), ], ) class TestSISNR(MetricTester): From 608b21aacd74bce9ad210234c99e9327ffc31bb6 Mon Sep 17 00:00:00 2001 From: quancs Date: Wed, 16 Jun 2021 04:40:15 +0800 Subject: [PATCH 051/109] remove pb_bss_eval --- requirements/test.txt | 1 - 1 file changed, 1 deletion(-) diff --git a/requirements/test.txt b/requirements/test.txt index a3e9efbde80..5251a06f3be 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -21,5 +21,4 @@ nltk>=3.6 -r image.txt mir_eval>=0.6 -https://github.com/mpariente/pb_bss/archive/refs/heads/master.zip https://github.com/aliutkus/speechmetrics/archive/refs/heads/master.zip From aea0091f1a26f5ae2d59c5772cc6ebdfc0d349c1 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 15 Jun 2021 20:41:10 +0000 Subject: [PATCH 052/109] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/audio/test_si_sdr.py | 58 +++++++++++------------- tests/audio/test_si_snr.py | 36 +++++---------- tests/audio/test_snr.py | 59 ++++++++----------------- torchmetrics/functional/audio/si_sdr.py | 6 +-- 4 files changed, 56 insertions(+), 103 deletions(-) diff --git a/tests/audio/test_si_sdr.py b/tests/audio/test_si_sdr.py index 48d62038b76..dc017e807c2 100644 --- a/tests/audio/test_si_sdr.py +++ b/tests/audio/test_si_sdr.py @@ -15,6 +15,7 @@ from functools import partial import pytest +import speechmetrics import torch from torch.tensor import Tensor @@ -24,8 +25,6 @@ from torchmetrics.functional import si_sdr from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_6 -import speechmetrics - seed_all(42) Time = 1000 @@ -50,9 +49,7 @@ def speechmetrics_si_sdr(preds: Tensor, target: Tensor, zero_mean: bool): for i in range(preds.shape[0]): ms = [] for j in range(preds.shape[1]): - metric = speechmetrics_sisdr(preds[i, j].numpy(), - target[i, j].numpy(), - rate=16000) + metric = speechmetrics_sisdr(preds[i, j].numpy(), target[i, j].numpy(), rate=16000) ms.append(metric['sisdr'][0]) mss.append(ms) return torch.tensor(mss) @@ -65,24 +62,21 @@ def average_metric(preds, target, metric_func): speechmetrics_si_sdr_zero_mean = partial(speechmetrics_si_sdr, zero_mean=True) -speechmetrics_si_sdr_no_zero_mean = partial(speechmetrics_si_sdr, - zero_mean=False) +speechmetrics_si_sdr_no_zero_mean = partial(speechmetrics_si_sdr, zero_mean=False) @pytest.mark.parametrize( "preds, target, sk_metric, zero_mean", [ (inputs.preds, inputs.target, speechmetrics_si_sdr_zero_mean, True), - (inputs.preds, inputs.target, speechmetrics_si_sdr_no_zero_mean, - False), + (inputs.preds, inputs.target, speechmetrics_si_sdr_no_zero_mean, False), ], ) class TestSISDR(MetricTester): @pytest.mark.parametrize("ddp", [True, False]) @pytest.mark.parametrize("dist_sync_on_step", [True, False]) - def test_si_sdr(self, preds, target, sk_metric, zero_mean, ddp, - dist_sync_on_step): + def test_si_sdr(self, preds, target, sk_metric, zero_mean, ddp, dist_sync_on_step): self.run_class_metric_test( ddp, preds, @@ -102,35 +96,33 @@ def test_si_sdr_functional(self, preds, target, sk_metric, zero_mean): metric_args=dict(zero_mean=zero_mean), ) - def test_si_sdr_differentiability(self, preds, target, sk_metric, - zero_mean): - self.run_differentiability_test(preds=preds, - target=target, - metric_module=SI_SDR, - metric_functional=si_sdr, - metric_args={'zero_mean': zero_mean}) + def test_si_sdr_differentiability(self, preds, target, sk_metric, zero_mean): + self.run_differentiability_test( + preds=preds, + target=target, + metric_module=SI_SDR, + metric_functional=si_sdr, + metric_args={'zero_mean': zero_mean} + ) @pytest.mark.skipif( - not _TORCH_GREATER_EQUAL_1_6, - reason= - 'half support of core operations on not support before pytorch v1.6') + not _TORCH_GREATER_EQUAL_1_6, reason='half support of core operations on not support before pytorch v1.6' + ) def test_si_sdr_half_cpu(self, preds, target, sk_metric, zero_mean): pytest.xfail("SI-SDR metric does not support cpu + half precision") - @pytest.mark.skipif(not torch.cuda.is_available(), - reason='test requires cuda') + @pytest.mark.skipif(not torch.cuda.is_available(), reason='test requires cuda') def test_si_sdr_half_gpu(self, preds, target, sk_metric, zero_mean): - self.run_precision_test_gpu(preds=preds, - target=target, - metric_module=SI_SDR, - metric_functional=si_sdr, - metric_args={'zero_mean': zero_mean}) + self.run_precision_test_gpu( + preds=preds, + target=target, + metric_module=SI_SDR, + metric_functional=si_sdr, + metric_args={'zero_mean': zero_mean} + ) def test_error_on_different_shape(metric_class=SI_SDR): metric = metric_class() - with pytest.raises( - RuntimeError, - match='Predictions and targets are expected to have the same shape' - ): - metric(torch.randn(100,), torch.randn(50,)) + with pytest.raises(RuntimeError, match='Predictions and targets are expected to have the same shape'): + metric(torch.randn(100, ), torch.randn(50, )) diff --git a/tests/audio/test_si_snr.py b/tests/audio/test_si_snr.py index e3b1a73a4b1..dff6710441f 100644 --- a/tests/audio/test_si_snr.py +++ b/tests/audio/test_si_snr.py @@ -15,6 +15,7 @@ from functools import partial import pytest +import speechmetrics import torch from torch.tensor import Tensor @@ -24,8 +25,6 @@ from torchmetrics.functional import si_snr from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_6 -import speechmetrics - seed_all(42) Time = 1000 @@ -40,9 +39,7 @@ speechmetrics_sisdr = speechmetrics.load('sisdr') -def speechmetrics_si_sdr(preds: Tensor, - target: Tensor, - zero_mean: bool = True): +def speechmetrics_si_sdr(preds: Tensor, target: Tensor, zero_mean: bool = True): # shape: preds [BATCH_SIZE, 1, Time] , target [BATCH_SIZE, 1, Time] # or shape: preds [NUM_BATCHES*BATCH_SIZE, 1, Time] , target [NUM_BATCHES*BATCH_SIZE, 1, Time] if zero_mean: @@ -52,9 +49,7 @@ def speechmetrics_si_sdr(preds: Tensor, for i in range(preds.shape[0]): ms = [] for j in range(preds.shape[1]): - metric = speechmetrics_sisdr(preds[i, j].numpy(), - target[i, j].numpy(), - rate=16000) + metric = speechmetrics_sisdr(preds[i, j].numpy(), target[i, j].numpy(), rate=16000) ms.append(metric['sisdr'][0]) mss.append(ms) return torch.tensor(mss) @@ -95,31 +90,20 @@ def test_si_snr_functional(self, preds, target, sk_metric): ) def test_si_snr_differentiability(self, preds, target, sk_metric): - self.run_differentiability_test(preds=preds, - target=target, - metric_module=SI_SNR, - metric_functional=si_snr) + self.run_differentiability_test(preds=preds, target=target, metric_module=SI_SNR, metric_functional=si_snr) @pytest.mark.skipif( - not _TORCH_GREATER_EQUAL_1_6, - reason= - 'half support of core operations on not support before pytorch v1.6') + not _TORCH_GREATER_EQUAL_1_6, reason='half support of core operations on not support before pytorch v1.6' + ) def test_si_snr_half_cpu(self, preds, target, sk_metric): pytest.xfail("SI-SNR metric does not support cpu + half precision") - @pytest.mark.skipif(not torch.cuda.is_available(), - reason='test requires cuda') + @pytest.mark.skipif(not torch.cuda.is_available(), reason='test requires cuda') def test_si_snr_half_gpu(self, preds, target, sk_metric): - self.run_precision_test_gpu(preds=preds, - target=target, - metric_module=SI_SNR, - metric_functional=si_snr) + self.run_precision_test_gpu(preds=preds, target=target, metric_module=SI_SNR, metric_functional=si_snr) def test_error_on_different_shape(metric_class=SI_SNR): metric = metric_class() - with pytest.raises( - RuntimeError, - match='Predictions and targets are expected to have the same shape' - ): - metric(torch.randn(100,), torch.randn(50,)) + with pytest.raises(RuntimeError, match='Predictions and targets are expected to have the same shape'): + metric(torch.randn(100, ), torch.randn(50, )) diff --git a/tests/audio/test_snr.py b/tests/audio/test_snr.py index 7aae8f78bbf..33072632314 100644 --- a/tests/audio/test_snr.py +++ b/tests/audio/test_snr.py @@ -39,8 +39,7 @@ ) -def bss_eval_images_snr(preds: Tensor, target: Tensor, metric_func: Callable, - zero_mean: bool): +def bss_eval_images_snr(preds: Tensor, target: Tensor, metric_func: Callable, zero_mean: bool): # shape: preds [BATCH_SIZE, 1, Time] , target [BATCH_SIZE, 1, Time] # or shape: preds [NUM_BATCHES*BATCH_SIZE, 1, Time] , target [NUM_BATCHES*BATCH_SIZE, 1, Time] if zero_mean: @@ -51,11 +50,9 @@ def bss_eval_images_snr(preds: Tensor, target: Tensor, metric_func: Callable, t = target[j] e = preds[j] if metric_func == mir_eval_bss_eval_images: - snr_v = metric_func([t.view(-1).numpy()], - [e.view(-1).numpy()])[0][0] + snr_v = metric_func([t.view(-1).numpy()], [e.view(-1).numpy()])[0][0] else: - snr_v = metric_func([t.view(-1).numpy()], - [e.view(-1).numpy()])[0][0][0] + snr_v = metric_func([t.view(-1).numpy()], [e.view(-1).numpy()])[0][0][0] snr_vb.append(snr_v) return torch.tensor(snr_vb).view(-1, 1) @@ -66,18 +63,10 @@ def average_metric(preds: Tensor, target: Tensor, metric_func: Callable): return metric_func(preds, target).mean() -mireval_snr_zeromean = partial(bss_eval_images_snr, - metric_func=mir_eval_bss_eval_images, - zero_mean=True) -mireval_snr_nozeromean = partial(bss_eval_images_snr, - metric_func=mir_eval_bss_eval_images, - zero_mean=False) -museval_snr_zeromean = partial(bss_eval_images_snr, - metric_func=museval_bss_eval_images, - zero_mean=True) -museval_snr_nozeromean = partial(bss_eval_images_snr, - metric_func=museval_bss_eval_images, - zero_mean=False) +mireval_snr_zeromean = partial(bss_eval_images_snr, metric_func=mir_eval_bss_eval_images, zero_mean=True) +mireval_snr_nozeromean = partial(bss_eval_images_snr, metric_func=mir_eval_bss_eval_images, zero_mean=False) +museval_snr_zeromean = partial(bss_eval_images_snr, metric_func=museval_bss_eval_images, zero_mean=True) +museval_snr_nozeromean = partial(bss_eval_images_snr, metric_func=museval_bss_eval_images, zero_mean=False) @pytest.mark.parametrize( @@ -94,8 +83,7 @@ class TestSNR(MetricTester): @pytest.mark.parametrize("ddp", [True, False]) @pytest.mark.parametrize("dist_sync_on_step", [True, False]) - def test_snr(self, preds, target, sk_metric, zero_mean, ddp, - dist_sync_on_step): + def test_snr(self, preds, target, sk_metric, zero_mean, ddp, dist_sync_on_step): self.run_class_metric_test( ddp, preds, @@ -116,33 +104,24 @@ def test_snr_functional(self, preds, target, sk_metric, zero_mean): ) def test_snr_differentiability(self, preds, target, sk_metric, zero_mean): - self.run_differentiability_test(preds=preds, - target=target, - metric_module=SNR, - metric_functional=snr, - metric_args={'zero_mean': zero_mean}) + self.run_differentiability_test( + preds=preds, target=target, metric_module=SNR, metric_functional=snr, metric_args={'zero_mean': zero_mean} + ) @pytest.mark.skipif( - not _TORCH_GREATER_EQUAL_1_6, - reason= - 'half support of core operations on not support before pytorch v1.6') + not _TORCH_GREATER_EQUAL_1_6, reason='half support of core operations on not support before pytorch v1.6' + ) def test_snr_half_cpu(self, preds, target, sk_metric, zero_mean): pytest.xfail("SNR metric does not support cpu + half precision") - @pytest.mark.skipif(not torch.cuda.is_available(), - reason='test requires cuda') + @pytest.mark.skipif(not torch.cuda.is_available(), reason='test requires cuda') def test_snr_half_gpu(self, preds, target, sk_metric, zero_mean): - self.run_precision_test_gpu(preds=preds, - target=target, - metric_module=SNR, - metric_functional=snr, - metric_args={'zero_mean': zero_mean}) + self.run_precision_test_gpu( + preds=preds, target=target, metric_module=SNR, metric_functional=snr, metric_args={'zero_mean': zero_mean} + ) def test_error_on_different_shape(metric_class=SNR): metric = metric_class() - with pytest.raises( - RuntimeError, - match='Predictions and targets are expected to have the same shape' - ): - metric(torch.randn(100,), torch.randn(50,)) + with pytest.raises(RuntimeError, match='Predictions and targets are expected to have the same shape'): + metric(torch.randn(100, ), torch.randn(50, )) diff --git a/torchmetrics/functional/audio/si_sdr.py b/torchmetrics/functional/audio/si_sdr.py index 5428405ca07..229aadf1c0c 100644 --- a/torchmetrics/functional/audio/si_sdr.py +++ b/torchmetrics/functional/audio/si_sdr.py @@ -50,14 +50,12 @@ def si_sdr(preds: Tensor, target: Tensor, zero_mean: bool = False) -> Tensor: target = target - torch.mean(target, dim=-1, keepdim=True) preds = preds - torch.mean(preds, dim=-1, keepdim=True) - α = (torch.sum(preds * target, dim=-1, keepdim=True) + - EPS) / (torch.sum(target**2, dim=-1, keepdim=True) + EPS) + α = (torch.sum(preds * target, dim=-1, keepdim=True) + EPS) / (torch.sum(target**2, dim=-1, keepdim=True) + EPS) target_scaled = α * target noise = target_scaled - preds - si_sdr_value = (torch.sum(target_scaled**2, dim=-1) + - EPS) / (torch.sum(noise**2, dim=-1) + EPS) + si_sdr_value = (torch.sum(target_scaled**2, dim=-1) + EPS) / (torch.sum(noise**2, dim=-1) + EPS) si_sdr_value = 10 * torch.log10(si_sdr_value) return si_sdr_value From 0add5779d1a9c9ea9603cef568f8d617f200db01 Mon Sep 17 00:00:00 2001 From: quancs Date: Wed, 16 Jun 2021 04:42:03 +0800 Subject: [PATCH 053/109] add museval --- requirements/test.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements/test.txt b/requirements/test.txt index 5251a06f3be..e0772defd95 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -21,4 +21,5 @@ nltk>=3.6 -r image.txt mir_eval>=0.6 +museval https://github.com/aliutkus/speechmetrics/archive/refs/heads/master.zip From 74d0550fa9b519e79ba678985d73b800ec3f0198 Mon Sep 17 00:00:00 2001 From: quancs Date: Wed, 16 Jun 2021 12:08:38 +0800 Subject: [PATCH 054/109] update test files --- tests/audio/test_si_sdr.py | 2 +- tests/audio/test_si_snr.py | 2 +- tests/audio/test_snr.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/audio/test_si_sdr.py b/tests/audio/test_si_sdr.py index dc017e807c2..36f950e9da2 100644 --- a/tests/audio/test_si_sdr.py +++ b/tests/audio/test_si_sdr.py @@ -17,7 +17,7 @@ import pytest import speechmetrics import torch -from torch.tensor import Tensor +from torch import Tensor from tests.helpers import seed_all from tests.helpers.testers import BATCH_SIZE, NUM_BATCHES, MetricTester diff --git a/tests/audio/test_si_snr.py b/tests/audio/test_si_snr.py index dff6710441f..526c9f7882d 100644 --- a/tests/audio/test_si_snr.py +++ b/tests/audio/test_si_snr.py @@ -17,7 +17,7 @@ import pytest import speechmetrics import torch -from torch.tensor import Tensor +from torch import Tensor from tests.helpers import seed_all from tests.helpers.testers import BATCH_SIZE, NUM_BATCHES, MetricTester diff --git a/tests/audio/test_snr.py b/tests/audio/test_snr.py index 33072632314..7695f1137fd 100644 --- a/tests/audio/test_snr.py +++ b/tests/audio/test_snr.py @@ -19,7 +19,7 @@ import torch from mir_eval.separation import bss_eval_images as mir_eval_bss_eval_images from museval.metrics import bss_eval_images as museval_bss_eval_images -from torch.tensor import Tensor +from torch import Tensor from tests.helpers import seed_all from tests.helpers.testers import BATCH_SIZE, NUM_BATCHES, MetricTester From f7249b1806221194ba23662271511da4892d99c8 Mon Sep 17 00:00:00 2001 From: quancs Date: Wed, 16 Jun 2021 12:16:02 +0800 Subject: [PATCH 055/109] remove museval --- requirements/test.txt | 1 - tests/audio/test_snr.py | 5 ----- 2 files changed, 6 deletions(-) diff --git a/requirements/test.txt b/requirements/test.txt index e0772defd95..5251a06f3be 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -21,5 +21,4 @@ nltk>=3.6 -r image.txt mir_eval>=0.6 -museval https://github.com/aliutkus/speechmetrics/archive/refs/heads/master.zip diff --git a/tests/audio/test_snr.py b/tests/audio/test_snr.py index 7695f1137fd..3a228b5ea3f 100644 --- a/tests/audio/test_snr.py +++ b/tests/audio/test_snr.py @@ -18,7 +18,6 @@ import pytest import torch from mir_eval.separation import bss_eval_images as mir_eval_bss_eval_images -from museval.metrics import bss_eval_images as museval_bss_eval_images from torch import Tensor from tests.helpers import seed_all @@ -65,8 +64,6 @@ def average_metric(preds: Tensor, target: Tensor, metric_func: Callable): mireval_snr_zeromean = partial(bss_eval_images_snr, metric_func=mir_eval_bss_eval_images, zero_mean=True) mireval_snr_nozeromean = partial(bss_eval_images_snr, metric_func=mir_eval_bss_eval_images, zero_mean=False) -museval_snr_zeromean = partial(bss_eval_images_snr, metric_func=museval_bss_eval_images, zero_mean=True) -museval_snr_nozeromean = partial(bss_eval_images_snr, metric_func=museval_bss_eval_images, zero_mean=False) @pytest.mark.parametrize( @@ -74,8 +71,6 @@ def average_metric(preds: Tensor, target: Tensor, metric_func: Callable): [ (inputs.preds, inputs.target, mireval_snr_zeromean, True), (inputs.preds, inputs.target, mireval_snr_nozeromean, False), - (inputs.preds, inputs.target, museval_snr_zeromean, True), - (inputs.preds, inputs.target, museval_snr_nozeromean, False), ], ) class TestSNR(MetricTester): From 8f3e032984751d4d8b344e6f1c57137302f52f1c Mon Sep 17 00:00:00 2001 From: quancs Date: Wed, 16 Jun 2021 12:18:38 +0800 Subject: [PATCH 056/109] add funcs update return None annotation --- torchmetrics/audio/SI_SDR.py | 2 +- torchmetrics/audio/SI_SNR.py | 2 +- torchmetrics/audio/SNR.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/torchmetrics/audio/SI_SDR.py b/torchmetrics/audio/SI_SDR.py index 22bb4f1a3ec..2f42c2a0ca5 100644 --- a/torchmetrics/audio/SI_SDR.py +++ b/torchmetrics/audio/SI_SDR.py @@ -82,7 +82,7 @@ def __init__( self.add_state("sum_si_sdr", default=tensor(0.0), dist_reduce_fx="sum") self.add_state("total", default=tensor(0), dist_reduce_fx="sum") - def update(self, preds: Tensor, target: Tensor): + def update(self, preds: Tensor, target: Tensor) -> None: """ Update state with predictions and targets. diff --git a/torchmetrics/audio/SI_SNR.py b/torchmetrics/audio/SI_SNR.py index 98c67e835d7..10c9d840174 100644 --- a/torchmetrics/audio/SI_SNR.py +++ b/torchmetrics/audio/SI_SNR.py @@ -79,7 +79,7 @@ def __init__( self.add_state("sum_si_snr", default=tensor(0.0), dist_reduce_fx="sum") self.add_state("total", default=tensor(0), dist_reduce_fx="sum") - def update(self, preds: Tensor, target: Tensor): + def update(self, preds: Tensor, target: Tensor) -> None: """ Update state with predictions and targets. diff --git a/torchmetrics/audio/SNR.py b/torchmetrics/audio/SNR.py index 77510b43be0..223063cc1ac 100644 --- a/torchmetrics/audio/SNR.py +++ b/torchmetrics/audio/SNR.py @@ -82,7 +82,7 @@ def __init__( self.add_state("sum_snr", default=tensor(0.0), dist_reduce_fx="sum") self.add_state("total", default=tensor(0), dist_reduce_fx="sum") - def update(self, preds: Tensor, target: Tensor): + def update(self, preds: Tensor, target: Tensor) -> None: """ Update state with predictions and targets. From be1ca0b0763d040c042a13f8f8aadda82b0d9884 Mon Sep 17 00:00:00 2001 From: quancs Date: Wed, 16 Jun 2021 17:08:48 +0800 Subject: [PATCH 057/109] add 'Setup ffmpeg' --- .github/workflows/ci_test-full.yml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/.github/workflows/ci_test-full.yml b/.github/workflows/ci_test-full.yml index 93558bfa376..ad40dc65498 100644 --- a/.github/workflows/ci_test-full.yml +++ b/.github/workflows/ci_test-full.yml @@ -45,6 +45,10 @@ jobs: run: | brew install gcc libomp # https://github.com/pytorch/pytorch/issues/20030 + - name: Setup ffmpeg # for speechmetrics + run: | + conda install -c conda-forge ffmpeg + - name: Set min. dependencies if: matrix.requires == 'minimal' run: | From 134d8cdbe48686f3a4a8a15849f04a1ccd891967 Mon Sep 17 00:00:00 2001 From: quancs Date: Wed, 16 Jun 2021 17:16:25 +0800 Subject: [PATCH 058/109] update "Setup ffmpeg" --- .github/workflows/ci_test-full.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci_test-full.yml b/.github/workflows/ci_test-full.yml index ad40dc65498..266104ce9be 100644 --- a/.github/workflows/ci_test-full.yml +++ b/.github/workflows/ci_test-full.yml @@ -47,7 +47,7 @@ jobs: - name: Setup ffmpeg # for speechmetrics run: | - conda install -c conda-forge ffmpeg + $CONDA/bin/conda install -c conda-forge ffmpeg - name: Set min. dependencies if: matrix.requires == 'minimal' From 794d1ce469c9a2661bafc8a3d5b778a1e89a84f7 Mon Sep 17 00:00:00 2001 From: quancs Date: Wed, 16 Jun 2021 17:21:38 +0800 Subject: [PATCH 059/109] use setup-conda@v1 --- .github/workflows/ci_test-full.yml | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/.github/workflows/ci_test-full.yml b/.github/workflows/ci_test-full.yml index 266104ce9be..fb6e7a4fd65 100644 --- a/.github/workflows/ci_test-full.yml +++ b/.github/workflows/ci_test-full.yml @@ -45,9 +45,10 @@ jobs: run: | brew install gcc libomp # https://github.com/pytorch/pytorch/issues/20030 - - name: Setup ffmpeg # for speechmetrics - run: | - $CONDA/bin/conda install -c conda-forge ffmpeg + - uses: s-weigand/setup-conda@v1 + with: + activate-conda: true + - run: conda install -c conda-forge ffmpeg # for speechmetrics - name: Set min. dependencies if: matrix.requires == 'minimal' From 2da22cdfa4c01c1d7d38eebaf913793baa6826b0 Mon Sep 17 00:00:00 2001 From: Jirka Date: Wed, 16 Jun 2021 11:55:49 +0200 Subject: [PATCH 060/109] multi-OS --- .github/workflows/ci_test-full.yml | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/.github/workflows/ci_test-full.yml b/.github/workflows/ci_test-full.yml index fb6e7a4fd65..0b8264c33f7 100644 --- a/.github/workflows/ci_test-full.yml +++ b/.github/workflows/ci_test-full.yml @@ -43,12 +43,15 @@ jobs: - name: Setup macOS if: runner.os == 'macOS' run: | - brew install gcc libomp # https://github.com/pytorch/pytorch/issues/20030 - - - uses: s-weigand/setup-conda@v1 - with: - activate-conda: true - - run: conda install -c conda-forge ffmpeg # for speechmetrics + brew install gcc libomp ffmpeg # https://github.com/pytorch/pytorch/issues/20030 + - name: Setup Linux + if: runner.os == 'ubuntu' + run: | + apt install -y ffmpeg + - name: Setup Windows + if: runner.os == 'windows' + run: | + choco install ffmpeg - name: Set min. dependencies if: matrix.requires == 'minimal' From e786ed2af021afd85dbc8e6cd75fd563dabe9241 Mon Sep 17 00:00:00 2001 From: quancs Date: Wed, 16 Jun 2021 18:06:54 +0800 Subject: [PATCH 061/109] update atol to 1e-5 --- tests/audio/test_si_sdr.py | 1 + tests/audio/test_si_snr.py | 1 + 2 files changed, 2 insertions(+) diff --git a/tests/audio/test_si_sdr.py b/tests/audio/test_si_sdr.py index 36f950e9da2..bbfd8c169b6 100644 --- a/tests/audio/test_si_sdr.py +++ b/tests/audio/test_si_sdr.py @@ -73,6 +73,7 @@ def average_metric(preds, target, metric_func): ], ) class TestSISDR(MetricTester): + atol = 1e-5 @pytest.mark.parametrize("ddp", [True, False]) @pytest.mark.parametrize("dist_sync_on_step", [True, False]) diff --git a/tests/audio/test_si_snr.py b/tests/audio/test_si_snr.py index 526c9f7882d..e0b3cbc335c 100644 --- a/tests/audio/test_si_snr.py +++ b/tests/audio/test_si_snr.py @@ -68,6 +68,7 @@ def average_metric(preds, target, metric_func): ], ) class TestSISNR(MetricTester): + atol = 1e-5 @pytest.mark.parametrize("ddp", [True, False]) @pytest.mark.parametrize("dist_sync_on_step", [True, False]) From 662cf2807f9594cc2c3d932cc6c79bbcd7d95032 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Wed, 16 Jun 2021 12:29:16 +0200 Subject: [PATCH 062/109] Apply suggestions from code review --- azure-pipelines.yml | 2 +- requirements/integrate.txt | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/azure-pipelines.yml b/azure-pipelines.yml index 76bb72a9f56..cbabc9dae66 100644 --- a/azure-pipelines.yml +++ b/azure-pipelines.yml @@ -44,7 +44,7 @@ jobs: displayName: 'Image info & NVIDIA' - bash: | - sudo apt-get install -y cmake gcc clang + sudo apt-get install -y cmake ffmpeg # python -m pip install "pip==20.1" pip install --requirement ./requirements/devel.txt --upgrade-strategy only-if-needed pip uninstall -y torchmetrics diff --git a/requirements/integrate.txt b/requirements/integrate.txt index 1164d62fc6a..5c2802a7f46 100644 --- a/requirements/integrate.txt +++ b/requirements/integrate.txt @@ -1 +1 @@ -pytorch-lightning>=1.0.1 # because if asteroid depends +pytorch-lightning>=1.0 From d44be308630bc95e13bb37147971b018aca3775c Mon Sep 17 00:00:00 2001 From: quancs Date: Wed, 16 Jun 2021 18:51:41 +0800 Subject: [PATCH 063/109] change atol to 1e-2 --- tests/audio/test_si_sdr.py | 2 +- tests/audio/test_si_snr.py | 2 +- tests/audio/test_snr.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/audio/test_si_sdr.py b/tests/audio/test_si_sdr.py index bbfd8c169b6..18403ebace4 100644 --- a/tests/audio/test_si_sdr.py +++ b/tests/audio/test_si_sdr.py @@ -73,7 +73,7 @@ def average_metric(preds, target, metric_func): ], ) class TestSISDR(MetricTester): - atol = 1e-5 + atol = 1e-2 @pytest.mark.parametrize("ddp", [True, False]) @pytest.mark.parametrize("dist_sync_on_step", [True, False]) diff --git a/tests/audio/test_si_snr.py b/tests/audio/test_si_snr.py index e0b3cbc335c..2d0af81679e 100644 --- a/tests/audio/test_si_snr.py +++ b/tests/audio/test_si_snr.py @@ -68,7 +68,7 @@ def average_metric(preds, target, metric_func): ], ) class TestSISNR(MetricTester): - atol = 1e-5 + atol = 1e-2 @pytest.mark.parametrize("ddp", [True, False]) @pytest.mark.parametrize("dist_sync_on_step", [True, False]) diff --git a/tests/audio/test_snr.py b/tests/audio/test_snr.py index 3a228b5ea3f..0aa4580a2e0 100644 --- a/tests/audio/test_snr.py +++ b/tests/audio/test_snr.py @@ -74,7 +74,7 @@ def average_metric(preds: Tensor, target: Tensor, metric_func: Callable): ], ) class TestSNR(MetricTester): - atol = 1e-5 + atol = 1e-2 @pytest.mark.parametrize("ddp", [True, False]) @pytest.mark.parametrize("dist_sync_on_step", [True, False]) From 5e875ce63b8f8b16f6c3f7b7508ceddc240e4b9d Mon Sep 17 00:00:00 2001 From: quancs Date: Wed, 16 Jun 2021 19:40:29 +0800 Subject: [PATCH 064/109] update --- docs/source/references/modules.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/references/modules.rst b/docs/source/references/modules.rst index 7fee845bf18..0d4aa7dd935 100644 --- a/docs/source/references/modules.rst +++ b/docs/source/references/modules.rst @@ -22,8 +22,8 @@ your own metric type might be too burdensome. Audio Metrics ************* -Input details -~~~~~~~~~~~~~ +About Audio Metrics +~~~~~~~~~~~~~~~~~~~ For the purposes of audio metrics, inputs (predictions, targets) must have the same size. If the input is 1D tensors the output will be a scalar. If the input is multi-dimensional with shape [..., time]` the metric will be computed over the `time` dimension. From 5e3f5c85b9310b83bed717195f266d8cbfd64c3b Mon Sep 17 00:00:00 2001 From: quancs <1017241746@qq.com> Date: Thu, 17 Jun 2021 02:03:39 +0800 Subject: [PATCH 065/109] fix 'Setup Linux' not activated --- .github/workflows/ci_test-full.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci_test-full.yml b/.github/workflows/ci_test-full.yml index 0b8264c33f7..002892fbb17 100644 --- a/.github/workflows/ci_test-full.yml +++ b/.github/workflows/ci_test-full.yml @@ -45,7 +45,7 @@ jobs: run: | brew install gcc libomp ffmpeg # https://github.com/pytorch/pytorch/issues/20030 - name: Setup Linux - if: runner.os == 'ubuntu' + if: runner.os == 'Linux' run: | apt install -y ffmpeg - name: Setup Windows From f08916defade39f27a31b84105399238927d992f Mon Sep 17 00:00:00 2001 From: quancs <1017241746@qq.com> Date: Thu, 17 Jun 2021 02:14:05 +0800 Subject: [PATCH 066/109] add sudo --- .github/workflows/ci_test-full.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci_test-full.yml b/.github/workflows/ci_test-full.yml index 002892fbb17..5f9fcb102ed 100644 --- a/.github/workflows/ci_test-full.yml +++ b/.github/workflows/ci_test-full.yml @@ -47,7 +47,7 @@ jobs: - name: Setup Linux if: runner.os == 'Linux' run: | - apt install -y ffmpeg + sudo apt install -y ffmpeg - name: Setup Windows if: runner.os == 'windows' run: | From 2f26cd61335d97067fa91342ac25e8f172e39ee1 Mon Sep 17 00:00:00 2001 From: quancs <1017241746@qq.com> Date: Thu, 17 Jun 2021 11:01:06 +0800 Subject: [PATCH 067/109] reduce Time to 100 to reduce the test time --- tests/audio/test_si_sdr.py | 2 +- tests/audio/test_si_snr.py | 2 +- tests/audio/test_snr.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/audio/test_si_sdr.py b/tests/audio/test_si_sdr.py index 18403ebace4..dff72942d26 100644 --- a/tests/audio/test_si_sdr.py +++ b/tests/audio/test_si_sdr.py @@ -27,7 +27,7 @@ seed_all(42) -Time = 1000 +Time = 100 Input = namedtuple('Input', ["preds", "target"]) diff --git a/tests/audio/test_si_snr.py b/tests/audio/test_si_snr.py index 2d0af81679e..7f7fb52d298 100644 --- a/tests/audio/test_si_snr.py +++ b/tests/audio/test_si_snr.py @@ -27,7 +27,7 @@ seed_all(42) -Time = 1000 +Time = 100 Input = namedtuple('Input', ["preds", "target"]) diff --git a/tests/audio/test_snr.py b/tests/audio/test_snr.py index 0aa4580a2e0..5a1a017a99c 100644 --- a/tests/audio/test_snr.py +++ b/tests/audio/test_snr.py @@ -28,7 +28,7 @@ seed_all(42) -Time = 1000 +Time = 100 Input = namedtuple('Input', ["preds", "target"]) From ad593c7913a30ac2cc9477e7c4ab269d863e0bbd Mon Sep 17 00:00:00 2001 From: quancs <1017241746@qq.com> Date: Thu, 17 Jun 2021 11:05:08 +0800 Subject: [PATCH 068/109] increase timeoutInMinutes to 40 --- azure-pipelines.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/azure-pipelines.yml b/azure-pipelines.yml index cbabc9dae66..213b8a5d939 100644 --- a/azure-pipelines.yml +++ b/azure-pipelines.yml @@ -19,7 +19,7 @@ pr: jobs: - job: pytest # how long to run the job before automatically cancelling - timeoutInMinutes: 25 + timeoutInMinutes: 40 # how much time to give 'run always even if cancelled tasks' before stopping them cancelTimeoutInMinutes: 2 From 377d4a4c4b5e17a1528ec26316601f63b8c37b4a Mon Sep 17 00:00:00 2001 From: quancs <1017241746@qq.com> Date: Thu, 17 Jun 2021 11:34:08 +0800 Subject: [PATCH 069/109] install ffmpeg --- .github/workflows/ci_test-conda.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/ci_test-conda.yml b/.github/workflows/ci_test-conda.yml index 24cb2e1c775..f3dced40913 100644 --- a/.github/workflows/ci_test-conda.yml +++ b/.github/workflows/ci_test-conda.yml @@ -57,6 +57,7 @@ jobs: conda info conda install mkl pytorch=${{ matrix.pytorch-version }} cpuonly conda install cpuonly $(python ./requirements/adjust-versions.py conda) + conda install -c conda-forge ffmpeg conda list pip --version python ./requirements/adjust-versions.py requirements.txt From 41fddfb752260448feb785b741a6a1a93731a565 Mon Sep 17 00:00:00 2001 From: quancs <1017241746@qq.com> Date: Thu, 17 Jun 2021 11:34:52 +0800 Subject: [PATCH 070/109] timeout-minutes to 55 --- .github/workflows/ci_test-conda.yml | 2 +- .github/workflows/ci_test-full.yml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/ci_test-conda.yml b/.github/workflows/ci_test-conda.yml index f3dced40913..20332f7a502 100644 --- a/.github/workflows/ci_test-conda.yml +++ b/.github/workflows/ci_test-conda.yml @@ -17,7 +17,7 @@ jobs: pytorch-version: [1.3, 1.4, 1.5, 1.6, 1.7, 1.8, 1.9] # Timeout: https://stackoverflow.com/a/59076067/4521646 - timeout-minutes: 35 + timeout-minutes: 55 steps: - uses: actions/checkout@v2 diff --git a/.github/workflows/ci_test-full.yml b/.github/workflows/ci_test-full.yml index 5f9fcb102ed..bc250af7cab 100644 --- a/.github/workflows/ci_test-full.yml +++ b/.github/workflows/ci_test-full.yml @@ -26,7 +26,7 @@ jobs: requires: 'minimal' # Timeout: https://stackoverflow.com/a/59076067/4521646 - timeout-minutes: 35 + timeout-minutes: 55 steps: - uses: actions/checkout@v2 From 53a23c8d6dd85b593b252ee8f8a634e914aeb6da Mon Sep 17 00:00:00 2001 From: quancs <1017241746@qq.com> Date: Thu, 17 Jun 2021 15:49:30 +0800 Subject: [PATCH 071/109] +git --- azure-pipelines.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/azure-pipelines.yml b/azure-pipelines.yml index 213b8a5d939..4b4ff84b232 100644 --- a/azure-pipelines.yml +++ b/azure-pipelines.yml @@ -19,7 +19,7 @@ pr: jobs: - job: pytest # how long to run the job before automatically cancelling - timeoutInMinutes: 40 + timeoutInMinutes: 55 # how much time to give 'run always even if cancelled tasks' before stopping them cancelTimeoutInMinutes: 2 @@ -44,7 +44,7 @@ jobs: displayName: 'Image info & NVIDIA' - bash: | - sudo apt-get install -y cmake ffmpeg + sudo apt-get install -y cmake ffmpeg git # python -m pip install "pip==20.1" pip install --requirement ./requirements/devel.txt --upgrade-strategy only-if-needed pip uninstall -y torchmetrics From e8a6b6d9da98a78d956558002d6eddb24176df1c Mon Sep 17 00:00:00 2001 From: quancs <1017241746@qq.com> Date: Thu, 17 Jun 2021 16:18:43 +0800 Subject: [PATCH 072/109] show-error-codes --- .github/workflows/code-format.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/code-format.yml b/.github/workflows/code-format.yml index 479f6bb1281..bca47ce2cd0 100644 --- a/.github/workflows/code-format.yml +++ b/.github/workflows/code-format.yml @@ -52,7 +52,7 @@ jobs: pip list - name: mypy run: | - mypy + mypy --show-error-codes # format-check-yapf: # runs-on: ubuntu-20.04 From 70a655e7553c1fd9aafcaa54c4d2002611312fd4 Mon Sep 17 00:00:00 2001 From: quancs <1017241746@qq.com> Date: Thu, 17 Jun 2021 17:30:33 +0800 Subject: [PATCH 073/109] .detach().cpu().numpy() first --- tests/audio/test_si_sdr.py | 57 ++++++++++++++++++------------- tests/audio/test_si_snr.py | 4 ++- tests/audio/test_snr.py | 69 ++++++++++++++++++++++++-------------- 3 files changed, 80 insertions(+), 50 deletions(-) diff --git a/tests/audio/test_si_sdr.py b/tests/audio/test_si_sdr.py index dff72942d26..2a651deb07a 100644 --- a/tests/audio/test_si_sdr.py +++ b/tests/audio/test_si_sdr.py @@ -45,11 +45,15 @@ def speechmetrics_si_sdr(preds: Tensor, target: Tensor, zero_mean: bool): if zero_mean: preds = preds - preds.mean(dim=2, keepdim=True) target = target - target.mean(dim=2, keepdim=True) + target = target.detach().cpu().numpy() + preds = preds.detach().cpu().numpy() mss = [] for i in range(preds.shape[0]): ms = [] for j in range(preds.shape[1]): - metric = speechmetrics_sisdr(preds[i, j].numpy(), target[i, j].numpy(), rate=16000) + metric = speechmetrics_sisdr(preds[i, j], + target[i, j], + rate=16000) ms.append(metric['sisdr'][0]) mss.append(ms) return torch.tensor(mss) @@ -62,14 +66,16 @@ def average_metric(preds, target, metric_func): speechmetrics_si_sdr_zero_mean = partial(speechmetrics_si_sdr, zero_mean=True) -speechmetrics_si_sdr_no_zero_mean = partial(speechmetrics_si_sdr, zero_mean=False) +speechmetrics_si_sdr_no_zero_mean = partial(speechmetrics_si_sdr, + zero_mean=False) @pytest.mark.parametrize( "preds, target, sk_metric, zero_mean", [ (inputs.preds, inputs.target, speechmetrics_si_sdr_zero_mean, True), - (inputs.preds, inputs.target, speechmetrics_si_sdr_no_zero_mean, False), + (inputs.preds, inputs.target, speechmetrics_si_sdr_no_zero_mean, + False), ], ) class TestSISDR(MetricTester): @@ -77,7 +83,8 @@ class TestSISDR(MetricTester): @pytest.mark.parametrize("ddp", [True, False]) @pytest.mark.parametrize("dist_sync_on_step", [True, False]) - def test_si_sdr(self, preds, target, sk_metric, zero_mean, ddp, dist_sync_on_step): + def test_si_sdr(self, preds, target, sk_metric, zero_mean, ddp, + dist_sync_on_step): self.run_class_metric_test( ddp, preds, @@ -97,33 +104,35 @@ def test_si_sdr_functional(self, preds, target, sk_metric, zero_mean): metric_args=dict(zero_mean=zero_mean), ) - def test_si_sdr_differentiability(self, preds, target, sk_metric, zero_mean): - self.run_differentiability_test( - preds=preds, - target=target, - metric_module=SI_SDR, - metric_functional=si_sdr, - metric_args={'zero_mean': zero_mean} - ) + def test_si_sdr_differentiability(self, preds, target, sk_metric, + zero_mean): + self.run_differentiability_test(preds=preds, + target=target, + metric_module=SI_SDR, + metric_functional=si_sdr, + metric_args={'zero_mean': zero_mean}) @pytest.mark.skipif( - not _TORCH_GREATER_EQUAL_1_6, reason='half support of core operations on not support before pytorch v1.6' - ) + not _TORCH_GREATER_EQUAL_1_6, + reason= + 'half support of core operations on not support before pytorch v1.6') def test_si_sdr_half_cpu(self, preds, target, sk_metric, zero_mean): pytest.xfail("SI-SDR metric does not support cpu + half precision") - @pytest.mark.skipif(not torch.cuda.is_available(), reason='test requires cuda') + @pytest.mark.skipif(not torch.cuda.is_available(), + reason='test requires cuda') def test_si_sdr_half_gpu(self, preds, target, sk_metric, zero_mean): - self.run_precision_test_gpu( - preds=preds, - target=target, - metric_module=SI_SDR, - metric_functional=si_sdr, - metric_args={'zero_mean': zero_mean} - ) + self.run_precision_test_gpu(preds=preds, + target=target, + metric_module=SI_SDR, + metric_functional=si_sdr, + metric_args={'zero_mean': zero_mean}) def test_error_on_different_shape(metric_class=SI_SDR): metric = metric_class() - with pytest.raises(RuntimeError, match='Predictions and targets are expected to have the same shape'): - metric(torch.randn(100, ), torch.randn(50, )) + with pytest.raises( + RuntimeError, + match='Predictions and targets are expected to have the same shape' + ): + metric(torch.randn(100,), torch.randn(50,)) diff --git a/tests/audio/test_si_snr.py b/tests/audio/test_si_snr.py index 7f7fb52d298..46d1cf6ed09 100644 --- a/tests/audio/test_si_snr.py +++ b/tests/audio/test_si_snr.py @@ -45,11 +45,13 @@ def speechmetrics_si_sdr(preds: Tensor, target: Tensor, zero_mean: bool = True): if zero_mean: preds = preds - preds.mean(dim=2, keepdim=True) target = target - target.mean(dim=2, keepdim=True) + target = target.detach().cpu().numpy() + preds = preds.detach().cpu().numpy() mss = [] for i in range(preds.shape[0]): ms = [] for j in range(preds.shape[1]): - metric = speechmetrics_sisdr(preds[i, j].numpy(), target[i, j].numpy(), rate=16000) + metric = speechmetrics_sisdr(preds[i, j], target[i, j], rate=16000) ms.append(metric['sisdr'][0]) mss.append(ms) return torch.tensor(mss) diff --git a/tests/audio/test_snr.py b/tests/audio/test_snr.py index 5a1a017a99c..f1357ecab0f 100644 --- a/tests/audio/test_snr.py +++ b/tests/audio/test_snr.py @@ -38,22 +38,27 @@ ) -def bss_eval_images_snr(preds: Tensor, target: Tensor, metric_func: Callable, zero_mean: bool): +def bss_eval_images_snr(preds: Tensor, target: Tensor, metric_func: Callable, + zero_mean: bool): # shape: preds [BATCH_SIZE, 1, Time] , target [BATCH_SIZE, 1, Time] # or shape: preds [NUM_BATCHES*BATCH_SIZE, 1, Time] , target [NUM_BATCHES*BATCH_SIZE, 1, Time] if zero_mean: target = target - torch.mean(target, dim=-1, keepdim=True) preds = preds - torch.mean(preds, dim=-1, keepdim=True) - snr_vb = [] - for j in range(preds.shape[0]): - t = target[j] - e = preds[j] - if metric_func == mir_eval_bss_eval_images: - snr_v = metric_func([t.view(-1).numpy()], [e.view(-1).numpy()])[0][0] - else: - snr_v = metric_func([t.view(-1).numpy()], [e.view(-1).numpy()])[0][0][0] - snr_vb.append(snr_v) - return torch.tensor(snr_vb).view(-1, 1) + target = target.detach().cpu().numpy() + preds = preds.detach().cpu().numpy() + mss = [] + for i in range(preds.shape[0]): + ms = [] + for j in range(preds.shape[1]): + if metric_func == mir_eval_bss_eval_images: + snr_v = metric_func([target[i, j]], [preds[i, j]])[0][0] + else: + snr_v = metric_func([target[i, j]], [preds[i, j]])[0][0][0] + ms.append(snr_v) + mss.append(ms) + return torch.tensor(mss) + def average_metric(preds: Tensor, target: Tensor, metric_func: Callable): @@ -62,8 +67,12 @@ def average_metric(preds: Tensor, target: Tensor, metric_func: Callable): return metric_func(preds, target).mean() -mireval_snr_zeromean = partial(bss_eval_images_snr, metric_func=mir_eval_bss_eval_images, zero_mean=True) -mireval_snr_nozeromean = partial(bss_eval_images_snr, metric_func=mir_eval_bss_eval_images, zero_mean=False) +mireval_snr_zeromean = partial(bss_eval_images_snr, + metric_func=mir_eval_bss_eval_images, + zero_mean=True) +mireval_snr_nozeromean = partial(bss_eval_images_snr, + metric_func=mir_eval_bss_eval_images, + zero_mean=False) @pytest.mark.parametrize( @@ -78,7 +87,8 @@ class TestSNR(MetricTester): @pytest.mark.parametrize("ddp", [True, False]) @pytest.mark.parametrize("dist_sync_on_step", [True, False]) - def test_snr(self, preds, target, sk_metric, zero_mean, ddp, dist_sync_on_step): + def test_snr(self, preds, target, sk_metric, zero_mean, ddp, + dist_sync_on_step): self.run_class_metric_test( ddp, preds, @@ -99,24 +109,33 @@ def test_snr_functional(self, preds, target, sk_metric, zero_mean): ) def test_snr_differentiability(self, preds, target, sk_metric, zero_mean): - self.run_differentiability_test( - preds=preds, target=target, metric_module=SNR, metric_functional=snr, metric_args={'zero_mean': zero_mean} - ) + self.run_differentiability_test(preds=preds, + target=target, + metric_module=SNR, + metric_functional=snr, + metric_args={'zero_mean': zero_mean}) @pytest.mark.skipif( - not _TORCH_GREATER_EQUAL_1_6, reason='half support of core operations on not support before pytorch v1.6' - ) + not _TORCH_GREATER_EQUAL_1_6, + reason= + 'half support of core operations on not support before pytorch v1.6') def test_snr_half_cpu(self, preds, target, sk_metric, zero_mean): pytest.xfail("SNR metric does not support cpu + half precision") - @pytest.mark.skipif(not torch.cuda.is_available(), reason='test requires cuda') + @pytest.mark.skipif(not torch.cuda.is_available(), + reason='test requires cuda') def test_snr_half_gpu(self, preds, target, sk_metric, zero_mean): - self.run_precision_test_gpu( - preds=preds, target=target, metric_module=SNR, metric_functional=snr, metric_args={'zero_mean': zero_mean} - ) + self.run_precision_test_gpu(preds=preds, + target=target, + metric_module=SNR, + metric_functional=snr, + metric_args={'zero_mean': zero_mean}) def test_error_on_different_shape(metric_class=SNR): metric = metric_class() - with pytest.raises(RuntimeError, match='Predictions and targets are expected to have the same shape'): - metric(torch.randn(100, ), torch.randn(50, )) + with pytest.raises( + RuntimeError, + match='Predictions and targets are expected to have the same shape' + ): + metric(torch.randn(100,), torch.randn(50,)) From 96e72805db951e8242c011425fd65a87e17879f5 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 17 Jun 2021 09:31:52 +0000 Subject: [PATCH 074/109] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/audio/test_si_sdr.py | 55 +++++++++++++++++--------------------- tests/audio/test_snr.py | 46 +++++++++++-------------------- 2 files changed, 39 insertions(+), 62 deletions(-) diff --git a/tests/audio/test_si_sdr.py b/tests/audio/test_si_sdr.py index 2a651deb07a..c1ea88e522e 100644 --- a/tests/audio/test_si_sdr.py +++ b/tests/audio/test_si_sdr.py @@ -51,9 +51,7 @@ def speechmetrics_si_sdr(preds: Tensor, target: Tensor, zero_mean: bool): for i in range(preds.shape[0]): ms = [] for j in range(preds.shape[1]): - metric = speechmetrics_sisdr(preds[i, j], - target[i, j], - rate=16000) + metric = speechmetrics_sisdr(preds[i, j], target[i, j], rate=16000) ms.append(metric['sisdr'][0]) mss.append(ms) return torch.tensor(mss) @@ -66,16 +64,14 @@ def average_metric(preds, target, metric_func): speechmetrics_si_sdr_zero_mean = partial(speechmetrics_si_sdr, zero_mean=True) -speechmetrics_si_sdr_no_zero_mean = partial(speechmetrics_si_sdr, - zero_mean=False) +speechmetrics_si_sdr_no_zero_mean = partial(speechmetrics_si_sdr, zero_mean=False) @pytest.mark.parametrize( "preds, target, sk_metric, zero_mean", [ (inputs.preds, inputs.target, speechmetrics_si_sdr_zero_mean, True), - (inputs.preds, inputs.target, speechmetrics_si_sdr_no_zero_mean, - False), + (inputs.preds, inputs.target, speechmetrics_si_sdr_no_zero_mean, False), ], ) class TestSISDR(MetricTester): @@ -83,8 +79,7 @@ class TestSISDR(MetricTester): @pytest.mark.parametrize("ddp", [True, False]) @pytest.mark.parametrize("dist_sync_on_step", [True, False]) - def test_si_sdr(self, preds, target, sk_metric, zero_mean, ddp, - dist_sync_on_step): + def test_si_sdr(self, preds, target, sk_metric, zero_mean, ddp, dist_sync_on_step): self.run_class_metric_test( ddp, preds, @@ -104,35 +99,33 @@ def test_si_sdr_functional(self, preds, target, sk_metric, zero_mean): metric_args=dict(zero_mean=zero_mean), ) - def test_si_sdr_differentiability(self, preds, target, sk_metric, - zero_mean): - self.run_differentiability_test(preds=preds, - target=target, - metric_module=SI_SDR, - metric_functional=si_sdr, - metric_args={'zero_mean': zero_mean}) + def test_si_sdr_differentiability(self, preds, target, sk_metric, zero_mean): + self.run_differentiability_test( + preds=preds, + target=target, + metric_module=SI_SDR, + metric_functional=si_sdr, + metric_args={'zero_mean': zero_mean} + ) @pytest.mark.skipif( - not _TORCH_GREATER_EQUAL_1_6, - reason= - 'half support of core operations on not support before pytorch v1.6') + not _TORCH_GREATER_EQUAL_1_6, reason='half support of core operations on not support before pytorch v1.6' + ) def test_si_sdr_half_cpu(self, preds, target, sk_metric, zero_mean): pytest.xfail("SI-SDR metric does not support cpu + half precision") - @pytest.mark.skipif(not torch.cuda.is_available(), - reason='test requires cuda') + @pytest.mark.skipif(not torch.cuda.is_available(), reason='test requires cuda') def test_si_sdr_half_gpu(self, preds, target, sk_metric, zero_mean): - self.run_precision_test_gpu(preds=preds, - target=target, - metric_module=SI_SDR, - metric_functional=si_sdr, - metric_args={'zero_mean': zero_mean}) + self.run_precision_test_gpu( + preds=preds, + target=target, + metric_module=SI_SDR, + metric_functional=si_sdr, + metric_args={'zero_mean': zero_mean} + ) def test_error_on_different_shape(metric_class=SI_SDR): metric = metric_class() - with pytest.raises( - RuntimeError, - match='Predictions and targets are expected to have the same shape' - ): - metric(torch.randn(100,), torch.randn(50,)) + with pytest.raises(RuntimeError, match='Predictions and targets are expected to have the same shape'): + metric(torch.randn(100, ), torch.randn(50, )) diff --git a/tests/audio/test_snr.py b/tests/audio/test_snr.py index f1357ecab0f..e0baebeb067 100644 --- a/tests/audio/test_snr.py +++ b/tests/audio/test_snr.py @@ -38,8 +38,7 @@ ) -def bss_eval_images_snr(preds: Tensor, target: Tensor, metric_func: Callable, - zero_mean: bool): +def bss_eval_images_snr(preds: Tensor, target: Tensor, metric_func: Callable, zero_mean: bool): # shape: preds [BATCH_SIZE, 1, Time] , target [BATCH_SIZE, 1, Time] # or shape: preds [NUM_BATCHES*BATCH_SIZE, 1, Time] , target [NUM_BATCHES*BATCH_SIZE, 1, Time] if zero_mean: @@ -60,19 +59,14 @@ def bss_eval_images_snr(preds: Tensor, target: Tensor, metric_func: Callable, return torch.tensor(mss) - def average_metric(preds: Tensor, target: Tensor, metric_func: Callable): # shape: preds [BATCH_SIZE, 1, Time] , target [BATCH_SIZE, 1, Time] # or shape: preds [NUM_BATCHES*BATCH_SIZE, 1, Time] , target [NUM_BATCHES*BATCH_SIZE, 1, Time] return metric_func(preds, target).mean() -mireval_snr_zeromean = partial(bss_eval_images_snr, - metric_func=mir_eval_bss_eval_images, - zero_mean=True) -mireval_snr_nozeromean = partial(bss_eval_images_snr, - metric_func=mir_eval_bss_eval_images, - zero_mean=False) +mireval_snr_zeromean = partial(bss_eval_images_snr, metric_func=mir_eval_bss_eval_images, zero_mean=True) +mireval_snr_nozeromean = partial(bss_eval_images_snr, metric_func=mir_eval_bss_eval_images, zero_mean=False) @pytest.mark.parametrize( @@ -87,8 +81,7 @@ class TestSNR(MetricTester): @pytest.mark.parametrize("ddp", [True, False]) @pytest.mark.parametrize("dist_sync_on_step", [True, False]) - def test_snr(self, preds, target, sk_metric, zero_mean, ddp, - dist_sync_on_step): + def test_snr(self, preds, target, sk_metric, zero_mean, ddp, dist_sync_on_step): self.run_class_metric_test( ddp, preds, @@ -109,33 +102,24 @@ def test_snr_functional(self, preds, target, sk_metric, zero_mean): ) def test_snr_differentiability(self, preds, target, sk_metric, zero_mean): - self.run_differentiability_test(preds=preds, - target=target, - metric_module=SNR, - metric_functional=snr, - metric_args={'zero_mean': zero_mean}) + self.run_differentiability_test( + preds=preds, target=target, metric_module=SNR, metric_functional=snr, metric_args={'zero_mean': zero_mean} + ) @pytest.mark.skipif( - not _TORCH_GREATER_EQUAL_1_6, - reason= - 'half support of core operations on not support before pytorch v1.6') + not _TORCH_GREATER_EQUAL_1_6, reason='half support of core operations on not support before pytorch v1.6' + ) def test_snr_half_cpu(self, preds, target, sk_metric, zero_mean): pytest.xfail("SNR metric does not support cpu + half precision") - @pytest.mark.skipif(not torch.cuda.is_available(), - reason='test requires cuda') + @pytest.mark.skipif(not torch.cuda.is_available(), reason='test requires cuda') def test_snr_half_gpu(self, preds, target, sk_metric, zero_mean): - self.run_precision_test_gpu(preds=preds, - target=target, - metric_module=SNR, - metric_functional=snr, - metric_args={'zero_mean': zero_mean}) + self.run_precision_test_gpu( + preds=preds, target=target, metric_module=SNR, metric_functional=snr, metric_args={'zero_mean': zero_mean} + ) def test_error_on_different_shape(metric_class=SNR): metric = metric_class() - with pytest.raises( - RuntimeError, - match='Predictions and targets are expected to have the same shape' - ): - metric(torch.randn(100,), torch.randn(50,)) + with pytest.raises(RuntimeError, match='Predictions and targets are expected to have the same shape'): + metric(torch.randn(100, ), torch.randn(50, )) From a23e165108ee9ffa6f1125bcabd7d9143896ebf9 Mon Sep 17 00:00:00 2001 From: quancs <1017241746@qq.com> Date: Thu, 17 Jun 2021 20:25:58 +0800 Subject: [PATCH 075/109] add numpy --- requirements/test.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements/test.txt b/requirements/test.txt index 5251a06f3be..1f632f595d9 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -22,3 +22,4 @@ nltk>=3.6 mir_eval>=0.6 https://github.com/aliutkus/speechmetrics/archive/refs/heads/master.zip +numpy From 23b4683ecd03f2a0174aebc8ff66af9ab66c2472 Mon Sep 17 00:00:00 2001 From: Jirka Date: Thu, 17 Jun 2021 14:42:15 +0200 Subject: [PATCH 076/109] numpy --- requirements/test.txt | 1 - 1 file changed, 1 deletion(-) diff --git a/requirements/test.txt b/requirements/test.txt index 1f632f595d9..5251a06f3be 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -22,4 +22,3 @@ nltk>=3.6 mir_eval>=0.6 https://github.com/aliutkus/speechmetrics/archive/refs/heads/master.zip -numpy From d726ba30db10d964c08eb8e4461a80e1227b37f1 Mon Sep 17 00:00:00 2001 From: quancs <1017241746@qq.com> Date: Thu, 17 Jun 2021 22:31:56 +0800 Subject: [PATCH 077/109] ignore_errors torchmetrics.audio.* --- setup.cfg | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/setup.cfg b/setup.cfg index 3da200ba376..30355af2f98 100644 --- a/setup.cfg +++ b/setup.cfg @@ -104,3 +104,7 @@ ignore_errors = True # todo: add proper typing to this module... [mypy-torchmetrics.retrieval.*] ignore_errors = True + +# todo: add proper typing to this module... +[mypy-torchmetrics.audio.*] +ignore_errors = True From 0861c6e45f0e701a07cd2362f1554053a698695c Mon Sep 17 00:00:00 2001 From: quancs <1017241746@qq.com> Date: Thu, 17 Jun 2021 23:41:23 +0800 Subject: [PATCH 078/109] solve mypy no-redef error --- torchmetrics/audio/__init__.py | 6 +++--- torchmetrics/audio/{SI_SDR.py => si_sdr.py} | 0 torchmetrics/audio/{SI_SNR.py => si_snr.py} | 0 torchmetrics/audio/{SNR.py => snr.py} | 0 4 files changed, 3 insertions(+), 3 deletions(-) rename torchmetrics/audio/{SI_SDR.py => si_sdr.py} (100%) rename torchmetrics/audio/{SI_SNR.py => si_snr.py} (100%) rename torchmetrics/audio/{SNR.py => snr.py} (100%) diff --git a/torchmetrics/audio/__init__.py b/torchmetrics/audio/__init__.py index 035f1cd38fb..1ca805ab15a 100644 --- a/torchmetrics/audio/__init__.py +++ b/torchmetrics/audio/__init__.py @@ -11,6 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from torchmetrics.audio.SI_SDR import SI_SDR # noqa: F401 -from torchmetrics.audio.SI_SNR import SI_SNR # noqa: F401 -from torchmetrics.audio.SNR import SNR # noqa: F401 +from torchmetrics.audio.si_sdr import SI_SDR # noqa: F401 +from torchmetrics.audio.si_snr import SI_SNR # noqa: F401 +from torchmetrics.audio.snr import SNR # noqa: F401 diff --git a/torchmetrics/audio/SI_SDR.py b/torchmetrics/audio/si_sdr.py similarity index 100% rename from torchmetrics/audio/SI_SDR.py rename to torchmetrics/audio/si_sdr.py diff --git a/torchmetrics/audio/SI_SNR.py b/torchmetrics/audio/si_snr.py similarity index 100% rename from torchmetrics/audio/SI_SNR.py rename to torchmetrics/audio/si_snr.py diff --git a/torchmetrics/audio/SNR.py b/torchmetrics/audio/snr.py similarity index 100% rename from torchmetrics/audio/SNR.py rename to torchmetrics/audio/snr.py From 1848f9e33a57575fc836fb71c3b2ba6bae5deffa Mon Sep 17 00:00:00 2001 From: quancs <1017241746@qq.com> Date: Fri, 18 Jun 2021 01:58:32 +0800 Subject: [PATCH 079/109] remove --quiet --- .github/workflows/ci_test-conda.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/ci_test-conda.yml b/.github/workflows/ci_test-conda.yml index 20332f7a502..c95f7cb4a59 100644 --- a/.github/workflows/ci_test-conda.yml +++ b/.github/workflows/ci_test-conda.yml @@ -62,8 +62,8 @@ jobs: pip --version python ./requirements/adjust-versions.py requirements.txt python ./requirements/adjust-versions.py requirements/image.txt - pip install --requirement requirements.txt --upgrade-strategy only-if-needed --quiet - pip install --requirement requirements/test.txt --upgrade-strategy only-if-needed --quiet --find-links https://download.pytorch.org/whl/cpu/torch_stable.html + pip install --requirement requirements.txt --upgrade-strategy only-if-needed + pip install --requirement requirements/test.txt --upgrade-strategy only-if-needed --find-links https://download.pytorch.org/whl/cpu/torch_stable.html pip list python -c "import torch; assert torch.__version__[:3] == '${{ matrix.pytorch-version }}', torch.__version__" shell: bash -l {0} From 8469dc8baffd46a335641ec2de7c36335d7687f7 Mon Sep 17 00:00:00 2001 From: Jirka Date: Thu, 17 Jun 2021 23:22:50 +0200 Subject: [PATCH 080/109] pypesq --- requirements/test.txt | 2 ++ 1 file changed, 2 insertions(+) diff --git a/requirements/test.txt b/requirements/test.txt index 5251a06f3be..a8439b8a127 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -20,5 +20,7 @@ nltk>=3.6 # add extra requirements -r image.txt +# audio +pypesq mir_eval>=0.6 https://github.com/aliutkus/speechmetrics/archive/refs/heads/master.zip From 4f3d41f4915a6f1615130def761852edea4cf5c8 Mon Sep 17 00:00:00 2001 From: Jirka Date: Fri, 18 Jun 2021 21:06:17 +0200 Subject: [PATCH 081/109] apt --- azure-pipelines.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/azure-pipelines.yml b/azure-pipelines.yml index 4b4ff84b232..5ed2ff396f8 100644 --- a/azure-pipelines.yml +++ b/azure-pipelines.yml @@ -19,7 +19,7 @@ pr: jobs: - job: pytest # how long to run the job before automatically cancelling - timeoutInMinutes: 55 + timeoutInMinutes: 45 # how much time to give 'run always even if cancelled tasks' before stopping them cancelTimeoutInMinutes: 2 @@ -44,7 +44,7 @@ jobs: displayName: 'Image info & NVIDIA' - bash: | - sudo apt-get install -y cmake ffmpeg git + apt install -y cmake ffmpeg git # python -m pip install "pip==20.1" pip install --requirement ./requirements/devel.txt --upgrade-strategy only-if-needed pip uninstall -y torchmetrics From 9358bfd031acc3459f18ad07e20f37e8df3bc038 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Fri, 18 Jun 2021 01:20:45 +0200 Subject: [PATCH 082/109] Inception Score (#299) * implementation and test * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * changelog * add example * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Apply suggestions from code review * update to torch fidelity 0.3.0 * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * 35min Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Jirka Borovec Co-authored-by: Jirka --- CHANGELOG.md | 3 + docs/source/references/modules.rst | 8 ++ tests/image/test_fid.py | 16 ++- tests/image/test_inception.py | 125 ++++++++++++++++++++ torchmetrics/__init__.py | 2 +- torchmetrics/image/__init__.py | 1 + torchmetrics/image/fid.py | 12 +- torchmetrics/image/inception.py | 176 +++++++++++++++++++++++++++++ 8 files changed, 334 insertions(+), 9 deletions(-) create mode 100644 tests/image/test_inception.py create mode 100644 torchmetrics/image/inception.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 5bd43f0dc14..bebf1b62fa7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -39,6 +39,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added audio metrics: SNR, SI_SDR, SI_SNR ([#292](https://github.com/PyTorchLightning/metrics/pull/292)) +- Added Inception Score metric to image module ([#299](https://github.com/PyTorchLightning/metrics/pull/299)) + + ### Changed - Forward cache is now reset when `reset` method is called ([#260](https://github.com/PyTorchLightning/metrics/pull/260)) diff --git a/docs/source/references/modules.rst b/docs/source/references/modules.rst index 0d4aa7dd935..5f218fbf669 100644 --- a/docs/source/references/modules.rst +++ b/docs/source/references/modules.rst @@ -297,9 +297,17 @@ Image Quality Metrics Image quality metrics can be used to access the quality of synthetic generated images from machine learning algorithms such as `Generative Adverserial Networks (GANs) `_. +FID +~~~ + .. autoclass:: torchmetrics.FID :noindex: +IS +~~ + +.. autoclass:: torchmetrics.IS + :noindex: ****************** Regression Metrics diff --git a/tests/image/test_fid.py b/tests/image/test_fid.py index 7987441520f..ec4609b3832 100644 --- a/tests/image/test_fid.py +++ b/tests/image/test_fid.py @@ -91,14 +91,18 @@ def test_fid_raises_errors_and_warnings(): ): _ = FID() + with pytest.raises(TypeError, match='Got unknown input to argument `feature`'): + _ = FID(feature=[1, 2]) + @pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason='test requires torch-fidelity') -def test_fid_same_input(): +@pytest.mark.parametrize("feature", [64, 192, 768, 2048]) +def test_fid_same_input(feature): """ if real and fake are update on the same data the fid score should be 0 """ - metric = FID(feature=192) + metric = FID(feature=feature) for _ in range(2): - img = torch.randint(0, 255, (5, 3, 299, 299), dtype=torch.uint8) + img = torch.randint(0, 255, (10, 3, 299, 299), dtype=torch.uint8) metric.update(img, real=True) metric.update(img, real=False) @@ -140,7 +144,11 @@ def test_compare_fid(tmpdir, feature=2048): metric.update(img2[batch_size * i:batch_size * (i + 1)].cuda(), real=False) torch_fid = calculate_metrics( - _ImgDataset(img1), _ImgDataset(img2), fid=True, feature_layer_fid=str(feature), batch_size=batch_size + input1=_ImgDataset(img1), + input2=_ImgDataset(img2), + fid=True, + feature_layer_fid=str(feature), + batch_size=batch_size ) tm_res = metric.compute() diff --git a/tests/image/test_inception.py b/tests/image/test_inception.py new file mode 100644 index 00000000000..120426374f8 --- /dev/null +++ b/tests/image/test_inception.py @@ -0,0 +1,125 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pickle + +import pytest +import torch +from torch.utils.data import Dataset + +from torchmetrics.image.inception import IS +from torchmetrics.utilities.imports import _TORCH_FIDELITY_AVAILABLE + +torch.manual_seed(42) + + +@pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason="test requires torch-fidelity") +def test_no_train(): + """ Assert that metric never leaves evaluation mode """ + + class MyModel(torch.nn.Module): + + def __init__(self): + super().__init__() + self.metric = IS() + + def forward(self, x): + return x + + model = MyModel() + model.train() + assert model.training + assert not model.metric.inception.training, 'IS metric was changed to training mode which should not happen' + + +@pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason='test requires torch-fidelity') +def test_is_pickle(): + """ Assert that we can initialize the metric and pickle it""" + metric = IS() + assert metric + + # verify metrics work after being loaded from pickled state + pickled_metric = pickle.dumps(metric) + metric = pickle.loads(pickled_metric) + + +def test_is_raises_errors_and_warnings(): + """ Test that expected warnings and errors are raised """ + with pytest.warns( + UserWarning, + match='Metric `IS` will save all extracted features in buffer.' + ' For large datasets this may lead to large memory footprint.' + ): + IS() + + if _TORCH_FIDELITY_AVAILABLE: + with pytest.raises(ValueError, match='Integer input to argument `feature` must be one of .*'): + _ = IS(feature=2) + else: + with pytest.raises( + ValueError, + match='IS metric requires that Torch-fidelity is installed.' + 'Either install as `pip install torchmetrics[image-quality]`' + ' or `pip install torch-fidelity`' + ): + IS() + + with pytest.raises(TypeError, match='Got unknown input to argument `feature`'): + IS(feature=[1, 2]) + + +@pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason='test requires torch-fidelity') +def test_is_update_compute(): + metric = IS() + + for _ in range(2): + img = torch.randint(0, 255, (10, 3, 299, 299), dtype=torch.uint8) + metric.update(img) + + mean, std = metric.compute() + assert mean != 0.0 + assert std != 0.0 + + +class _ImgDataset(Dataset): + + def __init__(self, imgs): + self.imgs = imgs + + def __getitem__(self, idx): + return self.imgs[idx] + + def __len__(self): + return self.imgs.shape[0] + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason='test is too slow without gpu') +@pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason='test requires torch-fidelity') +def test_compare_is(tmpdir): + """ check that the hole pipeline give the same result as torch-fidelity """ + from torch_fidelity import calculate_metrics + + metric = IS(splits=1).cuda() + + # Generate some synthetic data + img1 = torch.randint(0, 255, (100, 3, 299, 299), dtype=torch.uint8) + + batch_size = 10 + for i in range(img1.shape[0] // batch_size): + metric.update(img1[batch_size * i:batch_size * (i + 1)].cuda()) + + torch_fid = calculate_metrics(input1=_ImgDataset(img1), isc=True, isc_splits=1, batch_size=batch_size) + + tm_mean, tm_std = metric.compute() + + assert torch.allclose(tm_mean.cpu(), torch.tensor([torch_fid['inception_score_mean']]), atol=1e-3) diff --git a/torchmetrics/__init__.py b/torchmetrics/__init__.py index ee0b44f10bd..738094916d2 100644 --- a/torchmetrics/__init__.py +++ b/torchmetrics/__init__.py @@ -37,7 +37,7 @@ StatScores, ) from torchmetrics.collections import MetricCollection # noqa: F401 E402 -from torchmetrics.image import FID # noqa: F401 E402 +from torchmetrics.image import FID, IS # noqa: F401 E402 from torchmetrics.metric import Metric # noqa: F401 E402 from torchmetrics.regression import ( # noqa: F401 E402 PSNR, diff --git a/torchmetrics/image/__init__.py b/torchmetrics/image/__init__.py index ec4d70fa511..a0bee1440ee 100644 --- a/torchmetrics/image/__init__.py +++ b/torchmetrics/image/__init__.py @@ -12,3 +12,4 @@ # See the License for the specific language governing permissions and # limitations under the License. from torchmetrics.image.fid import FID # noqa: F401 +from torchmetrics.image.inception import IS # noqa: F401 diff --git a/torchmetrics/image/fid.py b/torchmetrics/image/fid.py index 3b692564575..6cd2889ab95 100644 --- a/torchmetrics/image/fid.py +++ b/torchmetrics/image/fid.py @@ -141,7 +141,7 @@ class FID(Metric): determines if the images should update the statistics of the real distribution or the fake distribution. .. note:: using this metric with the default feature extractor requires that ``torch-fidelity`` - is installed. Either install as ``pip install torchmetrics[image-quality]`` or + is installed. Either install as ``pip install torchmetrics[image]`` or ``pip install torch-fidelity`` .. note:: the ``forward`` method can be used but ``compute_on_step`` is disabled by default (oppesit of @@ -182,6 +182,8 @@ class FID(Metric): If ``feature`` is set to an ``int`` (default settings) and ``torch-fidelity`` is not installed ValueError: If ``feature`` is set to an ``int`` not in [64, 192, 768, 2048] + TypeError: + If ``feature`` is not an ``str``, ``int`` or ``torch.nn.Module`` Example: >>> import torch @@ -204,7 +206,7 @@ def __init__( compute_on_step: bool = False, dist_sync_on_step: bool = False, process_group: Optional[Any] = None, - dist_sync_fn: Callable = None + dist_sync_fn: Callable[[Tensor], List[Tensor]] = None ): super().__init__( compute_on_step=compute_on_step, @@ -222,7 +224,7 @@ def __init__( if not _TORCH_FIDELITY_AVAILABLE: raise ValueError( 'FID metric requires that Torch-fidelity is installed.' - 'Either install as `pip install torchmetrics[image-quality]` or `pip install torch-fidelity`' + 'Either install as `pip install torchmetrics[image]` or `pip install torch-fidelity`' ) valid_int_input = [64, 192, 768, 2048] if feature not in valid_int_input: @@ -231,8 +233,10 @@ def __init__( ) self.inception = NoTrainInceptionV3(name='inception-v3-compat', features_list=[str(feature)]) - else: + elif isinstance(feature, torch.nn.Module): self.inception = feature + else: + raise TypeError('Got unknown input to argument `feature`') self.add_state("real_features", [], dist_reduce_fx=None) self.add_state("fake_features", [], dist_reduce_fx=None) diff --git a/torchmetrics/image/inception.py b/torchmetrics/image/inception.py new file mode 100644 index 00000000000..87416445784 --- /dev/null +++ b/torchmetrics/image/inception.py @@ -0,0 +1,176 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Callable, List, Optional, Tuple, Union + +import torch +from torch import Tensor + +from torchmetrics.image.fid import NoTrainInceptionV3 +from torchmetrics.metric import Metric +from torchmetrics.utilities import rank_zero_warn +from torchmetrics.utilities.imports import _TORCH_FIDELITY_AVAILABLE + + +class IS(Metric): + r""" + Calculates the Inception Score (IS) which is used to access how realistic generated images are. + It is defined as + + .. math:: + IS = exp(\mathbb{E}_x KL(p(y | x ) || p(y))) + + where :math:`KL(p(y | x) || p(y))` is the KL divergence between the conditional distribution :math:`p(y|x)` + and the margianl distribution :math:`p(y)`. Both the conditional and marginal distribution is calculated + from features extracted from the images. The score is calculated on random splits of the images such that + both a mean and standard deviation of the score are returned. The metric was originally proposed in [1]. + + Using the default feature extraction (Inception v3 using the original weights from [2]), the input is + expected to be mini-batches of 3-channel RGB images of shape (3 x H x W) with dtype uint8. All images + will be resized to 299 x 299 which is the size of the original training data. + + .. note:: using this metric with the default feature extractor requires that ``torch-fidelity`` + is installed. Either install as ``pip install torchmetrics[image]`` or + ``pip install torch-fidelity`` + + .. note:: the ``forward`` method can be used but ``compute_on_step`` is disabled by default (oppesit of + all other metrics) as this metric does not really make sense to calculate on a single batch. This + means that by default ``forward`` will just call ``update`` underneat. + + Args: + feature: + Either an str, integer or ``nn.Module``: + + - an str or integer will indicate the inceptionv3 feature layer to choose. Can be one of the following: + 'logits_unbiased', 64, 192, 768, 2048 + - an ``nn.Module`` for using a custom feature extractor. Expects that its forward method returns + an ``[N,d]`` matrix where ``N`` is the batch size and ``d`` is the feature size. + + splits: integer determining how many splits the inception score calculation should be split among + + compute_on_step: + Forward only calls ``update()`` and return ``None`` if this is set to ``False``. + dist_sync_on_step: + Synchronize metric state across processes at each ``forward()`` + before returning the value at the step + process_group: + Specify the process group on which synchronization is called. + default: ``None`` (which selects the entire world) + dist_sync_fn: + Callback that performs the allgather operation on the metric state. When ``None``, DDP + will be used to perform the allgather + + [1] Improved Techniques for Training GANs + Tim Salimans, Ian Goodfellow, Wojciech Zaremba, Vicki Cheung, Alec Radford, Xi Chen + https://arxiv.org/abs/1606.03498 + + [2] GANs Trained by a Two Time-Scale Update Rule Converge to a Local Nash Equilibrium, + Martin Heusel, Hubert Ramsauer, Thomas Unterthiner, Bernhard Nessler, Sepp Hochreiter + https://arxiv.org/abs/1706.08500 + + Raises: + ValueError: + If ``feature`` is set to an ``str`` or ``int`` and ``torch-fidelity`` is not installed + ValueError: + If ``feature`` is set to an ``str`` or ``int`` and not one of ['logits_unbiased', 64, 192, 768, 2048] + TypeError: + If ``feature`` is not an ``str``, ``int`` or ``torch.nn.Module`` + + Example: + >>> import torch + >>> _ = torch.manual_seed(123) + >>> from torchmetrics import IS + >>> inception = IS() # doctest: +SKIP + >>> # generate some images + >>> imgs = torch.randint(0, 255, (100, 3, 299, 299), dtype=torch.uint8) # doctest: +SKIP + >>> inception.update(imgs) # doctest: +SKIP + >>> inception.compute() # doctest: +SKIP + (tensor(1.0569), tensor(0.0113)) + + """ + + def __init__( + self, + feature: Union[str, int, torch.nn.Module] = 'logits_unbiased', + splits: int = 10, + compute_on_step: bool = False, + dist_sync_on_step: bool = False, + process_group: Optional[Any] = None, + dist_sync_fn: Callable[[Tensor], List[Tensor]] = None + ): + super().__init__( + compute_on_step=compute_on_step, + dist_sync_on_step=dist_sync_on_step, + process_group=process_group, + dist_sync_fn=dist_sync_fn, + ) + + rank_zero_warn( + 'Metric `IS` will save all extracted features in buffer.' + ' For large datasets this may lead to large memory footprint.', UserWarning + ) + + if isinstance(feature, (str, int)): + if not _TORCH_FIDELITY_AVAILABLE: + raise ValueError( + 'IS metric requires that Torch-fidelity is installed.' + 'Either install as `pip install torchmetrics[image]`' + ' or `pip install torch-fidelity`' + ) + valid_int_input = ('logits_unbiased', 64, 192, 768, 2048) + if feature not in valid_int_input: + raise ValueError( + f'Integer input to argument `feature` must be one of {valid_int_input},' + f' but got {feature}.' + ) + + self.inception = NoTrainInceptionV3(name='inception-v3-compat', features_list=[str(feature)]) + elif isinstance(feature, torch.nn.Module): + self.inception = feature + else: + raise TypeError('Got unknown input to argument `feature`') + + self.splits = splits + self.add_state("features", [], dist_reduce_fx=None) + + def update(self, imgs: Tensor) -> None: # type: ignore + """ Update the state with extracted features + + Args: + imgs: tensor with images feed to the feature extractor + """ + features = self.inception(imgs) + self.features.append(features) + + def compute(self) -> Tuple[Tensor, Tensor]: + features = torch.cat(self.features, dim=0) + # random permute the features + idx = torch.randperm(features.shape[0]) + features = features[idx] + + # calculate probs and logits + prob = features.softmax(dim=1) + log_prob = features.log_softmax(dim=1) + + # split into groups + prob = prob.chunk(self.splits, dim=0) + log_prob = log_prob.chunk(self.splits, dim=0) + + # calculate score per split + mean_prob = [p.mean(dim=0, keepdim=True) for p in prob] + kl_ = [p * (log_p - m_p.log()) for p, log_p, m_p in zip(prob, log_prob, mean_prob)] + kl_ = [k.sum(dim=1).mean().exp() for k in kl_] + kl = torch.stack(kl_) + + # return mean and std + return kl.mean(), kl.std() From fd3aa0a5681aa573d11f482826b0990522d98c7a Mon Sep 17 00:00:00 2001 From: Justus Schock <12886177+justusschock@users.noreply.github.com> Date: Fri, 18 Jun 2021 17:05:48 +0200 Subject: [PATCH 083/109] Apply suggestions from code review --- torchmetrics/audio/si_sdr.py | 4 ++-- torchmetrics/audio/si_snr.py | 4 ++-- torchmetrics/audio/snr.py | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/torchmetrics/audio/si_sdr.py b/torchmetrics/audio/si_sdr.py index 2f42c2a0ca5..18a54c35ea4 100644 --- a/torchmetrics/audio/si_sdr.py +++ b/torchmetrics/audio/si_sdr.py @@ -69,8 +69,8 @@ def __init__( compute_on_step: bool = True, dist_sync_on_step: bool = False, process_group: Optional[Any] = None, - dist_sync_fn: Callable = None, - ): + dist_sync_fn: Optional[Callable[[torch.Tensor], torch.Tensor]] = None, + ) -> None: super().__init__( compute_on_step=compute_on_step, dist_sync_on_step=dist_sync_on_step, diff --git a/torchmetrics/audio/si_snr.py b/torchmetrics/audio/si_snr.py index 10c9d840174..d05f152d90a 100644 --- a/torchmetrics/audio/si_snr.py +++ b/torchmetrics/audio/si_snr.py @@ -67,8 +67,8 @@ def __init__( compute_on_step: bool = True, dist_sync_on_step: bool = False, process_group: Optional[Any] = None, - dist_sync_fn: Callable = None, - ): + dist_sync_fn: Optional[Callable[[torch.Tensor], torch.Tensor]] = None, + ) -> None: super().__init__( compute_on_step=compute_on_step, dist_sync_on_step=dist_sync_on_step, diff --git a/torchmetrics/audio/snr.py b/torchmetrics/audio/snr.py index 223063cc1ac..a94e06b58b8 100644 --- a/torchmetrics/audio/snr.py +++ b/torchmetrics/audio/snr.py @@ -69,8 +69,8 @@ def __init__( compute_on_step: bool = True, dist_sync_on_step: bool = False, process_group: Optional[Any] = None, - dist_sync_fn: Callable = None, - ): + dist_sync_fn: Optional[Callable[[torch.Tensor], torch.Tensor]] = None, + ) -> None: super().__init__( compute_on_step=compute_on_step, dist_sync_on_step=dist_sync_on_step, From 05e0f1905dbdb8691238c09cafa7a0bb8185d999 Mon Sep 17 00:00:00 2001 From: quancs <1017241746@qq.com> Date: Fri, 18 Jun 2021 23:33:10 +0800 Subject: [PATCH 084/109] add # type: ignore --- setup.cfg | 4 ---- torchmetrics/audio/si_sdr.py | 4 ++-- torchmetrics/audio/si_snr.py | 4 ++-- torchmetrics/audio/snr.py | 4 ++-- 4 files changed, 6 insertions(+), 10 deletions(-) diff --git a/setup.cfg b/setup.cfg index 30355af2f98..3da200ba376 100644 --- a/setup.cfg +++ b/setup.cfg @@ -104,7 +104,3 @@ ignore_errors = True # todo: add proper typing to this module... [mypy-torchmetrics.retrieval.*] ignore_errors = True - -# todo: add proper typing to this module... -[mypy-torchmetrics.audio.*] -ignore_errors = True diff --git a/torchmetrics/audio/si_sdr.py b/torchmetrics/audio/si_sdr.py index 18a54c35ea4..da3fc570ff7 100644 --- a/torchmetrics/audio/si_sdr.py +++ b/torchmetrics/audio/si_sdr.py @@ -69,7 +69,7 @@ def __init__( compute_on_step: bool = True, dist_sync_on_step: bool = False, process_group: Optional[Any] = None, - dist_sync_fn: Optional[Callable[[torch.Tensor], torch.Tensor]] = None, + dist_sync_fn: Optional[Callable[[Tensor], Tensor]] = None, ) -> None: super().__init__( compute_on_step=compute_on_step, @@ -82,7 +82,7 @@ def __init__( self.add_state("sum_si_sdr", default=tensor(0.0), dist_reduce_fx="sum") self.add_state("total", default=tensor(0), dist_reduce_fx="sum") - def update(self, preds: Tensor, target: Tensor) -> None: + def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore """ Update state with predictions and targets. diff --git a/torchmetrics/audio/si_snr.py b/torchmetrics/audio/si_snr.py index d05f152d90a..35848f79782 100644 --- a/torchmetrics/audio/si_snr.py +++ b/torchmetrics/audio/si_snr.py @@ -67,7 +67,7 @@ def __init__( compute_on_step: bool = True, dist_sync_on_step: bool = False, process_group: Optional[Any] = None, - dist_sync_fn: Optional[Callable[[torch.Tensor], torch.Tensor]] = None, + dist_sync_fn: Optional[Callable[[Tensor], Tensor]] = None, ) -> None: super().__init__( compute_on_step=compute_on_step, @@ -79,7 +79,7 @@ def __init__( self.add_state("sum_si_snr", default=tensor(0.0), dist_reduce_fx="sum") self.add_state("total", default=tensor(0), dist_reduce_fx="sum") - def update(self, preds: Tensor, target: Tensor) -> None: + def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore """ Update state with predictions and targets. diff --git a/torchmetrics/audio/snr.py b/torchmetrics/audio/snr.py index a94e06b58b8..07f62f94dfc 100644 --- a/torchmetrics/audio/snr.py +++ b/torchmetrics/audio/snr.py @@ -69,7 +69,7 @@ def __init__( compute_on_step: bool = True, dist_sync_on_step: bool = False, process_group: Optional[Any] = None, - dist_sync_fn: Optional[Callable[[torch.Tensor], torch.Tensor]] = None, + dist_sync_fn: Optional[Callable[[Tensor], Tensor]] = None, ) -> None: super().__init__( compute_on_step=compute_on_step, @@ -82,7 +82,7 @@ def __init__( self.add_state("sum_snr", default=tensor(0.0), dist_reduce_fx="sum") self.add_state("total", default=tensor(0), dist_reduce_fx="sum") - def update(self, preds: Tensor, target: Tensor) -> None: + def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore """ Update state with predictions and targets. From a351c8d0b686682fc0fb6fe3ded6cd71b81ef2b3 Mon Sep 17 00:00:00 2001 From: quancs <1017241746@qq.com> Date: Fri, 18 Jun 2021 23:45:28 +0800 Subject: [PATCH 085/109] try without test_si_snr & test_si_sdr --- tests/audio/test_si_sdr.py | 131 ------------------------------------- tests/audio/test_si_snr.py | 112 ------------------------------- 2 files changed, 243 deletions(-) delete mode 100644 tests/audio/test_si_sdr.py delete mode 100644 tests/audio/test_si_snr.py diff --git a/tests/audio/test_si_sdr.py b/tests/audio/test_si_sdr.py deleted file mode 100644 index c1ea88e522e..00000000000 --- a/tests/audio/test_si_sdr.py +++ /dev/null @@ -1,131 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from collections import namedtuple -from functools import partial - -import pytest -import speechmetrics -import torch -from torch import Tensor - -from tests.helpers import seed_all -from tests.helpers.testers import BATCH_SIZE, NUM_BATCHES, MetricTester -from torchmetrics.audio import SI_SDR -from torchmetrics.functional import si_sdr -from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_6 - -seed_all(42) - -Time = 100 - -Input = namedtuple('Input', ["preds", "target"]) - -inputs = Input( - preds=torch.rand(NUM_BATCHES, BATCH_SIZE, 1, Time), - target=torch.rand(NUM_BATCHES, BATCH_SIZE, 1, Time), -) - -speechmetrics_sisdr = speechmetrics.load('sisdr') - - -def speechmetrics_si_sdr(preds: Tensor, target: Tensor, zero_mean: bool): - # shape: preds [BATCH_SIZE, 1, Time] , target [BATCH_SIZE, 1, Time] - # or shape: preds [NUM_BATCHES*BATCH_SIZE, 1, Time] , target [NUM_BATCHES*BATCH_SIZE, 1, Time] - if zero_mean: - preds = preds - preds.mean(dim=2, keepdim=True) - target = target - target.mean(dim=2, keepdim=True) - target = target.detach().cpu().numpy() - preds = preds.detach().cpu().numpy() - mss = [] - for i in range(preds.shape[0]): - ms = [] - for j in range(preds.shape[1]): - metric = speechmetrics_sisdr(preds[i, j], target[i, j], rate=16000) - ms.append(metric['sisdr'][0]) - mss.append(ms) - return torch.tensor(mss) - - -def average_metric(preds, target, metric_func): - # shape: preds [BATCH_SIZE, 1, Time] , target [BATCH_SIZE, 1, Time] - # or shape: preds [NUM_BATCHES*BATCH_SIZE, 1, Time] , target [NUM_BATCHES*BATCH_SIZE, 1, Time] - return metric_func(preds, target).mean() - - -speechmetrics_si_sdr_zero_mean = partial(speechmetrics_si_sdr, zero_mean=True) -speechmetrics_si_sdr_no_zero_mean = partial(speechmetrics_si_sdr, zero_mean=False) - - -@pytest.mark.parametrize( - "preds, target, sk_metric, zero_mean", - [ - (inputs.preds, inputs.target, speechmetrics_si_sdr_zero_mean, True), - (inputs.preds, inputs.target, speechmetrics_si_sdr_no_zero_mean, False), - ], -) -class TestSISDR(MetricTester): - atol = 1e-2 - - @pytest.mark.parametrize("ddp", [True, False]) - @pytest.mark.parametrize("dist_sync_on_step", [True, False]) - def test_si_sdr(self, preds, target, sk_metric, zero_mean, ddp, dist_sync_on_step): - self.run_class_metric_test( - ddp, - preds, - target, - SI_SDR, - sk_metric=partial(average_metric, metric_func=sk_metric), - dist_sync_on_step=dist_sync_on_step, - metric_args=dict(zero_mean=zero_mean), - ) - - def test_si_sdr_functional(self, preds, target, sk_metric, zero_mean): - self.run_functional_metric_test( - preds, - target, - si_sdr, - sk_metric, - metric_args=dict(zero_mean=zero_mean), - ) - - def test_si_sdr_differentiability(self, preds, target, sk_metric, zero_mean): - self.run_differentiability_test( - preds=preds, - target=target, - metric_module=SI_SDR, - metric_functional=si_sdr, - metric_args={'zero_mean': zero_mean} - ) - - @pytest.mark.skipif( - not _TORCH_GREATER_EQUAL_1_6, reason='half support of core operations on not support before pytorch v1.6' - ) - def test_si_sdr_half_cpu(self, preds, target, sk_metric, zero_mean): - pytest.xfail("SI-SDR metric does not support cpu + half precision") - - @pytest.mark.skipif(not torch.cuda.is_available(), reason='test requires cuda') - def test_si_sdr_half_gpu(self, preds, target, sk_metric, zero_mean): - self.run_precision_test_gpu( - preds=preds, - target=target, - metric_module=SI_SDR, - metric_functional=si_sdr, - metric_args={'zero_mean': zero_mean} - ) - - -def test_error_on_different_shape(metric_class=SI_SDR): - metric = metric_class() - with pytest.raises(RuntimeError, match='Predictions and targets are expected to have the same shape'): - metric(torch.randn(100, ), torch.randn(50, )) diff --git a/tests/audio/test_si_snr.py b/tests/audio/test_si_snr.py deleted file mode 100644 index 46d1cf6ed09..00000000000 --- a/tests/audio/test_si_snr.py +++ /dev/null @@ -1,112 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from collections import namedtuple -from functools import partial - -import pytest -import speechmetrics -import torch -from torch import Tensor - -from tests.helpers import seed_all -from tests.helpers.testers import BATCH_SIZE, NUM_BATCHES, MetricTester -from torchmetrics.audio import SI_SNR -from torchmetrics.functional import si_snr -from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_6 - -seed_all(42) - -Time = 100 - -Input = namedtuple('Input', ["preds", "target"]) - -inputs = Input( - preds=torch.rand(NUM_BATCHES, BATCH_SIZE, 1, Time), - target=torch.rand(NUM_BATCHES, BATCH_SIZE, 1, Time), -) - -speechmetrics_sisdr = speechmetrics.load('sisdr') - - -def speechmetrics_si_sdr(preds: Tensor, target: Tensor, zero_mean: bool = True): - # shape: preds [BATCH_SIZE, 1, Time] , target [BATCH_SIZE, 1, Time] - # or shape: preds [NUM_BATCHES*BATCH_SIZE, 1, Time] , target [NUM_BATCHES*BATCH_SIZE, 1, Time] - if zero_mean: - preds = preds - preds.mean(dim=2, keepdim=True) - target = target - target.mean(dim=2, keepdim=True) - target = target.detach().cpu().numpy() - preds = preds.detach().cpu().numpy() - mss = [] - for i in range(preds.shape[0]): - ms = [] - for j in range(preds.shape[1]): - metric = speechmetrics_sisdr(preds[i, j], target[i, j], rate=16000) - ms.append(metric['sisdr'][0]) - mss.append(ms) - return torch.tensor(mss) - - -def average_metric(preds, target, metric_func): - # shape: preds [BATCH_SIZE, 1, Time] , target [BATCH_SIZE, 1, Time] - # or shape: preds [NUM_BATCHES*BATCH_SIZE, 1, Time] , target [NUM_BATCHES*BATCH_SIZE, 1, Time] - return metric_func(preds, target).mean() - - -@pytest.mark.parametrize( - "preds, target, sk_metric", - [ - (inputs.preds, inputs.target, speechmetrics_si_sdr), - ], -) -class TestSISNR(MetricTester): - atol = 1e-2 - - @pytest.mark.parametrize("ddp", [True, False]) - @pytest.mark.parametrize("dist_sync_on_step", [True, False]) - def test_si_snr(self, preds, target, sk_metric, ddp, dist_sync_on_step): - self.run_class_metric_test( - ddp, - preds, - target, - SI_SNR, - sk_metric=partial(average_metric, metric_func=sk_metric), - dist_sync_on_step=dist_sync_on_step, - ) - - def test_si_snr_functional(self, preds, target, sk_metric): - self.run_functional_metric_test( - preds, - target, - si_snr, - sk_metric, - ) - - def test_si_snr_differentiability(self, preds, target, sk_metric): - self.run_differentiability_test(preds=preds, target=target, metric_module=SI_SNR, metric_functional=si_snr) - - @pytest.mark.skipif( - not _TORCH_GREATER_EQUAL_1_6, reason='half support of core operations on not support before pytorch v1.6' - ) - def test_si_snr_half_cpu(self, preds, target, sk_metric): - pytest.xfail("SI-SNR metric does not support cpu + half precision") - - @pytest.mark.skipif(not torch.cuda.is_available(), reason='test requires cuda') - def test_si_snr_half_gpu(self, preds, target, sk_metric): - self.run_precision_test_gpu(preds=preds, target=target, metric_module=SI_SNR, metric_functional=si_snr) - - -def test_error_on_different_shape(metric_class=SI_SNR): - metric = metric_class() - with pytest.raises(RuntimeError, match='Predictions and targets are expected to have the same shape'): - metric(torch.randn(100, ), torch.randn(50, )) From 67da4857af67f6cf79ca38b29f5883709c34eb45 Mon Sep 17 00:00:00 2001 From: quancs <1017241746@qq.com> Date: Sat, 19 Jun 2021 00:10:08 +0800 Subject: [PATCH 086/109] test_import_speechmetrics --- tests/audio/test_speechmetrics.py | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) create mode 100644 tests/audio/test_speechmetrics.py diff --git a/tests/audio/test_speechmetrics.py b/tests/audio/test_speechmetrics.py new file mode 100644 index 00000000000..76350084ab2 --- /dev/null +++ b/tests/audio/test_speechmetrics.py @@ -0,0 +1,27 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from collections import namedtuple +from functools import partial +from typing import Callable + +import pytest +import torch +from torch import Tensor + + +def test_import_speechmetrics() -> None: + try: + import speechmetrics + except ImportError: + pytest.fail('ImportError speechmetrics') From 64ffd9af6a2d05d86f97a38eb980f50fe6780a71 Mon Sep 17 00:00:00 2001 From: quancs <1017241746@qq.com> Date: Sat, 19 Jun 2021 00:37:55 +0800 Subject: [PATCH 087/109] test_speechmetrics_si_sdr --- tests/audio/test_speechmetrics.py | 51 +++++++++++++++++++++++++++++++ 1 file changed, 51 insertions(+) diff --git a/tests/audio/test_speechmetrics.py b/tests/audio/test_speechmetrics.py index 76350084ab2..63f14bfad9c 100644 --- a/tests/audio/test_speechmetrics.py +++ b/tests/audio/test_speechmetrics.py @@ -25,3 +25,54 @@ def test_import_speechmetrics() -> None: import speechmetrics except ImportError: pytest.fail('ImportError speechmetrics') + + +from tests.helpers import seed_all +from tests.helpers.testers import BATCH_SIZE, NUM_BATCHES, MetricTester + +seed_all(42) + +Time = 100 + +Input = namedtuple('Input', ["preds", "target"]) + +inputs = Input( + preds=torch.rand(NUM_BATCHES, BATCH_SIZE, 1, Time), + target=torch.rand(NUM_BATCHES, BATCH_SIZE, 1, Time), +) + +import multiprocessing +import time +import speechmetrics + +speechmetrics_sisdr = speechmetrics.load('sisdr') +def speechmetrics_si_sdr(preds: Tensor, target: Tensor, + zero_mean: bool) -> Tensor: + if zero_mean: + preds = preds - preds.mean(dim=2, keepdim=True) + target = target - target.mean(dim=2, keepdim=True) + target = target.detach().cpu().numpy() + preds = preds.detach().cpu().numpy() + mss = [] + for i in range(preds.shape[0]): + ms = [] + for j in range(preds.shape[1]): + metric = speechmetrics_sisdr(preds[i, j], + target[i, j], + rate=16000) + ms.append(metric['sisdr'][0]) + mss.append(ms) + return torch.tensor(mss) + +def test_speechmetrics_si_sdr() -> None: + t = multiprocessing.Process(target=speechmetrics_si_sdr, + args=(inputs.preds[0], inputs.target[0], + False)) + t.start() + try: + t.join(timeout=180) # 3min + if t.is_alive(): + pytest.fail(f'timeout 3min. t.is_alive()={t.is_alive()}') + t.terminate() + except: + pytest.fail('join except') From 0b5fb27dcaa3cbf5e12650e004c6144a34e021d8 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 18 Jun 2021 16:38:33 +0000 Subject: [PATCH 088/109] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/audio/test_speechmetrics.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/tests/audio/test_speechmetrics.py b/tests/audio/test_speechmetrics.py index 63f14bfad9c..f8fe846bafa 100644 --- a/tests/audio/test_speechmetrics.py +++ b/tests/audio/test_speechmetrics.py @@ -43,11 +43,13 @@ def test_import_speechmetrics() -> None: import multiprocessing import time + import speechmetrics speechmetrics_sisdr = speechmetrics.load('sisdr') -def speechmetrics_si_sdr(preds: Tensor, target: Tensor, - zero_mean: bool) -> Tensor: + + +def speechmetrics_si_sdr(preds: Tensor, target: Tensor, zero_mean: bool) -> Tensor: if zero_mean: preds = preds - preds.mean(dim=2, keepdim=True) target = target - target.mean(dim=2, keepdim=True) @@ -57,17 +59,14 @@ def speechmetrics_si_sdr(preds: Tensor, target: Tensor, for i in range(preds.shape[0]): ms = [] for j in range(preds.shape[1]): - metric = speechmetrics_sisdr(preds[i, j], - target[i, j], - rate=16000) + metric = speechmetrics_sisdr(preds[i, j], target[i, j], rate=16000) ms.append(metric['sisdr'][0]) mss.append(ms) return torch.tensor(mss) + def test_speechmetrics_si_sdr() -> None: - t = multiprocessing.Process(target=speechmetrics_si_sdr, - args=(inputs.preds[0], inputs.target[0], - False)) + t = multiprocessing.Process(target=speechmetrics_si_sdr, args=(inputs.preds[0], inputs.target[0], False)) t.start() try: t.join(timeout=180) # 3min From 85d9f720604e5f50c685adfb682a15d3dbe2c376 Mon Sep 17 00:00:00 2001 From: quancs <1017241746@qq.com> Date: Sat, 19 Jun 2021 00:52:20 +0800 Subject: [PATCH 089/109] test_si_sdr_functional --- tests/audio/test_speechmetrics.py | 78 +++++++++++++++++++++++++++++++ 1 file changed, 78 insertions(+) diff --git a/tests/audio/test_speechmetrics.py b/tests/audio/test_speechmetrics.py index f8fe846bafa..f8fb4dc0a94 100644 --- a/tests/audio/test_speechmetrics.py +++ b/tests/audio/test_speechmetrics.py @@ -75,3 +75,81 @@ def test_speechmetrics_si_sdr() -> None: t.terminate() except: pytest.fail('join except') + + +from torchmetrics.audio import SI_SDR +from torchmetrics.functional import si_sdr + +speechmetrics_si_sdr_zero_mean = partial(speechmetrics_si_sdr, zero_mean=True) +speechmetrics_si_sdr_no_zero_mean = partial(speechmetrics_si_sdr, zero_mean=False) +from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_6 + + +# def average_metric(preds, target, metric_func): +# # shape: preds [BATCH_SIZE, 1, Time] , target [BATCH_SIZE, 1, Time] +# # or shape: preds [NUM_BATCHES*BATCH_SIZE, 1, Time] , target [NUM_BATCHES*BATCH_SIZE, 1, Time] +# return metric_func(preds, target).mean() + + +@pytest.mark.parametrize( + "preds, target, sk_metric, zero_mean", + [ + (inputs.preds, inputs.target, speechmetrics_si_sdr_zero_mean, True), + (inputs.preds, inputs.target, speechmetrics_si_sdr_no_zero_mean, False), + ], +) +class TestSISDR(MetricTester): + atol = 1e-2 + + # @pytest.mark.parametrize("ddp", [True, False]) + # @pytest.mark.parametrize("dist_sync_on_step", [True, False]) + # def test_si_sdr(self, preds, target, sk_metric, zero_mean, ddp, dist_sync_on_step): + # self.run_class_metric_test( + # ddp, + # preds, + # target, + # SI_SDR, + # sk_metric=partial(average_metric, metric_func=sk_metric), + # dist_sync_on_step=dist_sync_on_step, + # metric_args=dict(zero_mean=zero_mean), + # ) + + def test_si_sdr_functional(self, preds, target, sk_metric, zero_mean): + self.run_functional_metric_test( + preds, + target, + si_sdr, + sk_metric, + metric_args=dict(zero_mean=zero_mean), + ) + + # def test_si_sdr_differentiability(self, preds, target, sk_metric, zero_mean): + # self.run_differentiability_test( + # preds=preds, + # target=target, + # metric_module=SI_SDR, + # metric_functional=si_sdr, + # metric_args={'zero_mean': zero_mean} + # ) + + # @pytest.mark.skipif( + # not _TORCH_GREATER_EQUAL_1_6, reason='half support of core operations on not support before pytorch v1.6' + # ) + # def test_si_sdr_half_cpu(self, preds, target, sk_metric, zero_mean): + # pytest.xfail("SI-SDR metric does not support cpu + half precision") + + # @pytest.mark.skipif(not torch.cuda.is_available(), reason='test requires cuda') + # def test_si_sdr_half_gpu(self, preds, target, sk_metric, zero_mean): + # self.run_precision_test_gpu( + # preds=preds, + # target=target, + # metric_module=SI_SDR, + # metric_functional=si_sdr, + # metric_args={'zero_mean': zero_mean} + # ) + + +# def test_error_on_different_shape(metric_class=SI_SDR): +# metric = metric_class() +# with pytest.raises(RuntimeError, match='Predictions and targets are expected to have the same shape'): +# metric(torch.randn(100, ), torch.randn(50, )) From f5afe0f808d6450b61db4a16ceb8d3b3d0faef1a Mon Sep 17 00:00:00 2001 From: quancs <1017241746@qq.com> Date: Sat, 19 Jun 2021 01:06:17 +0800 Subject: [PATCH 090/109] test audio only --- .github/workflows/ci_test-conda.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/ci_test-conda.yml b/.github/workflows/ci_test-conda.yml index c95f7cb4a59..d6bca117f00 100644 --- a/.github/workflows/ci_test-conda.yml +++ b/.github/workflows/ci_test-conda.yml @@ -17,7 +17,7 @@ jobs: pytorch-version: [1.3, 1.4, 1.5, 1.6, 1.7, 1.8, 1.9] # Timeout: https://stackoverflow.com/a/59076067/4521646 - timeout-minutes: 55 + timeout-minutes: 20 steps: - uses: actions/checkout@v2 @@ -71,7 +71,7 @@ jobs: - name: Testing run: | # NOTE: run coverage on tests does not propagare faler status for Win, https://github.com/nedbat/coveragepy/issues/1003 - python -m pytest torchmetrics tests -v --durations=35 --junitxml=junit/test-conda-py${{ matrix.python-version }}-pt${{ matrix.pytorch-version }}.xml + python -m pytest tests/audio -v --durations=35 --junitxml=junit/test-conda-py${{ matrix.python-version }}-pt${{ matrix.pytorch-version }}.xml shell: bash -l {0} - name: Upload pytest test results From 6e11386e913a2b69413b6d79813db4dd4debcf1e Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 18 Jun 2021 16:53:09 +0000 Subject: [PATCH 091/109] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/audio/test_speechmetrics.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/audio/test_speechmetrics.py b/tests/audio/test_speechmetrics.py index f8fb4dc0a94..1a14b93b5b8 100644 --- a/tests/audio/test_speechmetrics.py +++ b/tests/audio/test_speechmetrics.py @@ -84,7 +84,6 @@ def test_speechmetrics_si_sdr() -> None: speechmetrics_si_sdr_no_zero_mean = partial(speechmetrics_si_sdr, zero_mean=False) from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_6 - # def average_metric(preds, target, metric_func): # # shape: preds [BATCH_SIZE, 1, Time] , target [BATCH_SIZE, 1, Time] # # or shape: preds [NUM_BATCHES*BATCH_SIZE, 1, Time] , target [NUM_BATCHES*BATCH_SIZE, 1, Time] From fe9a7cf0397ca967c2f8aec4402f5f16590fc643 Mon Sep 17 00:00:00 2001 From: quancs <1017241746@qq.com> Date: Sat, 19 Jun 2021 01:24:20 +0800 Subject: [PATCH 092/109] install libsndfile1 --- .github/workflows/ci_test-conda.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/ci_test-conda.yml b/.github/workflows/ci_test-conda.yml index d6bca117f00..0c518198cec 100644 --- a/.github/workflows/ci_test-conda.yml +++ b/.github/workflows/ci_test-conda.yml @@ -54,6 +54,7 @@ jobs: - name: Update Environment run: | + sudo apt install libsndfile1 conda info conda install mkl pytorch=${{ matrix.pytorch-version }} cpuonly conda install cpuonly $(python ./requirements/adjust-versions.py conda) From 9491551680bfee5a87fb3d93a406b4619851e504 Mon Sep 17 00:00:00 2001 From: quancs <1017241746@qq.com> Date: Sat, 19 Jun 2021 01:33:02 +0800 Subject: [PATCH 093/109] add sisnr sisdr test --- tests/audio/test_si_sdr.py | 131 +++++++++++++++++++++++++++++++++++++ tests/audio/test_si_snr.py | 112 +++++++++++++++++++++++++++++++ 2 files changed, 243 insertions(+) create mode 100644 tests/audio/test_si_sdr.py create mode 100644 tests/audio/test_si_snr.py diff --git a/tests/audio/test_si_sdr.py b/tests/audio/test_si_sdr.py new file mode 100644 index 00000000000..c1ea88e522e --- /dev/null +++ b/tests/audio/test_si_sdr.py @@ -0,0 +1,131 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from collections import namedtuple +from functools import partial + +import pytest +import speechmetrics +import torch +from torch import Tensor + +from tests.helpers import seed_all +from tests.helpers.testers import BATCH_SIZE, NUM_BATCHES, MetricTester +from torchmetrics.audio import SI_SDR +from torchmetrics.functional import si_sdr +from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_6 + +seed_all(42) + +Time = 100 + +Input = namedtuple('Input', ["preds", "target"]) + +inputs = Input( + preds=torch.rand(NUM_BATCHES, BATCH_SIZE, 1, Time), + target=torch.rand(NUM_BATCHES, BATCH_SIZE, 1, Time), +) + +speechmetrics_sisdr = speechmetrics.load('sisdr') + + +def speechmetrics_si_sdr(preds: Tensor, target: Tensor, zero_mean: bool): + # shape: preds [BATCH_SIZE, 1, Time] , target [BATCH_SIZE, 1, Time] + # or shape: preds [NUM_BATCHES*BATCH_SIZE, 1, Time] , target [NUM_BATCHES*BATCH_SIZE, 1, Time] + if zero_mean: + preds = preds - preds.mean(dim=2, keepdim=True) + target = target - target.mean(dim=2, keepdim=True) + target = target.detach().cpu().numpy() + preds = preds.detach().cpu().numpy() + mss = [] + for i in range(preds.shape[0]): + ms = [] + for j in range(preds.shape[1]): + metric = speechmetrics_sisdr(preds[i, j], target[i, j], rate=16000) + ms.append(metric['sisdr'][0]) + mss.append(ms) + return torch.tensor(mss) + + +def average_metric(preds, target, metric_func): + # shape: preds [BATCH_SIZE, 1, Time] , target [BATCH_SIZE, 1, Time] + # or shape: preds [NUM_BATCHES*BATCH_SIZE, 1, Time] , target [NUM_BATCHES*BATCH_SIZE, 1, Time] + return metric_func(preds, target).mean() + + +speechmetrics_si_sdr_zero_mean = partial(speechmetrics_si_sdr, zero_mean=True) +speechmetrics_si_sdr_no_zero_mean = partial(speechmetrics_si_sdr, zero_mean=False) + + +@pytest.mark.parametrize( + "preds, target, sk_metric, zero_mean", + [ + (inputs.preds, inputs.target, speechmetrics_si_sdr_zero_mean, True), + (inputs.preds, inputs.target, speechmetrics_si_sdr_no_zero_mean, False), + ], +) +class TestSISDR(MetricTester): + atol = 1e-2 + + @pytest.mark.parametrize("ddp", [True, False]) + @pytest.mark.parametrize("dist_sync_on_step", [True, False]) + def test_si_sdr(self, preds, target, sk_metric, zero_mean, ddp, dist_sync_on_step): + self.run_class_metric_test( + ddp, + preds, + target, + SI_SDR, + sk_metric=partial(average_metric, metric_func=sk_metric), + dist_sync_on_step=dist_sync_on_step, + metric_args=dict(zero_mean=zero_mean), + ) + + def test_si_sdr_functional(self, preds, target, sk_metric, zero_mean): + self.run_functional_metric_test( + preds, + target, + si_sdr, + sk_metric, + metric_args=dict(zero_mean=zero_mean), + ) + + def test_si_sdr_differentiability(self, preds, target, sk_metric, zero_mean): + self.run_differentiability_test( + preds=preds, + target=target, + metric_module=SI_SDR, + metric_functional=si_sdr, + metric_args={'zero_mean': zero_mean} + ) + + @pytest.mark.skipif( + not _TORCH_GREATER_EQUAL_1_6, reason='half support of core operations on not support before pytorch v1.6' + ) + def test_si_sdr_half_cpu(self, preds, target, sk_metric, zero_mean): + pytest.xfail("SI-SDR metric does not support cpu + half precision") + + @pytest.mark.skipif(not torch.cuda.is_available(), reason='test requires cuda') + def test_si_sdr_half_gpu(self, preds, target, sk_metric, zero_mean): + self.run_precision_test_gpu( + preds=preds, + target=target, + metric_module=SI_SDR, + metric_functional=si_sdr, + metric_args={'zero_mean': zero_mean} + ) + + +def test_error_on_different_shape(metric_class=SI_SDR): + metric = metric_class() + with pytest.raises(RuntimeError, match='Predictions and targets are expected to have the same shape'): + metric(torch.randn(100, ), torch.randn(50, )) diff --git a/tests/audio/test_si_snr.py b/tests/audio/test_si_snr.py new file mode 100644 index 00000000000..46d1cf6ed09 --- /dev/null +++ b/tests/audio/test_si_snr.py @@ -0,0 +1,112 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from collections import namedtuple +from functools import partial + +import pytest +import speechmetrics +import torch +from torch import Tensor + +from tests.helpers import seed_all +from tests.helpers.testers import BATCH_SIZE, NUM_BATCHES, MetricTester +from torchmetrics.audio import SI_SNR +from torchmetrics.functional import si_snr +from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_6 + +seed_all(42) + +Time = 100 + +Input = namedtuple('Input', ["preds", "target"]) + +inputs = Input( + preds=torch.rand(NUM_BATCHES, BATCH_SIZE, 1, Time), + target=torch.rand(NUM_BATCHES, BATCH_SIZE, 1, Time), +) + +speechmetrics_sisdr = speechmetrics.load('sisdr') + + +def speechmetrics_si_sdr(preds: Tensor, target: Tensor, zero_mean: bool = True): + # shape: preds [BATCH_SIZE, 1, Time] , target [BATCH_SIZE, 1, Time] + # or shape: preds [NUM_BATCHES*BATCH_SIZE, 1, Time] , target [NUM_BATCHES*BATCH_SIZE, 1, Time] + if zero_mean: + preds = preds - preds.mean(dim=2, keepdim=True) + target = target - target.mean(dim=2, keepdim=True) + target = target.detach().cpu().numpy() + preds = preds.detach().cpu().numpy() + mss = [] + for i in range(preds.shape[0]): + ms = [] + for j in range(preds.shape[1]): + metric = speechmetrics_sisdr(preds[i, j], target[i, j], rate=16000) + ms.append(metric['sisdr'][0]) + mss.append(ms) + return torch.tensor(mss) + + +def average_metric(preds, target, metric_func): + # shape: preds [BATCH_SIZE, 1, Time] , target [BATCH_SIZE, 1, Time] + # or shape: preds [NUM_BATCHES*BATCH_SIZE, 1, Time] , target [NUM_BATCHES*BATCH_SIZE, 1, Time] + return metric_func(preds, target).mean() + + +@pytest.mark.parametrize( + "preds, target, sk_metric", + [ + (inputs.preds, inputs.target, speechmetrics_si_sdr), + ], +) +class TestSISNR(MetricTester): + atol = 1e-2 + + @pytest.mark.parametrize("ddp", [True, False]) + @pytest.mark.parametrize("dist_sync_on_step", [True, False]) + def test_si_snr(self, preds, target, sk_metric, ddp, dist_sync_on_step): + self.run_class_metric_test( + ddp, + preds, + target, + SI_SNR, + sk_metric=partial(average_metric, metric_func=sk_metric), + dist_sync_on_step=dist_sync_on_step, + ) + + def test_si_snr_functional(self, preds, target, sk_metric): + self.run_functional_metric_test( + preds, + target, + si_snr, + sk_metric, + ) + + def test_si_snr_differentiability(self, preds, target, sk_metric): + self.run_differentiability_test(preds=preds, target=target, metric_module=SI_SNR, metric_functional=si_snr) + + @pytest.mark.skipif( + not _TORCH_GREATER_EQUAL_1_6, reason='half support of core operations on not support before pytorch v1.6' + ) + def test_si_snr_half_cpu(self, preds, target, sk_metric): + pytest.xfail("SI-SNR metric does not support cpu + half precision") + + @pytest.mark.skipif(not torch.cuda.is_available(), reason='test requires cuda') + def test_si_snr_half_gpu(self, preds, target, sk_metric): + self.run_precision_test_gpu(preds=preds, target=target, metric_module=SI_SNR, metric_functional=si_snr) + + +def test_error_on_different_shape(metric_class=SI_SNR): + metric = metric_class() + with pytest.raises(RuntimeError, match='Predictions and targets are expected to have the same shape'): + metric(torch.randn(100, ), torch.randn(50, )) From 2236698f5f82cfcd0d9a6e496776739131ddb6ef Mon Sep 17 00:00:00 2001 From: quancs <1017241746@qq.com> Date: Sat, 19 Jun 2021 01:51:14 +0800 Subject: [PATCH 094/109] test all & add quiet & remove test_speechmetrics --- .github/workflows/ci_test-conda.yml | 8 +- tests/audio/test_speechmetrics.py | 154 ---------------------------- 2 files changed, 4 insertions(+), 158 deletions(-) delete mode 100644 tests/audio/test_speechmetrics.py diff --git a/.github/workflows/ci_test-conda.yml b/.github/workflows/ci_test-conda.yml index 0c518198cec..e45f9b28710 100644 --- a/.github/workflows/ci_test-conda.yml +++ b/.github/workflows/ci_test-conda.yml @@ -17,7 +17,7 @@ jobs: pytorch-version: [1.3, 1.4, 1.5, 1.6, 1.7, 1.8, 1.9] # Timeout: https://stackoverflow.com/a/59076067/4521646 - timeout-minutes: 20 + timeout-minutes: 55 steps: - uses: actions/checkout@v2 @@ -63,8 +63,8 @@ jobs: pip --version python ./requirements/adjust-versions.py requirements.txt python ./requirements/adjust-versions.py requirements/image.txt - pip install --requirement requirements.txt --upgrade-strategy only-if-needed - pip install --requirement requirements/test.txt --upgrade-strategy only-if-needed --find-links https://download.pytorch.org/whl/cpu/torch_stable.html + pip install --requirement requirements.txt --upgrade-strategy only-if-needed --quiet + pip install --requirement requirements/test.txt --upgrade-strategy only-if-needed --quiet --find-links https://download.pytorch.org/whl/cpu/torch_stable.html pip list python -c "import torch; assert torch.__version__[:3] == '${{ matrix.pytorch-version }}', torch.__version__" shell: bash -l {0} @@ -72,7 +72,7 @@ jobs: - name: Testing run: | # NOTE: run coverage on tests does not propagare faler status for Win, https://github.com/nedbat/coveragepy/issues/1003 - python -m pytest tests/audio -v --durations=35 --junitxml=junit/test-conda-py${{ matrix.python-version }}-pt${{ matrix.pytorch-version }}.xml + python -m pytest torchmetrics tests -v --durations=35 --junitxml=junit/test-conda-py${{ matrix.python-version }}-pt${{ matrix.pytorch-version }}.xml shell: bash -l {0} - name: Upload pytest test results diff --git a/tests/audio/test_speechmetrics.py b/tests/audio/test_speechmetrics.py deleted file mode 100644 index 1a14b93b5b8..00000000000 --- a/tests/audio/test_speechmetrics.py +++ /dev/null @@ -1,154 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from collections import namedtuple -from functools import partial -from typing import Callable - -import pytest -import torch -from torch import Tensor - - -def test_import_speechmetrics() -> None: - try: - import speechmetrics - except ImportError: - pytest.fail('ImportError speechmetrics') - - -from tests.helpers import seed_all -from tests.helpers.testers import BATCH_SIZE, NUM_BATCHES, MetricTester - -seed_all(42) - -Time = 100 - -Input = namedtuple('Input', ["preds", "target"]) - -inputs = Input( - preds=torch.rand(NUM_BATCHES, BATCH_SIZE, 1, Time), - target=torch.rand(NUM_BATCHES, BATCH_SIZE, 1, Time), -) - -import multiprocessing -import time - -import speechmetrics - -speechmetrics_sisdr = speechmetrics.load('sisdr') - - -def speechmetrics_si_sdr(preds: Tensor, target: Tensor, zero_mean: bool) -> Tensor: - if zero_mean: - preds = preds - preds.mean(dim=2, keepdim=True) - target = target - target.mean(dim=2, keepdim=True) - target = target.detach().cpu().numpy() - preds = preds.detach().cpu().numpy() - mss = [] - for i in range(preds.shape[0]): - ms = [] - for j in range(preds.shape[1]): - metric = speechmetrics_sisdr(preds[i, j], target[i, j], rate=16000) - ms.append(metric['sisdr'][0]) - mss.append(ms) - return torch.tensor(mss) - - -def test_speechmetrics_si_sdr() -> None: - t = multiprocessing.Process(target=speechmetrics_si_sdr, args=(inputs.preds[0], inputs.target[0], False)) - t.start() - try: - t.join(timeout=180) # 3min - if t.is_alive(): - pytest.fail(f'timeout 3min. t.is_alive()={t.is_alive()}') - t.terminate() - except: - pytest.fail('join except') - - -from torchmetrics.audio import SI_SDR -from torchmetrics.functional import si_sdr - -speechmetrics_si_sdr_zero_mean = partial(speechmetrics_si_sdr, zero_mean=True) -speechmetrics_si_sdr_no_zero_mean = partial(speechmetrics_si_sdr, zero_mean=False) -from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_6 - -# def average_metric(preds, target, metric_func): -# # shape: preds [BATCH_SIZE, 1, Time] , target [BATCH_SIZE, 1, Time] -# # or shape: preds [NUM_BATCHES*BATCH_SIZE, 1, Time] , target [NUM_BATCHES*BATCH_SIZE, 1, Time] -# return metric_func(preds, target).mean() - - -@pytest.mark.parametrize( - "preds, target, sk_metric, zero_mean", - [ - (inputs.preds, inputs.target, speechmetrics_si_sdr_zero_mean, True), - (inputs.preds, inputs.target, speechmetrics_si_sdr_no_zero_mean, False), - ], -) -class TestSISDR(MetricTester): - atol = 1e-2 - - # @pytest.mark.parametrize("ddp", [True, False]) - # @pytest.mark.parametrize("dist_sync_on_step", [True, False]) - # def test_si_sdr(self, preds, target, sk_metric, zero_mean, ddp, dist_sync_on_step): - # self.run_class_metric_test( - # ddp, - # preds, - # target, - # SI_SDR, - # sk_metric=partial(average_metric, metric_func=sk_metric), - # dist_sync_on_step=dist_sync_on_step, - # metric_args=dict(zero_mean=zero_mean), - # ) - - def test_si_sdr_functional(self, preds, target, sk_metric, zero_mean): - self.run_functional_metric_test( - preds, - target, - si_sdr, - sk_metric, - metric_args=dict(zero_mean=zero_mean), - ) - - # def test_si_sdr_differentiability(self, preds, target, sk_metric, zero_mean): - # self.run_differentiability_test( - # preds=preds, - # target=target, - # metric_module=SI_SDR, - # metric_functional=si_sdr, - # metric_args={'zero_mean': zero_mean} - # ) - - # @pytest.mark.skipif( - # not _TORCH_GREATER_EQUAL_1_6, reason='half support of core operations on not support before pytorch v1.6' - # ) - # def test_si_sdr_half_cpu(self, preds, target, sk_metric, zero_mean): - # pytest.xfail("SI-SDR metric does not support cpu + half precision") - - # @pytest.mark.skipif(not torch.cuda.is_available(), reason='test requires cuda') - # def test_si_sdr_half_gpu(self, preds, target, sk_metric, zero_mean): - # self.run_precision_test_gpu( - # preds=preds, - # target=target, - # metric_module=SI_SDR, - # metric_functional=si_sdr, - # metric_args={'zero_mean': zero_mean} - # ) - - -# def test_error_on_different_shape(metric_class=SI_SDR): -# metric = metric_class() -# with pytest.raises(RuntimeError, match='Predictions and targets are expected to have the same shape'): -# metric(torch.randn(100, ), torch.randn(50, )) From f5a0411946d02740e59312ebb53909f220cc7734 Mon Sep 17 00:00:00 2001 From: quancs <1017241746@qq.com> Date: Sat, 19 Jun 2021 02:17:18 +0800 Subject: [PATCH 095/109] remove sudo & install libsndfile1 --- azure-pipelines.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/azure-pipelines.yml b/azure-pipelines.yml index 5ed2ff396f8..b0ebe3d9723 100644 --- a/azure-pipelines.yml +++ b/azure-pipelines.yml @@ -44,7 +44,7 @@ jobs: displayName: 'Image info & NVIDIA' - bash: | - apt install -y cmake ffmpeg git + apt-get install -y cmake ffmpeg git libsndfile1 # python -m pip install "pip==20.1" pip install --requirement ./requirements/devel.txt --upgrade-strategy only-if-needed pip uninstall -y torchmetrics From 38583974e3be13a36421c60a248d099a77e4f1e0 Mon Sep 17 00:00:00 2001 From: Jirka Date: Fri, 18 Jun 2021 21:10:39 +0200 Subject: [PATCH 096/109] apt --- azure-pipelines.yml | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/azure-pipelines.yml b/azure-pipelines.yml index b0ebe3d9723..6a322a78e22 100644 --- a/azure-pipelines.yml +++ b/azure-pipelines.yml @@ -26,7 +26,7 @@ jobs: pool: gridai-spot-pool container: - image: "pytorch/pytorch:1.7.1-cuda11.0-cudnn8-runtime" + image: "pytorch/pytorch:1.8.1-cuda11.1-cudnn8-runtime" options: "--runtime=nvidia -e NVIDIA_VISIBLE_DEVICES=all" workspace: @@ -44,11 +44,14 @@ jobs: displayName: 'Image info & NVIDIA' - bash: | - apt-get install -y cmake ffmpeg git libsndfile1 + sudo apt-get update + sudo apt install -y cmake ffmpeg git libsndfile1 # python -m pip install "pip==20.1" pip install --requirement ./requirements/devel.txt --upgrade-strategy only-if-needed pip uninstall -y torchmetrics pip list + env: + DEBIAN_FRONTEND: noninteractive displayName: 'Install dependencies' - bash: | From 3eb3d651b1bf8d52aba0f7859dcbf172336d0e5b Mon Sep 17 00:00:00 2001 From: quancs <1017241746@qq.com> Date: Tue, 22 Jun 2021 03:23:04 +0800 Subject: [PATCH 097/109] Update torchmetrics/functional/audio/si_sdr.py Co-authored-by: Nicki Skafte --- torchmetrics/functional/audio/si_sdr.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchmetrics/functional/audio/si_sdr.py b/torchmetrics/functional/audio/si_sdr.py index 229aadf1c0c..2d3b44c3ea6 100644 --- a/torchmetrics/functional/audio/si_sdr.py +++ b/torchmetrics/functional/audio/si_sdr.py @@ -50,8 +50,8 @@ def si_sdr(preds: Tensor, target: Tensor, zero_mean: bool = False) -> Tensor: target = target - torch.mean(target, dim=-1, keepdim=True) preds = preds - torch.mean(preds, dim=-1, keepdim=True) - α = (torch.sum(preds * target, dim=-1, keepdim=True) + EPS) / (torch.sum(target**2, dim=-1, keepdim=True) + EPS) - target_scaled = α * target + alpha = (torch.sum(preds * target, dim=-1, keepdim=True) + EPS) / (torch.sum(target**2, dim=-1, keepdim=True) + EPS) + target_scaled = alpha * target noise = target_scaled - preds From 4c18540cb1bd5a770dd63aedb11f0960c8e3f7ea Mon Sep 17 00:00:00 2001 From: thomas chaton Date: Mon, 21 Jun 2021 11:27:33 +0100 Subject: [PATCH 098/109] [feat] Add _apply_sync to nn.Metric (#302) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * add _apply_sync to nn.Metric * move to context manager * add sync * add restore_cache * add a sync test * Update torchmetrics/metric.py Co-authored-by: Nicki Skafte * remove _update_signature * Apply suggestions from code review Co-authored-by: Carlos Mocholí * resolve failing test Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Nicki Skafte Co-authored-by: Jirka Borovec Co-authored-by: Carlos Mocholí --- tests/bases/test_ddp.py | 55 +++++++++++++++ torchmetrics/metric.py | 153 ++++++++++++++++++++++++++++++---------- 2 files changed, 172 insertions(+), 36 deletions(-) diff --git a/tests/bases/test_ddp.py b/tests/bases/test_ddp.py index 28aca7b1173..20eeee517fc 100644 --- a/tests/bases/test_ddp.py +++ b/tests/bases/test_ddp.py @@ -11,7 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import os import sys +from copy import deepcopy +from unittest import mock import pytest import torch @@ -116,3 +119,55 @@ def compute(self): def test_non_contiguous_tensors(): """ Test that gather_all operation works for non contiguous tensors """ torch.multiprocessing.spawn(_test_non_contiguous_tensors, args=(2, ), nprocs=2) + + +def _test_state_dict_is_synced(rank, worldsize, tmpdir): + setup_ddp(rank, worldsize) + + class DummyCatMetric(Metric): + + def __init__(self): + super().__init__() + self.add_state("x", torch.tensor(0), dist_reduce_fx=torch.sum) + self.add_state("c", torch.tensor(0), dist_reduce_fx=torch.sum) + + def update(self, x): + self.x += x + self.c += 1 + + def compute(self): + return self.x // self.c + + metric = DummyCatMetric() + metric.persistent(True) + + steps = 5 + for i in range(steps): + metric(i) + state_dict = metric.state_dict() + + sum = i * (i + 1) / 2 + assert state_dict["x"] == sum * worldsize + assert metric.x == sum + assert metric.c == (i + 1) + assert state_dict["c"] == metric.c * worldsize + + def reload_state_dict(state_dict, expected_x, expected_c): + metric = DummyCatMetric() + metric.load_state_dict(state_dict) + assert metric.x == expected_x + assert metric.c == expected_c + + with mock.patch.dict(os.environ, {"GLOBAL_RANK": str(rank)}): + reload_state_dict(deepcopy(state_dict), 20 if not rank else 0, 10 if not rank else 0) + + reload_state_dict(deepcopy(state_dict), 20, 10) + + +@pytest.mark.skipif(sys.platform == "win32", reason="DDP not available on windows") +def test_state_dict_is_synced(tmpdir): + """ + This test asserts that metrics are synced while creating the state + dict but restored after to continue accumulation. + """ + torch.multiprocessing.spawn(_test_state_dict_is_synced, args=(2, tmpdir), nprocs=2) diff --git a/torchmetrics/metric.py b/torchmetrics/metric.py index 82262af26d1..e250608bfd5 100644 --- a/torchmetrics/metric.py +++ b/torchmetrics/metric.py @@ -14,10 +14,12 @@ import functools import inspect import operator +import os from abc import ABC, abstractmethod from collections.abc import Sequence +from contextlib import contextmanager from copy import deepcopy -from typing import Any, Callable, List, Optional, Union +from typing import Any, Callable, Dict, List, Optional, Union import torch from torch import Tensor, nn @@ -28,6 +30,10 @@ from torchmetrics.utilities.imports import _LIGHTNING_AVAILABLE, _compare_version +def distributed_available() -> bool: + return torch.distributed.is_available() and torch.distributed.is_initialized() + + class Metric(nn.Module, ABC): """ Base class for all metrics present in the Metrics API. @@ -83,6 +89,7 @@ def __init__( self.process_group = process_group self.dist_sync_fn = dist_sync_fn self._to_sync = True + self._restore_cache = True self._update_signature = inspect.signature(self.update) self.update = self._wrap_update(self.update) @@ -169,6 +176,8 @@ def forward(self, *args, **kwargs): if self.compute_on_step: self._to_sync = self.dist_sync_on_step + # skip restore cache operation from compute as cache is stored below. + self._restore_cache = False # save context before switch cache = {attr: getattr(self, attr) for attr in self._defaults} @@ -181,27 +190,31 @@ def forward(self, *args, **kwargs): # restore context for attr, val in cache.items(): setattr(self, attr, val) + + self._restore_cache = True self._to_sync = True self._computed = None return self._forward_cache - def _sync_dist(self, dist_sync_fn=gather_all_tensors): + def _sync_dist(self, dist_sync_fn: Callable = gather_all_tensors, process_group: Optional[Any] = None) -> None: input_dict = {attr: getattr(self, attr) for attr in self._reductions} + for attr, reduction_fn in self._reductions.items(): # pre-concatenate metric states that are lists to reduce number of all_gather operations if reduction_fn == dim_zero_cat and isinstance(input_dict[attr], list) and len(input_dict[attr]) > 1: input_dict[attr] = [dim_zero_cat(input_dict[attr])] + output_dict = apply_to_collection( input_dict, Tensor, dist_sync_fn, - group=self.process_group, + group=process_group or self.process_group, ) for attr, reduction_fn in self._reductions.items(): # pre-processing ops (stack or flatten for inputs) - if isinstance(output_dict[attr][0], Tensor): + if isinstance(output_dict[attr], Sequence) and isinstance(output_dict[attr][0], Tensor): output_dict[attr] = torch.stack(output_dict[attr]) elif isinstance(output_dict[attr][0], list): output_dict[attr] = _flatten(output_dict[attr]) @@ -221,6 +234,77 @@ def wrapped_func(*args, **kwargs): return wrapped_func + def sync( + self, + dist_sync_fn: Optional[Callable] = None, + process_group: Optional[Any] = None, + should_sync: bool = True, + distributed_available: Optional[Callable] = distributed_available, + ) -> Dict[str, Tensor]: + """ + Sync function for manually controlling when metrics states should be synced across processes + + Args: + dist_sync_fn: Function to be used to perform states synchronization + process_group: + Specify the process group on which synchronization is called. + default: None (which selects the entire world) + should_sync: Whether to apply to state synchronization. This will have an impact + only when running in a distributed setting. + distributed_available: Function to determine if we are running inside a distributed setting + + Returns: + cache: A dictionary containing the local metric states. The cache will be empty if sync didn't happen. + """ + is_distributed = distributed_available() + if not should_sync or not is_distributed: + return {} + if dist_sync_fn is None: + dist_sync_fn = gather_all_tensors + # cache prior to syncing + cache = {attr: getattr(self, attr) for attr in self._defaults} + # sync + self._sync_dist(dist_sync_fn, process_group=process_group) + return cache + + @contextmanager + def sync_context( + self, + dist_sync_fn: Optional[Callable] = None, + process_group: Optional[Any] = None, + should_sync: bool = True, + restore_cache: bool = True, + distributed_available: Optional[Callable] = distributed_available, + ) -> None: + """ + Context manager to synchronize the states between processes when running in a distributed setting + and restore the local cache states after yielding. + + Args: + dist_sync_fn: Function to be used to perform states synchronization + process_group: + Specify the process group on which synchronization is called. + default: None (which selects the entire world) + should_sync: Whether to apply to state synchronization. This will have an impact + only when running in a distributed setting. + restore_cache: Whether to restore the cache state so that the metrics can + continue to be accumulated. + distributed_available: Function to determine if we are running inside a distributed setting + """ + cache = self.sync( + dist_sync_fn=dist_sync_fn, + process_group=process_group, + should_sync=should_sync, + distributed_available=distributed_available + ) + + yield + + if cache and restore_cache: + # if we synced, restore to cache so that we can continue to accumulate un-synced state + for attr, val in cache.items(): + setattr(self, attr, val) + def _wrap_compute(self, compute): @functools.wraps(compute) @@ -236,26 +320,10 @@ def wrapped_func(*args, **kwargs): if self._computed is not None: return self._computed - dist_sync_fn = self.dist_sync_fn - if dist_sync_fn is None and torch.distributed.is_available() and torch.distributed.is_initialized(): - # User provided a bool, so we assume DDP if available - dist_sync_fn = gather_all_tensors - - synced = False - cache = [] - if self._to_sync and dist_sync_fn is not None: - # cache prior to syncing - cache = {attr: getattr(self, attr) for attr in self._defaults} - - # sync - self._sync_dist(dist_sync_fn) - synced = True - - self._computed = compute(*args, **kwargs) - if synced: - # if we synced, restore to cache so that we can continue to accumulate un-synced state - for attr, val in cache.items(): - setattr(self, attr, val) + with self.sync_context( + dist_sync_fn=self.dist_sync_fn, should_sync=self._to_sync, restore_cache=self._restore_cache + ): + self._computed = compute(*args, **kwargs) return self._computed @@ -299,11 +367,12 @@ def clone(self): def __getstate__(self): # ignore update and compute functions for pickling - return {k: v for k, v in self.__dict__.items() if k not in ["update", "compute"]} + return {k: v for k, v in self.__dict__.items() if k not in ["update", "compute", "_update_signature"]} def __setstate__(self, state): # manually restore update and compute functions for pickling self.__dict__.update(state) + self._update_signature = inspect.signature(self.update) self.update = self._wrap_update(self.update) self.compute = self._wrap_compute(self.compute) @@ -341,16 +410,23 @@ def persistent(self, mode: bool = False): def state_dict(self, destination=None, prefix="", keep_vars=False): destination = super().state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars) # Register metric states to be part of the state_dict - for key in self._defaults: - if self._persistent[key]: - current_val = getattr(self, key) - if not keep_vars: - if torch.is_tensor(current_val): - current_val = current_val.detach() - elif isinstance(current_val, list): - current_val = [cur_v.detach() if torch.is_tensor(cur_v) else cur_v for cur_v in current_val] - destination[prefix + key] = current_val - return destination + with self.sync_context(dist_sync_fn=self.dist_sync_fn): + for key in self._defaults: + if self._persistent[key]: + current_val = getattr(self, key) + if not keep_vars: + if isinstance(current_val, torch.Tensor): + current_val = current_val.detach() + elif isinstance(current_val, list): + current_val = [ + cur_v.detach() if isinstance(cur_v, torch.Tensor) else cur_v for cur_v in current_val + ] + # the tensors will be synced across processes so deepcopy to drop the references + destination[prefix + key] = deepcopy(current_val) + return destination + + def _should_load_from_state_dict(self) -> bool: + return os.getenv("GLOBAL_RANK", "0") == "0" def _load_from_state_dict( self, @@ -363,10 +439,15 @@ def _load_from_state_dict( error_msgs: List[str], ) -> None: """ Loads metric states from state_dict """ + + # only global rank 0 should be reloading the values present in the ``state_dict`` + # as the state contains synced values across all progress_group for key in self._defaults: name = prefix + key if name in state_dict: - setattr(self, key, state_dict.pop(name)) + value = state_dict.pop(name) + if self._should_load_from_state_dict(): + setattr(self, key, value) super()._load_from_state_dict( state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs ) From 3408927a5d7ec8c390759629abf1e54e419b426e Mon Sep 17 00:00:00 2001 From: Jirka Date: Mon, 21 Jun 2021 17:30:26 +0200 Subject: [PATCH 099/109] v0.4.0rc0 --- torchmetrics/__about__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchmetrics/__about__.py b/torchmetrics/__about__.py index d0f07475883..1b7964555f5 100644 --- a/torchmetrics/__about__.py +++ b/torchmetrics/__about__.py @@ -1,4 +1,4 @@ -__version__ = '0.4.0dev' +__version__ = '0.4.0rc0' __author__ = 'PyTorchLightning et al.' __author_email__ = 'name@pytorchlightning.ai' __license__ = 'Apache-2.0' From 267380d5646c3c771b90c9cd37b464f1549d1753 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 21 Jun 2021 19:58:53 +0200 Subject: [PATCH 100/109] [pre-commit.ci] pre-commit suggestions (#306) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit updates: - [github.com/PyCQA/isort: 5.8.0 → 5.9.1](https://github.com/PyCQA/isort/compare/5.8.0...5.9.1) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .pre-commit-config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 9b8addaf8ae..f655cbacb51 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -35,7 +35,7 @@ repos: - id: detect-private-key - repo: https://github.com/PyCQA/isort - rev: 5.8.0 + rev: 5.9.1 hooks: - id: isort name: imports From 69c8fbc63b8854b6c3f3f9a2fffdca271519dc22 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Mon, 21 Jun 2021 21:30:35 +0200 Subject: [PATCH 101/109] adding KID metric (#301) * implementation * parameter testing * fix test * implementation * update to torch fidelity 0.3.0 * changelog * docs * Apply suggestions from code review Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com> * add test * update * fix tests * typing * fix typing * fix bus error * Apply suggestions from code review Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com> Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com> Co-authored-by: Jirka Borovec --- CHANGELOG.md | 2 + docs/source/references/modules.rst | 4 + tests/image/test_fid.py | 3 +- tests/image/test_inception.py | 8 +- tests/image/test_kid.py | 168 +++++++++++++++++ torchmetrics/__init__.py | 2 +- torchmetrics/image/__init__.py | 1 + torchmetrics/image/fid.py | 5 +- torchmetrics/image/inception.py | 3 +- torchmetrics/image/kid.py | 280 +++++++++++++++++++++++++++++ torchmetrics/utilities/data.py | 2 + 11 files changed, 470 insertions(+), 8 deletions(-) create mode 100644 tests/image/test_kid.py create mode 100644 torchmetrics/image/kid.py diff --git a/CHANGELOG.md b/CHANGELOG.md index bebf1b62fa7..d1171a60f8c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -42,6 +42,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added Inception Score metric to image module ([#299](https://github.com/PyTorchLightning/metrics/pull/299)) +- Added KID metric to image module ([#301](https://github.com/PyTorchLightning/metrics/pull/301)) + ### Changed - Forward cache is now reset when `reset` method is called ([#260](https://github.com/PyTorchLightning/metrics/pull/260)) diff --git a/docs/source/references/modules.rst b/docs/source/references/modules.rst index 5f218fbf669..44f55c41fdf 100644 --- a/docs/source/references/modules.rst +++ b/docs/source/references/modules.rst @@ -309,6 +309,10 @@ IS .. autoclass:: torchmetrics.IS :noindex: + +.. autoclass:: torchmetrics.KID + :noindex: + ****************** Regression Metrics ****************** diff --git a/tests/image/test_fid.py b/tests/image/test_fid.py index ec4609b3832..d8a80677a1f 100644 --- a/tests/image/test_fid.py +++ b/tests/image/test_fid.py @@ -148,7 +148,8 @@ def test_compare_fid(tmpdir, feature=2048): input2=_ImgDataset(img2), fid=True, feature_layer_fid=str(feature), - batch_size=batch_size + batch_size=batch_size, + save_cpu_ram=True ) tm_res = metric.compute() diff --git a/tests/image/test_inception.py b/tests/image/test_inception.py index 120426374f8..0a740a076f8 100644 --- a/tests/image/test_inception.py +++ b/tests/image/test_inception.py @@ -87,8 +87,8 @@ def test_is_update_compute(): metric.update(img) mean, std = metric.compute() - assert mean != 0.0 - assert std != 0.0 + assert mean >= 0.0 + assert std >= 0.0 class _ImgDataset(Dataset): @@ -118,7 +118,9 @@ def test_compare_is(tmpdir): for i in range(img1.shape[0] // batch_size): metric.update(img1[batch_size * i:batch_size * (i + 1)].cuda()) - torch_fid = calculate_metrics(input1=_ImgDataset(img1), isc=True, isc_splits=1, batch_size=batch_size) + torch_fid = calculate_metrics( + input1=_ImgDataset(img1), isc=True, isc_splits=1, batch_size=batch_size, save_cpu_ram=True + ) tm_mean, tm_std = metric.compute() diff --git a/tests/image/test_kid.py b/tests/image/test_kid.py new file mode 100644 index 00000000000..ba83b9d4f54 --- /dev/null +++ b/tests/image/test_kid.py @@ -0,0 +1,168 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pickle + +import pytest +import torch +from torch.utils.data import Dataset + +from torchmetrics.image.kid import KID +from torchmetrics.utilities.imports import _TORCH_FIDELITY_AVAILABLE + +torch.manual_seed(42) + + +@pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason="test requires torch-fidelity") +def test_no_train(): + """ Assert that metric never leaves evaluation mode """ + + class MyModel(torch.nn.Module): + + def __init__(self): + super().__init__() + self.metric = KID() + + def forward(self, x): + return x + + model = MyModel() + model.train() + assert model.training + assert not model.metric.inception.training, 'FID metric was changed to training mode which should not happen' + + +@pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason='test requires torch-fidelity') +def test_kid_pickle(): + """ Assert that we can initialize the metric and pickle it""" + metric = KID() + assert metric + + # verify metrics work after being loaded from pickled state + pickled_metric = pickle.dumps(metric) + metric = pickle.loads(pickled_metric) + + +def test_kid_raises_errors_and_warnings(): + """ Test that expected warnings and errors are raised """ + with pytest.warns( + UserWarning, + match='Metric `KID` will save all extracted features in buffer.' + ' For large datasets this may lead to large memory footprint.' + ): + KID() + + if _TORCH_FIDELITY_AVAILABLE: + with pytest.raises(ValueError, match='Integer input to argument `feature` must be one of .*'): + KID(feature=2) + else: + with pytest.raises( + ValueError, + match='KID metric requires that Torch-fidelity is installed.' + 'Either install as `pip install torchmetrics[image]`' + ' or `pip install torch-fidelity`' + ): + KID() + + with pytest.raises(TypeError, match='Got unknown input to argument `feature`'): + KID(feature=[1, 2]) + + with pytest.raises(ValueError, match='Argument `subset_size` should be smaller than the number of samples'): + m = KID() + m.update(torch.randint(0, 255, (5, 3, 299, 299), dtype=torch.uint8), real=True) + m.update(torch.randint(0, 255, (5, 3, 299, 299), dtype=torch.uint8), real=False) + m.compute() + + +@pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason='test requires torch-fidelity') +def test_kid_extra_parameters(): + with pytest.raises(ValueError, match="Argument `subsets` expected to be integer larger than 0"): + KID(subsets=-1) + + with pytest.raises(ValueError, match="Argument `subset_size` expected to be integer larger than 0"): + KID(subset_size=-1) + + with pytest.raises(ValueError, match="Argument `degree` expected to be integer larger than 0"): + KID(degree=-1) + + with pytest.raises(ValueError, match="Argument `gamma` expected to be `None` or float larger than 0"): + KID(gamma=-1) + + with pytest.raises(ValueError, match="Argument `coef` expected to be float larger than 0"): + KID(coef=-1) + + +@pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason='test requires torch-fidelity') +@pytest.mark.parametrize("feature", [64, 192, 768, 2048]) +def test_kid_same_input(feature): + """ test that the metric works """ + metric = KID(feature=feature, subsets=5, subset_size=2) + + for _ in range(2): + img = torch.randint(0, 255, (10, 3, 299, 299), dtype=torch.uint8) + metric.update(img, real=True) + metric.update(img, real=False) + + assert torch.allclose(torch.cat(metric.real_features, dim=0), torch.cat(metric.fake_features, dim=0)) + + mean, std = metric.compute() + assert mean != 0.0 + assert std >= 0.0 + + +class _ImgDataset(Dataset): + + def __init__(self, imgs): + self.imgs = imgs + + def __getitem__(self, idx): + return self.imgs[idx] + + def __len__(self): + return self.imgs.shape[0] + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason='test is too slow without gpu') +@pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason='test requires torch-fidelity') +def test_compare_kid(tmpdir, feature=2048): + """ check that the hole pipeline give the same result as torch-fidelity """ + from torch_fidelity import calculate_metrics + + metric = KID(feature=feature, subsets=1, subset_size=100).cuda() + + # Generate some synthetic data + img1 = torch.randint(0, 180, (100, 3, 299, 299), dtype=torch.uint8) + img2 = torch.randint(100, 255, (100, 3, 299, 299), dtype=torch.uint8) + + batch_size = 10 + for i in range(img1.shape[0] // batch_size): + metric.update(img1[batch_size * i:batch_size * (i + 1)].cuda(), real=True) + + for i in range(img2.shape[0] // batch_size): + metric.update(img2[batch_size * i:batch_size * (i + 1)].cuda(), real=False) + + torch_fid = calculate_metrics( + input1=_ImgDataset(img1), + input2=_ImgDataset(img2), + kid=True, + feature_layer_fid=str(feature), + batch_size=batch_size, + kid_subsets=1, + kid_subset_size=100, + save_cpu_ram=True + ) + + tm_mean, tm_std = metric.compute() + + assert torch.allclose(tm_mean.cpu(), torch.tensor([torch_fid['kernel_inception_distance_mean']]), atol=1e-3) + assert torch.allclose(tm_std.cpu(), torch.tensor([torch_fid['kernel_inception_distance_std']]), atol=1e-3) diff --git a/torchmetrics/__init__.py b/torchmetrics/__init__.py index 738094916d2..430552fa961 100644 --- a/torchmetrics/__init__.py +++ b/torchmetrics/__init__.py @@ -37,7 +37,7 @@ StatScores, ) from torchmetrics.collections import MetricCollection # noqa: F401 E402 -from torchmetrics.image import FID, IS # noqa: F401 E402 +from torchmetrics.image import FID, IS, KID # noqa: F401 E402 from torchmetrics.metric import Metric # noqa: F401 E402 from torchmetrics.regression import ( # noqa: F401 E402 PSNR, diff --git a/torchmetrics/image/__init__.py b/torchmetrics/image/__init__.py index a0bee1440ee..8098e92e643 100644 --- a/torchmetrics/image/__init__.py +++ b/torchmetrics/image/__init__.py @@ -13,3 +13,4 @@ # limitations under the License. from torchmetrics.image.fid import FID # noqa: F401 from torchmetrics.image.inception import IS # noqa: F401 +from torchmetrics.image.kid import KID # noqa: F401 diff --git a/torchmetrics/image/fid.py b/torchmetrics/image/fid.py index 6cd2889ab95..f23190dcebc 100644 --- a/torchmetrics/image/fid.py +++ b/torchmetrics/image/fid.py @@ -20,6 +20,7 @@ from torchmetrics.metric import Metric from torchmetrics.utilities import rank_zero_info, rank_zero_warn +from torchmetrics.utilities.data import dim_zero_cat from torchmetrics.utilities.imports import _TORCH_FIDELITY_AVAILABLE if _TORCH_FIDELITY_AVAILABLE: @@ -257,8 +258,8 @@ def update(self, imgs: Tensor, real: bool) -> None: # type: ignore def compute(self) -> Tensor: """ Calculate FID score based on accumulated extracted features from the two distributions """ - real_features = torch.cat(self.real_features, dim=0) - fake_features = torch.cat(self.fake_features, dim=0) + real_features = dim_zero_cat(self.real_features) + fake_features = dim_zero_cat(self.fake_features) # computation is extremely sensitive so it needs to happen in double precision orig_dtype = real_features.dtype real_features = real_features.double() diff --git a/torchmetrics/image/inception.py b/torchmetrics/image/inception.py index 87416445784..0b7352e69de 100644 --- a/torchmetrics/image/inception.py +++ b/torchmetrics/image/inception.py @@ -19,6 +19,7 @@ from torchmetrics.image.fid import NoTrainInceptionV3 from torchmetrics.metric import Metric from torchmetrics.utilities import rank_zero_warn +from torchmetrics.utilities.data import dim_zero_cat from torchmetrics.utilities.imports import _TORCH_FIDELITY_AVAILABLE @@ -153,7 +154,7 @@ def update(self, imgs: Tensor) -> None: # type: ignore self.features.append(features) def compute(self) -> Tuple[Tensor, Tensor]: - features = torch.cat(self.features, dim=0) + features = dim_zero_cat(self.features) # random permute the features idx = torch.randperm(features.shape[0]) features = features[idx] diff --git a/torchmetrics/image/kid.py b/torchmetrics/image/kid.py new file mode 100644 index 00000000000..42e7c1e23e3 --- /dev/null +++ b/torchmetrics/image/kid.py @@ -0,0 +1,280 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Callable, Optional, Tuple, Union + +import torch +from torch import Tensor + +from torchmetrics.image.fid import NoTrainInceptionV3 +from torchmetrics.metric import Metric +from torchmetrics.utilities import rank_zero_warn +from torchmetrics.utilities.data import dim_zero_cat +from torchmetrics.utilities.imports import _TORCH_FIDELITY_AVAILABLE + + +def maximum_mean_discrepancy(k_xx: Tensor, k_xy: Tensor, k_yy: Tensor) -> Tensor: + """ + Adapted from https://github.com/toshas/torch-fidelity/blob/v0.3.0/torch_fidelity/metric_kid.py + """ + m = k_xx.shape[0] + + diag_x = torch.diag(k_xx) + diag_y = torch.diag(k_yy) + + kt_xx_sums = k_xx.sum(dim=-1) - diag_x + kt_yy_sums = k_yy.sum(dim=-1) - diag_y + k_xy_sums = k_xy.sum(dim=0) + + kt_xx_sum = kt_xx_sums.sum() + kt_yy_sum = kt_yy_sums.sum() + k_xy_sum = k_xy_sums.sum() + + value = (kt_xx_sum + kt_yy_sum) / (m * (m - 1)) + value -= 2 * k_xy_sum / (m**2) + return value + + +def poly_kernel(f1: Tensor, f2: Tensor, degree: int = 3, gamma: Optional[float] = None, coef: float = 1.0) -> Tensor: + """ + Adapted from https://github.com/toshas/torch-fidelity/blob/v0.3.0/torch_fidelity/metric_kid.py + """ + if gamma is None: + gamma = 1.0 / f1.shape[1] + kernel = (f1 @ f2.T * gamma + coef)**degree + return kernel + + +def poly_mmd( + f_real: Tensor, f_fake: Tensor, degree: int = 3, gamma: Optional[float] = None, coef: float = 1.0 +) -> Tensor: + """ + Adapted from https://github.com/toshas/torch-fidelity/blob/v0.3.0/torch_fidelity/metric_kid.py + """ + k_11 = poly_kernel(f_real, f_real, degree, gamma, coef) + k_22 = poly_kernel(f_fake, f_fake, degree, gamma, coef) + k_12 = poly_kernel(f_real, f_fake, degree, gamma, coef) + return maximum_mean_discrepancy(k_11, k_12, k_22) + + +class KID(Metric): + r""" + Calculates Kernel Inception Distance (KID) which is used to access the quality of generated images. Given by + + .. math:: + KID = MMD(f_{real}, f_{fake})^2 + + where :math:`MMD` is the maximum mean discrepancy and :math:`I_{real}, I_{fake}` are extracted features + from real and fake images, see [1] for more details. In particular, calculating the MMD requires the + evaluation of a polynomial kernel function :math:`k` + + .. math:: + k(x,y) = (\gamma * x^T y + coef)^degree + + which controls the distance between two features. In practise the MMD is calculated over a number of + subsets to be able to both get the mean and standard deviation of KID. + + Using the default feature extraction (Inception v3 using the original weights from [2]), the input is + expected to be mini-batches of 3-channel RGB images of shape (3 x H x W) with dtype uint8. All images + will be resized to 299 x 299 which is the size of the original training data. + + .. note:: using this metric with the default feature extractor requires that ``torch-fidelity`` + is installed. Either install as ``pip install torchmetrics[image]`` or + ``pip install torch-fidelity`` + + .. note:: the ``forward`` method can be used but ``compute_on_step`` is disabled by default (oppesit of + all other metrics) as this metric does not really make sense to calculate on a single batch. This + means that by default ``forward`` will just call ``update`` underneat. + + Args: + feature: + Either an str, integer or ``nn.Module``: + + - an str or integer will indicate the inceptionv3 feature layer to choose. Can be one of the following: + 'logits_unbiased', 64, 192, 768, 2048 + - an ``nn.Module`` for using a custom feature extractor. Expects that its forward method returns + an ``[N,d]`` matrix where ``N`` is the batch size and ``d`` is the feature size. + + subsets: + Number of subsets to calculate the mean and standard deviation scores over + subset_size: + Number of randomly picked samples in each subset + degree: + Degree of the polynomial kernel function + gamma: + Scale-length of polynomial kernel. If set to ``None`` will be automatically set to the feature size + coef: + Bias term in the polynomial kernel. + compute_on_step: + Forward only calls ``update()`` and return ``None`` if this is set to ``False``. + dist_sync_on_step: + Synchronize metric state across processes at each ``forward()`` + before returning the value at the step + process_group: + Specify the process group on which synchronization is called. + default: ``None`` (which selects the entire world) + dist_sync_fn: + Callback that performs the allgather operation on the metric state. When ``None``, DDP + will be used to perform the allgather + + [1] Demystifying MMD GANs + Mikołaj Bińkowski, Danica J. Sutherland, Michael Arbel, Arthur Gretton + https://arxiv.org/abs/1801.01401 + + [2] GANs Trained by a Two Time-Scale Update Rule Converge to a Local Nash Equilibrium, + Martin Heusel, Hubert Ramsauer, Thomas Unterthiner, Bernhard Nessler, Sepp Hochreiter + https://arxiv.org/abs/1706.08500 + + Raises: + ValueError: + If ``feature`` is set to an ``int`` (default settings) and ``torch-fidelity`` is not installed + ValueError: + If ``feature`` is set to an ``int`` not in [64, 192, 768, 2048] + ValueError: + If ``subsets`` is not an integer larger than 0 + ValueError: + If ``subset_size`` is not an integer larger than 0 + ValueError: + If ``degree`` is not an integer larger than 0 + ValueError: + If ``gamma`` is niether ``None`` or a float larger than 0 + ValueError: + If ``coef`` is not an float larger than 0 + + Example: + >>> import torch + >>> _ = torch.manual_seed(123) + >>> from torchmetrics import KID + >>> kid = KID(subset_size=50) # doctest: +SKIP + >>> # generate two slightly overlapping image intensity distributions + >>> imgs_dist1 = torch.randint(0, 200, (100, 3, 299, 299), dtype=torch.uint8) # doctest: +SKIP + >>> imgs_dist2 = torch.randint(100, 255, (100, 3, 299, 299), dtype=torch.uint8) # doctest: +SKIP + >>> kid.update(imgs_dist1, real=True) # doctest: +SKIP + >>> kid.update(imgs_dist2, real=False) # doctest: +SKIP + >>> kid_mean, kid_std = kid.compute() # doctest: +SKIP + >>> print((kid_mean, kid_std)) # doctest: +SKIP + (tensor(0.0338), tensor(0.0025)) + + """ + + def __init__( + self, + feature: Union[str, int, torch.nn.Module] = 2048, + subsets: int = 100, + subset_size: int = 1000, + degree: int = 3, + gamma: Optional[float] = None, # type: ignore + coef: float = 1.0, + compute_on_step: bool = False, + dist_sync_on_step: bool = False, + process_group: Optional[Any] = None, + dist_sync_fn: Callable = None + ): + super().__init__( + compute_on_step=compute_on_step, + dist_sync_on_step=dist_sync_on_step, + process_group=process_group, + dist_sync_fn=dist_sync_fn, + ) + + rank_zero_warn( + 'Metric `KID` will save all extracted features in buffer.' + ' For large datasets this may lead to large memory footprint.', UserWarning + ) + + if isinstance(feature, (str, int)): + if not _TORCH_FIDELITY_AVAILABLE: + raise RuntimeError( + 'KID metric requires that Torch-fidelity is installed.' + ' Either install as `pip install torchmetrics[image]`' + ' or `pip install torch-fidelity`' + ) + valid_int_input = ('logits_unbiased', 64, 192, 768, 2048) + if feature not in valid_int_input: + raise ValueError( + f'Integer input to argument `feature` must be one of {valid_int_input},' + f' but got {feature}.' + ) + + self.inception = NoTrainInceptionV3(name='inception-v3-compat', features_list=[str(feature)]) + elif isinstance(feature, torch.nn.Module): + self.inception = feature + else: + raise TypeError('Got unknown input to argument `feature`') + + if not (isinstance(subsets, int) and subsets > 0): + raise ValueError("Argument `subsets` expected to be integer larger than 0") + self.subsets = subsets + + if not (isinstance(subset_size, int) and subset_size > 0): + raise ValueError("Argument `subset_size` expected to be integer larger than 0") + self.subset_size = subset_size + + if not (isinstance(degree, int) and degree > 0): + raise ValueError("Argument `degree` expected to be integer larger than 0") + self.degree = degree + + if gamma is not None and not (isinstance(gamma, float) and gamma > 0): + raise ValueError("Argument `gamma` expected to be `None` or float larger than 0") + self.gamma = gamma + + if not (isinstance(coef, float) and coef > 0): + raise ValueError("Argument `coef` expected to be float larger than 0") + self.coef = coef + + # states for extracted features + self.add_state("real_features", [], dist_reduce_fx=None) + self.add_state("fake_features", [], dist_reduce_fx=None) + + def update(self, imgs: Tensor, real: bool) -> None: # type: ignore + """ Update the state with extracted features + + Args: + imgs: tensor with images feed to the feature extractor + real: bool indicating if imgs belong to the real or the fake distribution + """ + features = self.inception(imgs) + + if real: + self.real_features.append(features) + else: + self.fake_features.append(features) + + def compute(self) -> Tuple[Tensor, Tensor]: + """ Calculate KID score based on accumulated extracted features from the two distributions. + Returns a tuple of mean and standard deviation of KID scores calculated on subsets of + extracted features. + + Implementation inspired by https://github.com/toshas/torch-fidelity/blob/v0.3.0/torch_fidelity/metric_kid.py + """ + real_features = dim_zero_cat(self.real_features) + fake_features = dim_zero_cat(self.fake_features) + + n_samples_real = real_features.shape[0] + if n_samples_real < self.subset_size: + raise ValueError('Argument `subset_size` should be smaller than the number of samples') + n_samples_fake = fake_features.shape[0] + if n_samples_fake < self.subset_size: + raise ValueError('Argument `subset_size` should be smaller than the number of samples') + + kid_scores_ = [] + for i in range(self.subsets): + perm = torch.randperm(n_samples_real) + f_real = real_features[perm[:self.subset_size]] + perm = torch.randperm(n_samples_fake) + f_fake = fake_features[perm[:self.subset_size]] + + o = poly_mmd(f_real, f_fake, self.degree, self.gamma, self.coef) + kid_scores_.append(o) + kid_scores = torch.stack(kid_scores_) + return kid_scores.mean(), kid_scores.std(unbiased=False) diff --git a/torchmetrics/utilities/data.py b/torchmetrics/utilities/data.py index 559f752ad55..312b8725c17 100644 --- a/torchmetrics/utilities/data.py +++ b/torchmetrics/utilities/data.py @@ -24,6 +24,8 @@ def dim_zero_cat(x: Union[Tensor, List[Tensor]]) -> Tensor: x = x if isinstance(x, (list, tuple)) else [x] x = [y.unsqueeze(0) if y.numel() == 1 and y.ndim == 0 else y for y in x] + if not x: # empty list + raise ValueError('No samples to concatenate') return torch.cat(x, dim=0) From e7753e1c33fdbefd96e2bb25b436354b6169df59 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Tue, 22 Jun 2021 14:04:48 +0200 Subject: [PATCH 102/109] Apply suggestions from code review --- requirements/test.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements/test.txt b/requirements/test.txt index a8439b8a127..3d23f08a207 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -23,4 +23,4 @@ nltk>=3.6 # audio pypesq mir_eval>=0.6 -https://github.com/aliutkus/speechmetrics/archive/refs/heads/master.zip +speechmetrics @ https://github.com/aliutkus/speechmetrics/archive/refs/heads/master.zip From e82bc96f12b9d49a6bef559262e5bbd9a136bcf8 Mon Sep 17 00:00:00 2001 From: Jirka Date: Tue, 22 Jun 2021 14:14:07 +0200 Subject: [PATCH 103/109] SRMRpy --- requirements/test.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements/test.txt b/requirements/test.txt index 3d23f08a207..236208b0e50 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -23,4 +23,5 @@ nltk>=3.6 # audio pypesq mir_eval>=0.6 +SRMRpy @ https://github.com/jfsantos/SRMRpy/archive/refs/heads/master.zip speechmetrics @ https://github.com/aliutkus/speechmetrics/archive/refs/heads/master.zip From bb045c99a30a64450a66540a719b6288bfba6c6a Mon Sep 17 00:00:00 2001 From: Jirka Date: Tue, 22 Jun 2021 14:16:48 +0200 Subject: [PATCH 104/109] pesq --- requirements/test.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements/test.txt b/requirements/test.txt index 236208b0e50..c927ade68c3 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -23,5 +23,6 @@ nltk>=3.6 # audio pypesq mir_eval>=0.6 +pesq @ https://github.com/ludlows/python-pesq/archive/refs/heads/master.zip SRMRpy @ https://github.com/jfsantos/SRMRpy/archive/refs/heads/master.zip speechmetrics @ https://github.com/aliutkus/speechmetrics/archive/refs/heads/master.zip From 70687f1350d6afc5dd538663dcd4bdea2ec179ed Mon Sep 17 00:00:00 2001 From: Jirka Date: Tue, 22 Jun 2021 14:21:56 +0200 Subject: [PATCH 105/109] gcc --- azure-pipelines.yml | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/azure-pipelines.yml b/azure-pipelines.yml index 6a322a78e22..5bac700ebd7 100644 --- a/azure-pipelines.yml +++ b/azure-pipelines.yml @@ -44,8 +44,9 @@ jobs: displayName: 'Image info & NVIDIA' - bash: | - sudo apt-get update - sudo apt install -y cmake ffmpeg git libsndfile1 + su + apt-get update + apt install -y gcc cmake ffmpeg git libsndfile1 # python -m pip install "pip==20.1" pip install --requirement ./requirements/devel.txt --upgrade-strategy only-if-needed pip uninstall -y torchmetrics From 050bb7b72ac8285ece76bbb3ee1b9b8d9cef2148 Mon Sep 17 00:00:00 2001 From: Jirka Date: Tue, 22 Jun 2021 15:58:24 +0200 Subject: [PATCH 106/109] comment -u root cuda 10.2 whoami --- azure-pipelines.yml | 17 +++++++++++------ requirements/test.txt | 4 ++-- 2 files changed, 13 insertions(+), 8 deletions(-) diff --git a/azure-pipelines.yml b/azure-pipelines.yml index 5bac700ebd7..8066ae243d5 100644 --- a/azure-pipelines.yml +++ b/azure-pipelines.yml @@ -27,7 +27,7 @@ jobs: container: image: "pytorch/pytorch:1.8.1-cuda11.1-cudnn8-runtime" - options: "--runtime=nvidia -e NVIDIA_VISIBLE_DEVICES=all" + options: "--runtime=nvidia -e NVIDIA_VISIBLE_DEVICES=all --name ci-container -v /usr/bin/docker:/tmp/docker:ro" workspace: clean: all @@ -35,6 +35,8 @@ jobs: steps: - bash: | + whoami + id lspci | egrep 'VGA|3D' whereis nvidia nvidia-smi @@ -43,16 +45,19 @@ jobs: pip list displayName: 'Image info & NVIDIA' + - script: | + /tmp/docker exec -t -u 0 ci-container \ + sh -c "apt-get update && DEBIAN_FRONTEND=noninteractive apt-get -o Dpkg::Options::="--force-confold" -y install sudo" + displayName: 'Install Sudo in container (thanks Microsoft!)' + - bash: | - su - apt-get update - apt install -y gcc cmake ffmpeg git libsndfile1 + set -ex + sudo apt-get update + sudo apt-get install -y gcc cmake ffmpeg git libsndfile1 # python -m pip install "pip==20.1" pip install --requirement ./requirements/devel.txt --upgrade-strategy only-if-needed pip uninstall -y torchmetrics pip list - env: - DEBIAN_FRONTEND: noninteractive displayName: 'Install dependencies' - bash: | diff --git a/requirements/test.txt b/requirements/test.txt index c927ade68c3..64e19052893 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -23,6 +23,6 @@ nltk>=3.6 # audio pypesq mir_eval>=0.6 -pesq @ https://github.com/ludlows/python-pesq/archive/refs/heads/master.zip -SRMRpy @ https://github.com/jfsantos/SRMRpy/archive/refs/heads/master.zip +#pesq @ https://github.com/ludlows/python-pesq/archive/refs/heads/master.zip +#SRMRpy @ https://github.com/jfsantos/SRMRpy/archive/refs/heads/master.zip speechmetrics @ https://github.com/aliutkus/speechmetrics/archive/refs/heads/master.zip From de034a397f86771985ec9e36be477a92e42d376b Mon Sep 17 00:00:00 2001 From: Jirka Date: Tue, 22 Jun 2021 18:47:13 +0200 Subject: [PATCH 107/109] env --- azure-pipelines.yml | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/azure-pipelines.yml b/azure-pipelines.yml index 8066ae243d5..bea545735a9 100644 --- a/azure-pipelines.yml +++ b/azure-pipelines.yml @@ -47,7 +47,9 @@ jobs: - script: | /tmp/docker exec -t -u 0 ci-container \ - sh -c "apt-get update && DEBIAN_FRONTEND=noninteractive apt-get -o Dpkg::Options::="--force-confold" -y install sudo" + sh -c "apt-get update && apt-get -o Dpkg::Options::="--force-confold" -y install sudo" + env: + DEBIAN_FRONTEND: noninteractive displayName: 'Install Sudo in container (thanks Microsoft!)' - bash: | From cda3c6b5227697b203f8ca8a3e4d0bdebd8f7bec Mon Sep 17 00:00:00 2001 From: Jirka Date: Tue, 22 Jun 2021 18:58:09 +0200 Subject: [PATCH 108/109] env --- azure-pipelines.yml | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/azure-pipelines.yml b/azure-pipelines.yml index bea545735a9..8066ae243d5 100644 --- a/azure-pipelines.yml +++ b/azure-pipelines.yml @@ -47,9 +47,7 @@ jobs: - script: | /tmp/docker exec -t -u 0 ci-container \ - sh -c "apt-get update && apt-get -o Dpkg::Options::="--force-confold" -y install sudo" - env: - DEBIAN_FRONTEND: noninteractive + sh -c "apt-get update && DEBIAN_FRONTEND=noninteractive apt-get -o Dpkg::Options::="--force-confold" -y install sudo" displayName: 'Install Sudo in container (thanks Microsoft!)' - bash: | From fec35986b7d1b52be30b347e032363d1f3a9dcb0 Mon Sep 17 00:00:00 2001 From: Jirka Date: Tue, 22 Jun 2021 19:05:50 +0200 Subject: [PATCH 109/109] env --- azure-pipelines.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/azure-pipelines.yml b/azure-pipelines.yml index 8066ae243d5..dc81c15e47b 100644 --- a/azure-pipelines.yml +++ b/azure-pipelines.yml @@ -26,7 +26,7 @@ jobs: pool: gridai-spot-pool container: - image: "pytorch/pytorch:1.8.1-cuda11.1-cudnn8-runtime" + image: "pytorch/pytorch:1.8.1-cuda10.2-cudnn7-runtime" options: "--runtime=nvidia -e NVIDIA_VISIBLE_DEVICES=all --name ci-container -v /usr/bin/docker:/tmp/docker:ro" workspace: