Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add PIT for audio metrics #384

Merged
merged 73 commits into from
Jul 29, 2021
Merged
Show file tree
Hide file tree
Changes from 42 commits
Commits
Show all changes
73 commits
Select commit Hold shift + click to select a range
95266f7
finish functional version
quancs Jul 17, 2021
3eaf506
change eval_func to str type
quancs Jul 17, 2021
f63fedb
remove return_best_perm
quancs Jul 17, 2021
58dcf1d
raise
quancs Jul 17, 2021
bb186d6
PIT
quancs Jul 17, 2021
13be4e4
max
quancs Jul 17, 2021
1c722de
add test
quancs Jul 17, 2021
9a538eb
Merge branch 'PyTorchLightning:master' into audio-pit
quancs Jul 17, 2021
123a686
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 17, 2021
5d2977c
add
quancs Jul 17, 2021
3b9172f
fix pep8
quancs Jul 17, 2021
2c01d18
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 17, 2021
7c2078f
fix pep8
quancs Jul 17, 2021
70d964e
Merge branch 'audio-pit' of https://github.com/quancs/metrics into au…
quancs Jul 17, 2021
85daf5c
fix mypy
quancs Jul 17, 2021
e129896
add scipy
quancs Jul 17, 2021
bcf43a3
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 17, 2021
801b80c
fix
quancs Jul 17, 2021
afaa79c
Merge branch 'audio-pit' of https://github.com/quancs/metrics into au…
quancs Jul 17, 2021
f9f770f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 17, 2021
dee14ea
fix doctest
quancs Jul 17, 2021
5d06bf1
Merge branch 'audio-pit' of https://github.com/quancs/metrics into au…
quancs Jul 17, 2021
de37c92
fix doctest
quancs Jul 17, 2021
05e061b
fix doctest
quancs Jul 17, 2021
ae12aa0
remove scipy dep & add warning
quancs Jul 18, 2021
112bc16
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 18, 2021
dc5e689
change warn
quancs Jul 18, 2021
6a3beda
Merge branch 'audio-pit' of https://github.com/quancs/metrics into au…
quancs Jul 18, 2021
208b479
add
quancs Jul 18, 2021
736e354
move scipy import to inner function
quancs Jul 18, 2021
a7a4588
Merge branch 'master' into audio-pit
Borda Jul 20, 2021
271d899
Apply suggestions from code review
Borda Jul 20, 2021
2e1ae3f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 20, 2021
1059575
simplyfied
quancs Jul 20, 2021
07ecb26
add test_consistency_of_two_implementations
quancs Jul 20, 2021
16d70f9
use _SCIPY_AVAILABLE
quancs Jul 20, 2021
b452792
add docstring
quancs Jul 20, 2021
6367adc
to TIME
quancs Jul 20, 2021
29744ab
a_naive_implementation_of_pit_based_on_scipy
quancs Jul 20, 2021
2070d3a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 20, 2021
859e25e
add docstring for _find_best_perm_by_linear_sum_assignment
quancs Jul 20, 2021
776cf88
Merge branch 'audio-pit' of https://github.com/quancs/metrics into au…
quancs Jul 20, 2021
70adaa3
Update torchmetrics/functional/audio/pit.py
quancs Jul 22, 2021
a17b357
Update torchmetrics/functional/audio/pit.py
quancs Jul 22, 2021
39105ae
Update torchmetrics/functional/audio/pit.py
quancs Jul 22, 2021
dbb0e5a
Update torchmetrics/functional/audio/pit.py
quancs Jul 22, 2021
c308c02
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 22, 2021
a4e347a
use self.run_functional_metric_test
quancs Jul 22, 2021
01d484e
add more description
quancs Jul 22, 2021
5a974e9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 22, 2021
400b323
import warnings
quancs Jul 22, 2021
458b912
check ValueError
quancs Jul 22, 2021
041dfaf
Merge branch 'audio-pit' of https://github.com/quancs/metrics into au…
quancs Jul 22, 2021
5d709d8
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 22, 2021
9d15f89
Merge branch 'master' into audio-pit
Borda Jul 24, 2021
30e7f67
Merge branch 'master' into audio-pit
Borda Jul 26, 2021
312999a
Merge branch 'master' into audio-pit
Borda Jul 28, 2021
8d3ed75
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 28, 2021
ca3425d
Apply suggestions from code review
Borda Jul 28, 2021
36911fd
Merge branch 'master' into audio-pit
SkafteNicki Jul 28, 2021
1bd7f5d
fix imports
Borda Jul 28, 2021
6b143c7
Merge branch 'master' into audio-pit
Borda Jul 28, 2021
05cacbb
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 28, 2021
9db192f
fix test
Borda Jul 28, 2021
bf41ab6
Merge branch 'audio-pit' of https://github.com/quancs/metrics into au…
Borda Jul 28, 2021
65a27ff
docs
Borda Jul 28, 2021
a63297d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 28, 2021
6177f78
Merge branch 'master' into audio-pit
Borda Jul 29, 2021
8f3dc55
Apply suggestions from code review
Borda Jul 29, 2021
23ab576
Apply suggestions from code review
SkafteNicki Jul 29, 2021
72eace5
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 29, 2021
30c94e6
Merge branch 'master' into audio-pit
Borda Jul 29, 2021
40b7e81
Merge branch 'master' into audio-pit
Borda Jul 29, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added Symmetric Mean Absolute Percentage error (SMAPE) ([#375](https://github.com/PyTorchLightning/metrics/issues/375))


- Added Permutation Invariant Training metric (PIT) ([#294](https://github.com/PyTorchLightning/metrics/issues/294))


### Changed

- Moved `psnr` and `ssim` from `torchmetrics.functional.regression.*` to `torchmetrics.functional.image.*` ([#382](https://github.com/PyTorchLightning/metrics/pull/382))
Expand Down
7 changes: 7 additions & 0 deletions docs/source/references/functional.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,13 @@ Functional metrics
Audio Metrics
*************

pit [func]
~~~~~~~~~~

.. autofunction:: torchmetrics.functional.pit
:noindex:


si_sdr [func]
~~~~~~~~~~~~~

Expand Down
6 changes: 6 additions & 0 deletions docs/source/references/modules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,12 @@ the metric will be computed over the ``time`` dimension.
>>> snr_val
tensor(16.1805)

PIT
~~~~~~
Borda marked this conversation as resolved.
Show resolved Hide resolved

.. autoclass:: torchmetrics.PIT
:noindex:

SI_SDR
~~~~~~

Expand Down
210 changes: 210 additions & 0 deletions tests/audio/test_pit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,210 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import namedtuple
from functools import partial
from typing import Callable, Tuple

import numpy as np
import pytest
import torch
from scipy.optimize import linear_sum_assignment
from torch import Tensor

from tests.helpers import seed_all
from tests.helpers.testers import BATCH_SIZE, NUM_BATCHES, MetricTester
from torchmetrics.audio import PIT
from torchmetrics.functional import pit, si_sdr, snr
from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_6

seed_all(42)

TIME = 10

Input = namedtuple('Input', ["preds", "target"])

# three speaker examples to test _find_best_perm_by_linear_sum_assignment
inputs1 = Input(
preds=torch.rand(NUM_BATCHES, BATCH_SIZE, 3, TIME),
target=torch.rand(NUM_BATCHES, BATCH_SIZE, 3, TIME),
)
# two speaker examples to test _find_best_perm_by_exhuastive_method
inputs2 = Input(
preds=torch.rand(NUM_BATCHES, BATCH_SIZE, 2, TIME),
target=torch.rand(NUM_BATCHES, BATCH_SIZE, 2, TIME),
)


def a_naive_implementation_of_pit_based_on_scipy(
Borda marked this conversation as resolved.
Show resolved Hide resolved
preds: Tensor,
target: Tensor,
metric_func: Callable,
eval_func: str,
) -> Tuple[Tensor, Tensor]:
"""a naive implementation of pit based on scipy
Borda marked this conversation as resolved.
Show resolved Hide resolved
Borda marked this conversation as resolved.
Show resolved Hide resolved

Args:
preds: predictions, shape[batch, spk, time]
target: targets, shape[batch, spk, time]
metric_func: which metric
eval_func: min or max

Returns:
best_metric:
shape [batch]
best_perm:
shape [batch, spk]
"""
batch_size, spk_num = target.shape[0:2]
metric_mtx = torch.empty((batch_size, spk_num, spk_num), device=target.device)
for t in range(spk_num):
for e in range(spk_num):
metric_mtx[:, t, e] = metric_func(preds[:, e, ...], target[:, t, ...])

# pit_r = PIT(metric_func, eval_func)(preds, target)
metric_mtx = metric_mtx.detach().cpu().numpy()
best_metrics = []
best_perms = []
for b in range(batch_size):
row_idx, col_idx = linear_sum_assignment(metric_mtx[b, ...], eval_func == 'max')
best_metrics.append(metric_mtx[b, row_idx, col_idx].mean())
best_perms.append(col_idx)
return torch.from_numpy(np.stack(best_metrics)), torch.from_numpy(np.stack(best_perms))


def average_metric(preds: Tensor, target: Tensor, metric_func: Callable) -> Tensor:
Borda marked this conversation as resolved.
Show resolved Hide resolved
"""average the metric values

Args:
preds: predictions, shape[batch, spk, time]
target: targets, shape[batch, spk, time]
metric_func: a function which return best_metric and best_perm

Returns:
the average of best_metric
"""
return metric_func(preds, target)[0].mean()


snr_pit_scipy = partial(a_naive_implementation_of_pit_based_on_scipy, metric_func=snr, eval_func='max')
si_sdr_pit_scipy = partial(a_naive_implementation_of_pit_based_on_scipy, metric_func=si_sdr, eval_func='max')
Borda marked this conversation as resolved.
Show resolved Hide resolved


@pytest.mark.parametrize(
"preds, target, sk_metric, metric_func, eval_func",
[
(inputs1.preds, inputs1.target, snr_pit_scipy, snr, 'max'),
(inputs1.preds, inputs1.target, si_sdr_pit_scipy, si_sdr, 'max'),
(inputs2.preds, inputs2.target, snr_pit_scipy, snr, 'max'),
(inputs2.preds, inputs2.target, si_sdr_pit_scipy, si_sdr, 'max'),
],
)
class TestPIT(MetricTester):
atol = 1e-2

@pytest.mark.parametrize("ddp", [True, False])
@pytest.mark.parametrize("dist_sync_on_step", [True, False])
def test_pit(self, preds, target, sk_metric, metric_func, eval_func, ddp, dist_sync_on_step):
self.run_class_metric_test(
ddp,
preds,
target,
PIT,
sk_metric=partial(average_metric, metric_func=sk_metric),
dist_sync_on_step=dist_sync_on_step,
metric_args=dict(metric_func=metric_func, eval_func=eval_func),
)

def test_pit_functional(self, preds, target, sk_metric, metric_func, eval_func):
device = 'cuda' if (torch.cuda.is_available() and torch.cuda.device_count() > 0) else 'cpu'
quancs marked this conversation as resolved.
Show resolved Hide resolved

# move to device
preds = preds.to(device)
target = target.to(device)

for i in range(NUM_BATCHES):
best_metric, best_perm = pit(preds[i], target[i], metric_func, eval_func)
best_metric_sk, best_perm_sk = sk_metric(preds[i].cpu(), target[i].cpu())

# assert its the same
assert np.allclose(
best_metric.detach().cpu().numpy(), best_metric_sk.detach().cpu().numpy(), atol=self.atol
)
assert (best_perm.detach().cpu().numpy() == best_perm_sk.detach().cpu().numpy()).all()

def test_pit_differentiability(self, preds, target, sk_metric, metric_func, eval_func):

def pit_diff(preds, target, metric_func, eval_func):
return pit(preds, target, metric_func, eval_func)[0]

self.run_differentiability_test(
preds=preds,
target=target,
metric_module=PIT,
metric_functional=pit_diff,
metric_args={
'metric_func': metric_func,
'eval_func': eval_func
}
)

@pytest.mark.skipif(
not _TORCH_GREATER_EQUAL_1_6, reason='half support of core operations on not support before pytorch v1.6'
)
def test_pit_half_cpu(self, preds, target, sk_metric, metric_func, eval_func):
pytest.xfail("PIT metric does not support cpu + half precision")

@pytest.mark.skipif(not torch.cuda.is_available(), reason='test requires cuda')
def test_pit_half_gpu(self, preds, target, sk_metric, metric_func, eval_func):
self.run_precision_test_gpu(
preds=preds,
target=target,
metric_module=PIT,
metric_functional=partial(pit, metric_func=metric_func, eval_func=eval_func),
metric_args={
'metric_func': metric_func,
'eval_func': eval_func
}
)


def test_error_on_different_shape() -> None:
metric = PIT(snr, 'max')
with pytest.raises(RuntimeError, match='Predictions and targets are expected to have the same shape'):
metric(torch.randn(3, 3, 10), torch.randn(3, 2, 10))


def test_error_on_wrong_eval_func() -> None:
metric = PIT(snr, 'xxx')
with pytest.raises(RuntimeError, match='eval_func can only be "max" or "min"'):
metric(torch.randn(3, 3, 10), torch.randn(3, 3, 10))


def test_error_on_wrong_shape() -> None:
metric = PIT(snr, 'max')
with pytest.raises(RuntimeError, match='Inputs must be of shape *'):
metric(torch.randn(3), torch.randn(3))


def test_consistency_of_two_implementations() -> None:
from torchmetrics.functional.audio.pit import (
_find_best_perm_by_exhuastive_method,
_find_best_perm_by_linear_sum_assignment,
)
shapes_test = [(5, 2, 2), (4, 3, 3), (4, 4, 4), (3, 5, 5)]
for shp in shapes_test:
metric_mtx = torch.randn(size=shp)
bm1, bp1 = _find_best_perm_by_linear_sum_assignment(metric_mtx, torch.max)
bm2, bp2 = _find_best_perm_by_exhuastive_method(metric_mtx, torch.max)
assert torch.allclose(bm1, bm2)
assert (bp1 == bp2).all()
2 changes: 1 addition & 1 deletion torchmetrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
_PACKAGE_ROOT = os.path.dirname(__file__)
_PROJECT_ROOT = os.path.dirname(_PACKAGE_ROOT)

from torchmetrics.audio import SI_SDR, SI_SNR, SNR # noqa: F401 E402
from torchmetrics.audio import PIT, SI_SDR, SI_SNR, SNR # noqa: F401 E402
from torchmetrics.average import AverageMeter # noqa: F401 E402
from torchmetrics.classification import ( # noqa: F401 E402
AUC,
Expand Down
1 change: 1 addition & 0 deletions torchmetrics/audio/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from torchmetrics.audio.pit import PIT # noqa: F401
from torchmetrics.audio.si_sdr import SI_SDR # noqa: F401
from torchmetrics.audio.si_snr import SI_SNR # noqa: F401
from torchmetrics.audio.snr import SNR # noqa: F401
114 changes: 114 additions & 0 deletions torchmetrics/audio/pit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Dict, Optional

from torch import Tensor, tensor

from torchmetrics.functional.audio.pit import pit
from torchmetrics.metric import Metric


class PIT(Metric):
Borda marked this conversation as resolved.
Show resolved Hide resolved
""" Permutation invariant training metric

Forward accepts

- ``preds``: ``shape [..., time]``
- ``target``: ``shape [..., time]``

Args:
metric_func:
a metric function accept a batch of target and estimate, i.e. metric_func(target[:, i, ...],
estimate[:, j, ...]), and returns a batch of metric tensors [batch]
eval_func:
the function to find the best permutation, can be 'min' or 'max', i.e. the smaller the better
or the larger the better.
compute_on_step:
Forward only calls ``update()`` and returns None if this is set to False. default: True
dist_sync_on_step:
Synchronize metric state across processes at each ``forward()``
before returning the value at the step.
process_group:
Specify the process group on which synchronization is called. default: None (which selects the entire world)
dist_sync_fn:
Callback that performs the allgather operation on the metric state. When `None`, DDP
will be used to perform the allgather.
kwargs:
additional args for metric_func

Returns:
average PIT metric

Example:
>>> import torch
>>> from torchmetrics import PIT
>>> from torchmetrics.functional.audio import si_snr
>>> preds = torch.randn(3, 2, 5) # [batch, spk, time]
>>> target = torch.randn(3, 2, 5) # [batch, spk, time]
>>> pit = PIT(si_snr, 'max')
>>> avg_pit_metric = pit(preds, target)
Borda marked this conversation as resolved.
Show resolved Hide resolved

Reference:
[1] D. Yu, M. Kolbaek, Z.-H. Tan, J. Jensen, Permutation invariant training of deep models for
speaker-independent multi-talker speech separation, in: 2017 IEEE Int. Conf. Acoust. Speech
Signal Process. ICASSP, IEEE, New Orleans, LA, 2017: pp. 241–245. https://doi.org/10.1109/ICASSP.2017.7952154.
"""
sum_pit_metric: Tensor
total: Tensor

def __init__(
self,
metric_func: Callable,
eval_func: str = 'max',
compute_on_step: bool = True,
dist_sync_on_step: bool = False,
process_group: Optional[Any] = None,
dist_sync_fn: Optional[Callable[[Tensor], Tensor]] = None,
**kwargs: Dict[str, Any],
) -> None:
super().__init__(
compute_on_step=compute_on_step,
dist_sync_on_step=dist_sync_on_step,
process_group=process_group,
dist_sync_fn=dist_sync_fn,
)
self.metric_func = metric_func
self.eval_func = eval_func
self.kwargs = kwargs

self.add_state("sum_pit_metric", default=tensor(0.0), dist_reduce_fx="sum")
self.add_state("total", default=tensor(0), dist_reduce_fx="sum")

def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore
"""
Update state with predictions and targets.

Args:
preds: Predictions from model
target: Ground truth values
"""
pit_metric = pit(preds, target, self.metric_func, self.eval_func, **self.kwargs)[0]

self.sum_pit_metric += pit_metric.sum()
self.total += pit_metric.numel()

def compute(self) -> Tensor:
"""
Computes average PIT metric.
"""
return self.sum_pit_metric / self.total

@property
def is_differentiable(self) -> bool:
return True
1 change: 1 addition & 0 deletions torchmetrics/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from torchmetrics.functional.audio.pit import permutate, pit # noqa: F401
from torchmetrics.functional.audio.si_sdr import si_sdr # noqa: F401
from torchmetrics.functional.audio.si_snr import si_snr # noqa: F401
from torchmetrics.functional.audio.snr import snr # noqa: F401
Expand Down
1 change: 1 addition & 0 deletions torchmetrics/functional/audio/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from torchmetrics.functional.audio.pit import permutate, pit # noqa: F401
from torchmetrics.functional.audio.si_sdr import si_sdr # noqa: F401
from torchmetrics.functional.audio.si_snr import si_snr # noqa: F401
from torchmetrics.functional.audio.snr import snr # noqa: F401
Loading