diff --git a/src/torchmetrics/__init__.py b/src/torchmetrics/__init__.py index d62a1b8452b..1aedeb14c07 100644 --- a/src/torchmetrics/__init__.py +++ b/src/torchmetrics/__init__.py @@ -13,13 +13,15 @@ from torchmetrics import functional # noqa: E402 from torchmetrics.aggregation import CatMetric, MaxMetric, MeanMetric, MinMetric, SumMetric # noqa: E402 -from torchmetrics.audio import PermutationInvariantTraining # noqa: E402 -from torchmetrics.audio import ( # noqa: E402 - ScaleInvariantSignalDistortionRatio, - ScaleInvariantSignalNoiseRatio, - SignalDistortionRatio, - SignalNoiseRatio, +from torchmetrics.audio._deprecated import _PermutationInvariantTraining as PermutationInvariantTraining # noqa: E402 +from torchmetrics.audio._deprecated import ( # noqa: E402 + _ScaleInvariantSignalDistortionRatio as ScaleInvariantSignalDistortionRatio, ) +from torchmetrics.audio._deprecated import ( # noqa: E402 + _ScaleInvariantSignalNoiseRatio as ScaleInvariantSignalNoiseRatio, +) +from torchmetrics.audio._deprecated import _SignalDistortionRatio as SignalDistortionRatio # noqa: E402 +from torchmetrics.audio._deprecated import _SignalNoiseRatio as SignalNoiseRatio # noqa: E402 from torchmetrics.classification import ( # noqa: E402 AUROC, ROC, diff --git a/src/torchmetrics/audio/_deprecated.py b/src/torchmetrics/audio/_deprecated.py new file mode 100644 index 00000000000..9a9f73eaf30 --- /dev/null +++ b/src/torchmetrics/audio/_deprecated.py @@ -0,0 +1,123 @@ +from typing import Any, Callable, Optional + +from typing_extensions import Literal + +from torchmetrics.audio.pit import PermutationInvariantTraining +from torchmetrics.audio.sdr import ScaleInvariantSignalDistortionRatio, SignalDistortionRatio +from torchmetrics.audio.snr import ScaleInvariantSignalNoiseRatio, SignalNoiseRatio +from torchmetrics.utilities.prints import _deprecated_root_import_class + + +class _PermutationInvariantTraining(PermutationInvariantTraining): + """Wrapper for deprecated import. + + >>> import torch + >>> from torchmetrics.functional import scale_invariant_signal_noise_ratio + >>> _ = torch.manual_seed(42) + >>> preds = torch.randn(3, 2, 5) # [batch, spk, time] + >>> target = torch.randn(3, 2, 5) # [batch, spk, time] + >>> pit = _PermutationInvariantTraining(scale_invariant_signal_noise_ratio, 'max') + >>> pit(preds, target) + tensor(-2.1065) + """ + + def __init__( + self, + metric_func: Callable, + eval_func: Literal["max", "min"] = "max", + **kwargs: Any, + ) -> None: + _deprecated_root_import_class("PermutationInvariantTraining", "audio") + return super().__init__(metric_func=metric_func, eval_func=eval_func, **kwargs) + + +class _ScaleInvariantSignalDistortionRatio(ScaleInvariantSignalDistortionRatio): + """Wrapper for deprecated import. + + >>> from torch import tensor + >>> target = tensor([3.0, -0.5, 2.0, 7.0]) + >>> preds = tensor([2.5, 0.0, 2.0, 8.0]) + >>> si_sdr = _ScaleInvariantSignalDistortionRatio() + >>> si_sdr(preds, target) + tensor(18.4030) + """ + + def __init__( + self, + zero_mean: bool = False, + **kwargs: Any, + ) -> None: + _deprecated_root_import_class("ScaleInvariantSignalDistortionRatio", "audio") + return super().__init__(zero_mean=zero_mean, **kwargs) + + +class _ScaleInvariantSignalNoiseRatio(ScaleInvariantSignalNoiseRatio): + """Wrapper for deprecated import. + + >>> from torch import tensor + >>> target = tensor([3.0, -0.5, 2.0, 7.0]) + >>> preds = tensor([2.5, 0.0, 2.0, 8.0]) + >>> si_snr = _ScaleInvariantSignalNoiseRatio() + >>> si_snr(preds, target) + tensor(15.0918) + """ + + def __init__( + self, + **kwargs: Any, + ) -> None: + _deprecated_root_import_class("ScaleInvariantSignalNoiseRatio", "audio") + return super().__init__(**kwargs) + + +class _SignalDistortionRatio(SignalDistortionRatio): + """Wrapper for deprecated import. + + >>> import torch + >>> g = torch.manual_seed(1) + >>> preds = torch.randn(8000) + >>> target = torch.randn(8000) + >>> sdr = _SignalDistortionRatio() + >>> sdr(preds, target) + tensor(-12.0589) + >>> # use with pit + >>> from torchmetrics.functional import signal_distortion_ratio + >>> preds = torch.randn(4, 2, 8000) # [batch, spk, time] + >>> target = torch.randn(4, 2, 8000) + >>> pit = _PermutationInvariantTraining(signal_distortion_ratio, 'max') + >>> pit(preds, target) + tensor(-11.6051) + """ + + def __init__( + self, + use_cg_iter: Optional[int] = None, + filter_length: int = 512, + zero_mean: bool = False, + load_diag: Optional[float] = None, + **kwargs: Any, + ) -> None: + _deprecated_root_import_class("SignalDistortionRatio", "audio") + return super().__init__( + use_cg_iter=use_cg_iter, filter_length=filter_length, zero_mean=zero_mean, load_diag=load_diag, **kwargs + ) + + +class _SignalNoiseRatio(SignalNoiseRatio): + """Wrapper for deprecated import. + + >>> from torch import tensor + >>> target = tensor([3.0, -0.5, 2.0, 7.0]) + >>> preds = tensor([2.5, 0.0, 2.0, 8.0]) + >>> snr = _SignalNoiseRatio() + >>> snr(preds, target) + tensor(16.1805) + """ + + def __init__( + self, + zero_mean: bool = False, + **kwargs: Any, + ) -> None: + _deprecated_root_import_class("SignalNoiseRatio", "audio") + return super().__init__(zero_mean=zero_mean, **kwargs) diff --git a/src/torchmetrics/audio/pesq.py b/src/torchmetrics/audio/pesq.py index 2ddbb88a15e..304eaf943c6 100644 --- a/src/torchmetrics/audio/pesq.py +++ b/src/torchmetrics/audio/pesq.py @@ -68,7 +68,7 @@ class PerceptualEvaluationSpeechQuality(Metric): Example: >>> import torch - >>> from torchmetrics.audio.pesq import PerceptualEvaluationSpeechQuality + >>> from torchmetrics.audio import PerceptualEvaluationSpeechQuality >>> g = torch.manual_seed(1) >>> preds = torch.randn(8000) >>> target = torch.randn(8000) @@ -147,7 +147,7 @@ def plot(self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_ >>> # Example plotting a single value >>> import torch - >>> from torchmetrics.audio.pesq import PerceptualEvaluationSpeechQuality + >>> from torchmetrics.audio import PerceptualEvaluationSpeechQuality >>> metric = PerceptualEvaluationSpeechQuality(8000, 'nb') >>> metric.update(torch.rand(8000), torch.rand(8000)) >>> fig_, ax_ = metric.plot() @@ -157,7 +157,7 @@ def plot(self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_ >>> # Example plotting multiple values >>> import torch - >>> from torchmetrics.audio.pesq import PerceptualEvaluationSpeechQuality + >>> from torchmetrics.audio import PerceptualEvaluationSpeechQuality >>> metric = PerceptualEvaluationSpeechQuality(8000, 'nb') >>> values = [ ] >>> for _ in range(10): diff --git a/src/torchmetrics/audio/pit.py b/src/torchmetrics/audio/pit.py index 97968ea8d2d..a540bd86936 100644 --- a/src/torchmetrics/audio/pit.py +++ b/src/torchmetrics/audio/pit.py @@ -55,8 +55,8 @@ class PermutationInvariantTraining(Metric): Example: >>> import torch - >>> from torchmetrics import PermutationInvariantTraining - >>> from torchmetrics.functional import scale_invariant_signal_noise_ratio + >>> from torchmetrics.audio import PermutationInvariantTraining + >>> from torchmetrics.functional.audio import scale_invariant_signal_noise_ratio >>> _ = torch.manual_seed(42) >>> preds = torch.randn(3, 2, 5) # [batch, spk, time] >>> target = torch.randn(3, 2, 5) # [batch, spk, time] @@ -122,8 +122,8 @@ def plot(self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_ >>> # Example plotting a single value >>> import torch - >>> from torchmetrics.audio.pit import PermutationInvariantTraining - >>> from torchmetrics.functional import scale_invariant_signal_noise_ratio + >>> from torchmetrics.audio import PermutationInvariantTraining + >>> from torchmetrics.functional.audio import scale_invariant_signal_noise_ratio >>> preds = torch.randn(3, 2, 5) # [batch, spk, time] >>> target = torch.randn(3, 2, 5) # [batch, spk, time] >>> metric = PermutationInvariantTraining(scale_invariant_signal_noise_ratio, 'max') @@ -135,8 +135,8 @@ def plot(self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_ >>> # Example plotting multiple values >>> import torch - >>> from torchmetrics.audio.pit import PermutationInvariantTraining - >>> from torchmetrics.functional import scale_invariant_signal_noise_ratio + >>> from torchmetrics.audio import PermutationInvariantTraining + >>> from torchmetrics.functional.audio import scale_invariant_signal_noise_ratio >>> preds = torch.randn(3, 2, 5) # [batch, spk, time] >>> target = torch.randn(3, 2, 5) # [batch, spk, time] >>> metric = PermutationInvariantTraining(scale_invariant_signal_noise_ratio, 'max') diff --git a/src/torchmetrics/audio/sdr.py b/src/torchmetrics/audio/sdr.py index b83ab27011c..7af2ec44f9f 100644 --- a/src/torchmetrics/audio/sdr.py +++ b/src/torchmetrics/audio/sdr.py @@ -139,7 +139,7 @@ def plot(self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_ >>> # Example plotting a single value >>> import torch - >>> from torchmetrics.audio.sdr import SignalDistortionRatio + >>> from torchmetrics.audio import SignalDistortionRatio >>> metric = SignalDistortionRatio() >>> metric.update(torch.rand(8000), torch.rand(8000)) >>> fig_, ax_ = metric.plot() @@ -149,7 +149,7 @@ def plot(self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_ >>> # Example plotting multiple values >>> import torch - >>> from torchmetrics.audio.sdr import SignalDistortionRatio + >>> from torchmetrics.audio import SignalDistortionRatio >>> metric = SignalDistortionRatio() >>> values = [ ] >>> for _ in range(10): diff --git a/src/torchmetrics/functional/__init__.py b/src/torchmetrics/functional/__init__.py index dc050c66f95..b48d0b7f586 100644 --- a/src/torchmetrics/functional/__init__.py +++ b/src/torchmetrics/functional/__init__.py @@ -11,9 +11,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from torchmetrics.functional.audio.pit import permutation_invariant_training, pit_permutate -from torchmetrics.functional.audio.sdr import scale_invariant_signal_distortion_ratio, signal_distortion_ratio -from torchmetrics.functional.audio.snr import scale_invariant_signal_noise_ratio, signal_noise_ratio +from torchmetrics.functional.audio._deprecated import _permutation_invariant_training as permutation_invariant_training +from torchmetrics.functional.audio._deprecated import _pit_permutate as pit_permutate +from torchmetrics.functional.audio._deprecated import ( + _scale_invariant_signal_distortion_ratio as scale_invariant_signal_distortion_ratio, +) +from torchmetrics.functional.audio._deprecated import ( + _scale_invariant_signal_noise_ratio as scale_invariant_signal_noise_ratio, +) +from torchmetrics.functional.audio._deprecated import _signal_distortion_ratio as signal_distortion_ratio +from torchmetrics.functional.audio._deprecated import _signal_noise_ratio as signal_noise_ratio from torchmetrics.functional.classification.accuracy import accuracy from torchmetrics.functional.classification.auroc import auroc from torchmetrics.functional.classification.average_precision import average_precision diff --git a/src/torchmetrics/functional/audio/_deprecated.py b/src/torchmetrics/functional/audio/_deprecated.py new file mode 100644 index 00000000000..49ae6735f20 --- /dev/null +++ b/src/torchmetrics/functional/audio/_deprecated.py @@ -0,0 +1,117 @@ +from typing import Any, Callable, Optional, Tuple + +from torch import Tensor +from typing_extensions import Literal + +from torchmetrics.functional.audio.pit import permutation_invariant_training, pit_permutate +from torchmetrics.functional.audio.sdr import scale_invariant_signal_distortion_ratio, signal_distortion_ratio +from torchmetrics.functional.audio.snr import scale_invariant_signal_noise_ratio, signal_noise_ratio +from torchmetrics.utilities.prints import _deprecated_root_import_func + + +def _permutation_invariant_training( + preds: Tensor, target: Tensor, metric_func: Callable, eval_func: Literal["max", "min"] = "max", **kwargs: Any +) -> Tuple[Tensor, Tensor]: + """Wrapper for deprecated import. + + >>> from torch import tensor + >>> preds = tensor([[[-0.0579, 0.3560, -0.9604], [-0.1719, 0.3205, 0.2951]]]) + >>> target = tensor([[[ 1.0958, -0.1648, 0.5228], [-0.4100, 1.1942, -0.5103]]]) + >>> best_metric, best_perm = _permutation_invariant_training( + ... preds, target, _scale_invariant_signal_distortion_ratio, 'max') + >>> best_metric + tensor([-5.1091]) + >>> best_perm + tensor([[0, 1]]) + >>> pit_permutate(preds, best_perm) + tensor([[[-0.0579, 0.3560, -0.9604], + [-0.1719, 0.3205, 0.2951]]]) + """ + _deprecated_root_import_func("permutation_invariant_training", "audio") + return permutation_invariant_training( + preds=preds, target=target, metric_func=metric_func, eval_func=eval_func, **kwargs + ) + + +def _pit_permutate(preds: Tensor, perm: Tensor) -> Tensor: + """Wrapper for deprecated import.""" + _deprecated_root_import_func("pit_permutate", "audio") + return pit_permutate(preds=preds, perm=perm) + + +def _scale_invariant_signal_distortion_ratio(preds: Tensor, target: Tensor, zero_mean: bool = False) -> Tensor: + """Wrapper for deprecated import. + + >>> from torch import tensor + >>> target = tensor([3.0, -0.5, 2.0, 7.0]) + >>> preds = tensor([2.5, 0.0, 2.0, 8.0]) + >>> _scale_invariant_signal_distortion_ratio(preds, target) + tensor(18.4030) + """ + _deprecated_root_import_func("scale_invariant_signal_distortion_ratio", "audio") + return scale_invariant_signal_distortion_ratio(preds=preds, target=target, zero_mean=zero_mean) + + +def _signal_distortion_ratio( + preds: Tensor, + target: Tensor, + use_cg_iter: Optional[int] = None, + filter_length: int = 512, + zero_mean: bool = False, + load_diag: Optional[float] = None, +) -> Tensor: + """Wrapper for deprecated import. + + >>> import torch + >>> g = torch.manual_seed(1) + >>> preds = torch.randn(8000) + >>> target = torch.randn(8000) + >>> _signal_distortion_ratio(preds, target) + tensor(-12.0589) + >>> # use with permutation_invariant_training + >>> preds = torch.randn(4, 2, 8000) # [batch, spk, time] + >>> target = torch.randn(4, 2, 8000) + >>> best_metric, best_perm = _permutation_invariant_training(preds, target, _signal_distortion_ratio, 'max') + >>> best_metric + tensor([-11.6375, -11.4358, -11.7148, -11.6325]) + >>> best_perm + tensor([[1, 0], + [0, 1], + [1, 0], + [0, 1]]) + """ + _deprecated_root_import_func("signal_distortion_ratio", "audio") + return signal_distortion_ratio( + preds=preds, + target=target, + use_cg_iter=use_cg_iter, + filter_length=filter_length, + zero_mean=zero_mean, + load_diag=load_diag, + ) + + +def _scale_invariant_signal_noise_ratio(preds: Tensor, target: Tensor) -> Tensor: + """Wrapper for deprecated import. + + >>> from torch import tensor + >>> target = tensor([3.0, -0.5, 2.0, 7.0]) + >>> preds = tensor([2.5, 0.0, 2.0, 8.0]) + >>> _scale_invariant_signal_noise_ratio(preds, target) + tensor(15.0918) + """ + _deprecated_root_import_func("scale_invariant_signal_noise_ratio", "audio") + return scale_invariant_signal_noise_ratio(preds=preds, target=target) + + +def _signal_noise_ratio(preds: Tensor, target: Tensor, zero_mean: bool = False) -> Tensor: + """Wrapper for deprecated import. + + >>> from torch import tensor + >>> target = tensor([3.0, -0.5, 2.0, 7.0]) + >>> preds = tensor([2.5, 0.0, 2.0, 8.0]) + >>> _signal_noise_ratio(preds, target) + tensor(16.1805) + """ + _deprecated_root_import_func("signal_noise_ratio", "audio") + return signal_noise_ratio(preds=preds, target=target, zero_mean=zero_mean) diff --git a/src/torchmetrics/utilities/prints.py b/src/torchmetrics/utilities/prints.py index b82b5e75673..a86388b8098 100644 --- a/src/torchmetrics/utilities/prints.py +++ b/src/torchmetrics/utilities/prints.py @@ -54,3 +54,19 @@ def _debug(*args: Any, **kwargs: Any) -> None: rank_zero_info = rank_zero_only(_info) rank_zero_warn = rank_zero_only(_warn) _future_warning = partial(warnings.warn, category=FutureWarning) + + +def _deprecated_root_import_class(name: str, domain: str) -> None: + """Warn user that he is importing class from location it has been deprecated.""" + _future_warning( + f"Importing `{name}` from `torchmetrics` was deprecated and will be removed in 2.0." + f" Import `{name}` from `torchmetrics.{domain}` instead." + ) + + +def _deprecated_root_import_func(name: str, domain: str) -> None: + """Warn user that he is importing function from location it has been deprecated.""" + _future_warning( + f"Importing `{name}` from `torchmetrics.functional` was deprecated and will be removed in 2.0." + f" Import `{name}` from `torchmetrics.{domain}` instead." + ) diff --git a/tests/unittests/audio/test_pesq.py b/tests/unittests/audio/test_pesq.py index 2d012a74a19..aff00109d67 100644 --- a/tests/unittests/audio/test_pesq.py +++ b/tests/unittests/audio/test_pesq.py @@ -20,8 +20,8 @@ from scipy.io import wavfile from torch import Tensor -from torchmetrics.audio.pesq import PerceptualEvaluationSpeechQuality -from torchmetrics.functional.audio.pesq import perceptual_evaluation_speech_quality +from torchmetrics.audio import PerceptualEvaluationSpeechQuality +from torchmetrics.functional.audio import perceptual_evaluation_speech_quality from unittests.audio import _SAMPLE_AUDIO_SPEECH, _SAMPLE_AUDIO_SPEECH_BAB_DB from unittests.helpers import seed_all from unittests.helpers.testers import MetricTester diff --git a/tests/unittests/audio/test_pit.py b/tests/unittests/audio/test_pit.py index c90278a3c44..06bb9f2db4e 100644 --- a/tests/unittests/audio/test_pit.py +++ b/tests/unittests/audio/test_pit.py @@ -22,11 +22,15 @@ from torch import Tensor from torchmetrics.audio import PermutationInvariantTraining -from torchmetrics.functional import ( +from torchmetrics.functional.audio import ( permutation_invariant_training, scale_invariant_signal_distortion_ratio, signal_noise_ratio, ) +from torchmetrics.functional.audio.pit import ( + _find_best_perm_by_exhaustive_method, + _find_best_perm_by_linear_sum_assignment, +) from unittests import BATCH_SIZE, NUM_BATCHES from unittests.helpers import seed_all from unittests.helpers.testers import MetricTester @@ -198,11 +202,6 @@ def test_error_on_wrong_shape() -> None: def test_consistency_of_two_implementations() -> None: """Test that both backend functions for computing metric (depending on torch version) returns the same result.""" - from torchmetrics.functional.audio.pit import ( - _find_best_perm_by_exhaustive_method, - _find_best_perm_by_linear_sum_assignment, - ) - shapes_test = [(5, 2, 2), (4, 3, 3), (4, 4, 4), (3, 5, 5)] for shp in shapes_test: metric_mtx = torch.randn(size=shp) diff --git a/tests/unittests/audio/test_si_sdr.py b/tests/unittests/audio/test_si_sdr.py index 622c895ee3b..78a6af0e94c 100644 --- a/tests/unittests/audio/test_si_sdr.py +++ b/tests/unittests/audio/test_si_sdr.py @@ -20,7 +20,7 @@ from torch import Tensor from torchmetrics.audio import ScaleInvariantSignalDistortionRatio -from torchmetrics.functional import scale_invariant_signal_distortion_ratio +from torchmetrics.functional.audio import scale_invariant_signal_distortion_ratio from unittests import BATCH_SIZE, NUM_BATCHES from unittests.helpers import seed_all from unittests.helpers.testers import MetricTester diff --git a/tests/unittests/audio/test_si_snr.py b/tests/unittests/audio/test_si_snr.py index baff783a551..4c960d1ef4f 100644 --- a/tests/unittests/audio/test_si_snr.py +++ b/tests/unittests/audio/test_si_snr.py @@ -20,7 +20,7 @@ from torch import Tensor from torchmetrics.audio import ScaleInvariantSignalNoiseRatio -from torchmetrics.functional import scale_invariant_signal_noise_ratio +from torchmetrics.functional.audio import scale_invariant_signal_noise_ratio from unittests import BATCH_SIZE, NUM_BATCHES from unittests.helpers import seed_all from unittests.helpers.testers import MetricTester diff --git a/tests/unittests/audio/test_snr.py b/tests/unittests/audio/test_snr.py index 54d4a0e15e2..895c72eeb99 100644 --- a/tests/unittests/audio/test_snr.py +++ b/tests/unittests/audio/test_snr.py @@ -21,7 +21,7 @@ from torch import Tensor from torchmetrics.audio import SignalNoiseRatio -from torchmetrics.functional import signal_noise_ratio +from torchmetrics.functional.audio import signal_noise_ratio from unittests import NUM_BATCHES from unittests.helpers import seed_all from unittests.helpers.testers import MetricTester diff --git a/tests/unittests/audio/test_stoi.py b/tests/unittests/audio/test_stoi.py index 4365eb7e75b..ad1b1a0af6b 100644 --- a/tests/unittests/audio/test_stoi.py +++ b/tests/unittests/audio/test_stoi.py @@ -20,8 +20,8 @@ from scipy.io import wavfile from torch import Tensor -from torchmetrics.audio.stoi import ShortTimeObjectiveIntelligibility -from torchmetrics.functional.audio.stoi import short_time_objective_intelligibility +from torchmetrics.audio import ShortTimeObjectiveIntelligibility +from torchmetrics.functional.audio import short_time_objective_intelligibility from unittests.audio import _SAMPLE_AUDIO_SPEECH, _SAMPLE_AUDIO_SPEECH_BAB_DB from unittests.helpers import seed_all from unittests.helpers.testers import MetricTester diff --git a/tests/unittests/deprecations/__init__.py b/tests/unittests/deprecations/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tests/unittests/deprecations/root_class_imports.py b/tests/unittests/deprecations/root_class_imports.py new file mode 100644 index 00000000000..6d8ab1f64e7 --- /dev/null +++ b/tests/unittests/deprecations/root_class_imports.py @@ -0,0 +1,31 @@ +"""Test that domain metric with import from root raises deprecation warning.""" +from functools import partial + +import pytest + +from torchmetrics import ( + PermutationInvariantTraining, + ScaleInvariantSignalDistortionRatio, + ScaleInvariantSignalNoiseRatio, + SignalDistortionRatio, + SignalNoiseRatio, +) +from torchmetrics.functional.audio import scale_invariant_signal_noise_ratio + + +@pytest.mark.parametrize( + "metric_cls", + [ + pytest.param( + partial(PermutationInvariantTraining, scale_invariant_signal_noise_ratio), id="PermutationInvariantTraining" + ), + ScaleInvariantSignalDistortionRatio, + ScaleInvariantSignalNoiseRatio, + SignalDistortionRatio, + SignalNoiseRatio, + ], +) +def test_import_from_root_package(metric_cls): + """Test that domain metric with import from root raises deprecation warning.""" + with pytest.warns(FutureWarning, match=r".+ was deprecated and will be removed in 2.0.+"): + metric_cls()