Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

imports: deprecate from pkg root [1/n] Audio #1685

Merged
merged 24 commits into from
Apr 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions src/torchmetrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,15 @@

from torchmetrics import functional # noqa: E402
from torchmetrics.aggregation import CatMetric, MaxMetric, MeanMetric, MinMetric, SumMetric # noqa: E402
from torchmetrics.audio import PermutationInvariantTraining # noqa: E402
from torchmetrics.audio import ( # noqa: E402
ScaleInvariantSignalDistortionRatio,
ScaleInvariantSignalNoiseRatio,
SignalDistortionRatio,
SignalNoiseRatio,
from torchmetrics.audio._deprecated import _PermutationInvariantTraining as PermutationInvariantTraining # noqa: E402
from torchmetrics.audio._deprecated import ( # noqa: E402
_ScaleInvariantSignalDistortionRatio as ScaleInvariantSignalDistortionRatio,
)
from torchmetrics.audio._deprecated import ( # noqa: E402
_ScaleInvariantSignalNoiseRatio as ScaleInvariantSignalNoiseRatio,
)
from torchmetrics.audio._deprecated import _SignalDistortionRatio as SignalDistortionRatio # noqa: E402
from torchmetrics.audio._deprecated import _SignalNoiseRatio as SignalNoiseRatio # noqa: E402
from torchmetrics.classification import ( # noqa: E402
AUROC,
ROC,
Expand Down
123 changes: 123 additions & 0 deletions src/torchmetrics/audio/_deprecated.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
from typing import Any, Callable, Optional

from typing_extensions import Literal

from torchmetrics.audio.pit import PermutationInvariantTraining
from torchmetrics.audio.sdr import ScaleInvariantSignalDistortionRatio, SignalDistortionRatio
from torchmetrics.audio.snr import ScaleInvariantSignalNoiseRatio, SignalNoiseRatio
from torchmetrics.utilities.prints import _deprecated_root_import_class


class _PermutationInvariantTraining(PermutationInvariantTraining):
"""Wrapper for deprecated import.

>>> import torch
>>> from torchmetrics.functional import scale_invariant_signal_noise_ratio
>>> _ = torch.manual_seed(42)
>>> preds = torch.randn(3, 2, 5) # [batch, spk, time]
>>> target = torch.randn(3, 2, 5) # [batch, spk, time]
>>> pit = _PermutationInvariantTraining(scale_invariant_signal_noise_ratio, 'max')
>>> pit(preds, target)
tensor(-2.1065)
"""

def __init__(
self,
metric_func: Callable,
eval_func: Literal["max", "min"] = "max",
**kwargs: Any,
) -> None:
_deprecated_root_import_class("PermutationInvariantTraining", "audio")
return super().__init__(metric_func=metric_func, eval_func=eval_func, **kwargs)


class _ScaleInvariantSignalDistortionRatio(ScaleInvariantSignalDistortionRatio):
"""Wrapper for deprecated import.

>>> from torch import tensor
>>> target = tensor([3.0, -0.5, 2.0, 7.0])
>>> preds = tensor([2.5, 0.0, 2.0, 8.0])
>>> si_sdr = _ScaleInvariantSignalDistortionRatio()
>>> si_sdr(preds, target)
tensor(18.4030)
"""

def __init__(
self,
zero_mean: bool = False,
**kwargs: Any,
) -> None:
_deprecated_root_import_class("ScaleInvariantSignalDistortionRatio", "audio")
return super().__init__(zero_mean=zero_mean, **kwargs)


class _ScaleInvariantSignalNoiseRatio(ScaleInvariantSignalNoiseRatio):
"""Wrapper for deprecated import.

>>> from torch import tensor
>>> target = tensor([3.0, -0.5, 2.0, 7.0])
>>> preds = tensor([2.5, 0.0, 2.0, 8.0])
>>> si_snr = _ScaleInvariantSignalNoiseRatio()
>>> si_snr(preds, target)
tensor(15.0918)
"""

def __init__(
self,
**kwargs: Any,
) -> None:
_deprecated_root_import_class("ScaleInvariantSignalNoiseRatio", "audio")
return super().__init__(**kwargs)


class _SignalDistortionRatio(SignalDistortionRatio):
"""Wrapper for deprecated import.

>>> import torch
>>> g = torch.manual_seed(1)
>>> preds = torch.randn(8000)
>>> target = torch.randn(8000)
>>> sdr = _SignalDistortionRatio()
>>> sdr(preds, target)
tensor(-12.0589)
>>> # use with pit
>>> from torchmetrics.functional import signal_distortion_ratio
>>> preds = torch.randn(4, 2, 8000) # [batch, spk, time]
>>> target = torch.randn(4, 2, 8000)
>>> pit = _PermutationInvariantTraining(signal_distortion_ratio, 'max')
>>> pit(preds, target)
tensor(-11.6051)
"""

def __init__(
self,
use_cg_iter: Optional[int] = None,
filter_length: int = 512,
zero_mean: bool = False,
load_diag: Optional[float] = None,
**kwargs: Any,
) -> None:
_deprecated_root_import_class("SignalDistortionRatio", "audio")
return super().__init__(
use_cg_iter=use_cg_iter, filter_length=filter_length, zero_mean=zero_mean, load_diag=load_diag, **kwargs
)


class _SignalNoiseRatio(SignalNoiseRatio):
"""Wrapper for deprecated import.

>>> from torch import tensor
>>> target = tensor([3.0, -0.5, 2.0, 7.0])
>>> preds = tensor([2.5, 0.0, 2.0, 8.0])
>>> snr = _SignalNoiseRatio()
>>> snr(preds, target)
tensor(16.1805)
"""

def __init__(
self,
zero_mean: bool = False,
**kwargs: Any,
) -> None:
_deprecated_root_import_class("SignalNoiseRatio", "audio")
return super().__init__(zero_mean=zero_mean, **kwargs)
6 changes: 3 additions & 3 deletions src/torchmetrics/audio/pesq.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ class PerceptualEvaluationSpeechQuality(Metric):

Example:
>>> import torch
>>> from torchmetrics.audio.pesq import PerceptualEvaluationSpeechQuality
>>> from torchmetrics.audio import PerceptualEvaluationSpeechQuality
>>> g = torch.manual_seed(1)
>>> preds = torch.randn(8000)
>>> target = torch.randn(8000)
Expand Down Expand Up @@ -147,7 +147,7 @@ def plot(self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_

>>> # Example plotting a single value
>>> import torch
>>> from torchmetrics.audio.pesq import PerceptualEvaluationSpeechQuality
>>> from torchmetrics.audio import PerceptualEvaluationSpeechQuality
>>> metric = PerceptualEvaluationSpeechQuality(8000, 'nb')
>>> metric.update(torch.rand(8000), torch.rand(8000))
>>> fig_, ax_ = metric.plot()
Expand All @@ -157,7 +157,7 @@ def plot(self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_

>>> # Example plotting multiple values
>>> import torch
>>> from torchmetrics.audio.pesq import PerceptualEvaluationSpeechQuality
>>> from torchmetrics.audio import PerceptualEvaluationSpeechQuality
>>> metric = PerceptualEvaluationSpeechQuality(8000, 'nb')
>>> values = [ ]
>>> for _ in range(10):
Expand Down
12 changes: 6 additions & 6 deletions src/torchmetrics/audio/pit.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,8 @@ class PermutationInvariantTraining(Metric):

Example:
>>> import torch
>>> from torchmetrics import PermutationInvariantTraining
>>> from torchmetrics.functional import scale_invariant_signal_noise_ratio
>>> from torchmetrics.audio import PermutationInvariantTraining
>>> from torchmetrics.functional.audio import scale_invariant_signal_noise_ratio
>>> _ = torch.manual_seed(42)
>>> preds = torch.randn(3, 2, 5) # [batch, spk, time]
>>> target = torch.randn(3, 2, 5) # [batch, spk, time]
Expand Down Expand Up @@ -122,8 +122,8 @@ def plot(self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_

>>> # Example plotting a single value
>>> import torch
>>> from torchmetrics.audio.pit import PermutationInvariantTraining
>>> from torchmetrics.functional import scale_invariant_signal_noise_ratio
>>> from torchmetrics.audio import PermutationInvariantTraining
>>> from torchmetrics.functional.audio import scale_invariant_signal_noise_ratio
>>> preds = torch.randn(3, 2, 5) # [batch, spk, time]
>>> target = torch.randn(3, 2, 5) # [batch, spk, time]
>>> metric = PermutationInvariantTraining(scale_invariant_signal_noise_ratio, 'max')
Expand All @@ -135,8 +135,8 @@ def plot(self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_

>>> # Example plotting multiple values
>>> import torch
>>> from torchmetrics.audio.pit import PermutationInvariantTraining
>>> from torchmetrics.functional import scale_invariant_signal_noise_ratio
>>> from torchmetrics.audio import PermutationInvariantTraining
>>> from torchmetrics.functional.audio import scale_invariant_signal_noise_ratio
>>> preds = torch.randn(3, 2, 5) # [batch, spk, time]
>>> target = torch.randn(3, 2, 5) # [batch, spk, time]
>>> metric = PermutationInvariantTraining(scale_invariant_signal_noise_ratio, 'max')
Expand Down
4 changes: 2 additions & 2 deletions src/torchmetrics/audio/sdr.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def plot(self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_

>>> # Example plotting a single value
>>> import torch
>>> from torchmetrics.audio.sdr import SignalDistortionRatio
>>> from torchmetrics.audio import SignalDistortionRatio
>>> metric = SignalDistortionRatio()
>>> metric.update(torch.rand(8000), torch.rand(8000))
>>> fig_, ax_ = metric.plot()
Expand All @@ -149,7 +149,7 @@ def plot(self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_

>>> # Example plotting multiple values
>>> import torch
>>> from torchmetrics.audio.sdr import SignalDistortionRatio
>>> from torchmetrics.audio import SignalDistortionRatio
>>> metric = SignalDistortionRatio()
>>> values = [ ]
>>> for _ in range(10):
Expand Down
13 changes: 10 additions & 3 deletions src/torchmetrics/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,16 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from torchmetrics.functional.audio.pit import permutation_invariant_training, pit_permutate
from torchmetrics.functional.audio.sdr import scale_invariant_signal_distortion_ratio, signal_distortion_ratio
from torchmetrics.functional.audio.snr import scale_invariant_signal_noise_ratio, signal_noise_ratio
from torchmetrics.functional.audio._deprecated import _permutation_invariant_training as permutation_invariant_training
from torchmetrics.functional.audio._deprecated import _pit_permutate as pit_permutate
from torchmetrics.functional.audio._deprecated import (
_scale_invariant_signal_distortion_ratio as scale_invariant_signal_distortion_ratio,
)
from torchmetrics.functional.audio._deprecated import (
_scale_invariant_signal_noise_ratio as scale_invariant_signal_noise_ratio,
)
from torchmetrics.functional.audio._deprecated import _signal_distortion_ratio as signal_distortion_ratio
from torchmetrics.functional.audio._deprecated import _signal_noise_ratio as signal_noise_ratio
from torchmetrics.functional.classification.accuracy import accuracy
from torchmetrics.functional.classification.auroc import auroc
from torchmetrics.functional.classification.average_precision import average_precision
Expand Down
117 changes: 117 additions & 0 deletions src/torchmetrics/functional/audio/_deprecated.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
from typing import Any, Callable, Optional, Tuple

from torch import Tensor
from typing_extensions import Literal

from torchmetrics.functional.audio.pit import permutation_invariant_training, pit_permutate
from torchmetrics.functional.audio.sdr import scale_invariant_signal_distortion_ratio, signal_distortion_ratio
from torchmetrics.functional.audio.snr import scale_invariant_signal_noise_ratio, signal_noise_ratio
from torchmetrics.utilities.prints import _deprecated_root_import_func


def _permutation_invariant_training(
preds: Tensor, target: Tensor, metric_func: Callable, eval_func: Literal["max", "min"] = "max", **kwargs: Any
) -> Tuple[Tensor, Tensor]:
"""Wrapper for deprecated import.

>>> from torch import tensor
>>> preds = tensor([[[-0.0579, 0.3560, -0.9604], [-0.1719, 0.3205, 0.2951]]])
>>> target = tensor([[[ 1.0958, -0.1648, 0.5228], [-0.4100, 1.1942, -0.5103]]])
>>> best_metric, best_perm = _permutation_invariant_training(
... preds, target, _scale_invariant_signal_distortion_ratio, 'max')
>>> best_metric
tensor([-5.1091])
>>> best_perm
tensor([[0, 1]])
>>> pit_permutate(preds, best_perm)
tensor([[[-0.0579, 0.3560, -0.9604],
[-0.1719, 0.3205, 0.2951]]])
"""
_deprecated_root_import_func("permutation_invariant_training", "audio")
return permutation_invariant_training(
preds=preds, target=target, metric_func=metric_func, eval_func=eval_func, **kwargs
)


def _pit_permutate(preds: Tensor, perm: Tensor) -> Tensor:
"""Wrapper for deprecated import."""
_deprecated_root_import_func("pit_permutate", "audio")
return pit_permutate(preds=preds, perm=perm)


def _scale_invariant_signal_distortion_ratio(preds: Tensor, target: Tensor, zero_mean: bool = False) -> Tensor:
"""Wrapper for deprecated import.

>>> from torch import tensor
>>> target = tensor([3.0, -0.5, 2.0, 7.0])
>>> preds = tensor([2.5, 0.0, 2.0, 8.0])
>>> _scale_invariant_signal_distortion_ratio(preds, target)
tensor(18.4030)
"""
_deprecated_root_import_func("scale_invariant_signal_distortion_ratio", "audio")
return scale_invariant_signal_distortion_ratio(preds=preds, target=target, zero_mean=zero_mean)


def _signal_distortion_ratio(
preds: Tensor,
target: Tensor,
use_cg_iter: Optional[int] = None,
filter_length: int = 512,
zero_mean: bool = False,
load_diag: Optional[float] = None,
) -> Tensor:
"""Wrapper for deprecated import.

>>> import torch
>>> g = torch.manual_seed(1)
>>> preds = torch.randn(8000)
>>> target = torch.randn(8000)
>>> _signal_distortion_ratio(preds, target)
tensor(-12.0589)
>>> # use with permutation_invariant_training
>>> preds = torch.randn(4, 2, 8000) # [batch, spk, time]
>>> target = torch.randn(4, 2, 8000)
>>> best_metric, best_perm = _permutation_invariant_training(preds, target, _signal_distortion_ratio, 'max')
>>> best_metric
tensor([-11.6375, -11.4358, -11.7148, -11.6325])
>>> best_perm
tensor([[1, 0],
[0, 1],
[1, 0],
[0, 1]])
"""
_deprecated_root_import_func("signal_distortion_ratio", "audio")
return signal_distortion_ratio(
preds=preds,
target=target,
use_cg_iter=use_cg_iter,
filter_length=filter_length,
zero_mean=zero_mean,
load_diag=load_diag,
)


def _scale_invariant_signal_noise_ratio(preds: Tensor, target: Tensor) -> Tensor:
"""Wrapper for deprecated import.

>>> from torch import tensor
>>> target = tensor([3.0, -0.5, 2.0, 7.0])
>>> preds = tensor([2.5, 0.0, 2.0, 8.0])
>>> _scale_invariant_signal_noise_ratio(preds, target)
tensor(15.0918)
"""
_deprecated_root_import_func("scale_invariant_signal_noise_ratio", "audio")
return scale_invariant_signal_noise_ratio(preds=preds, target=target)


def _signal_noise_ratio(preds: Tensor, target: Tensor, zero_mean: bool = False) -> Tensor:
"""Wrapper for deprecated import.

>>> from torch import tensor
>>> target = tensor([3.0, -0.5, 2.0, 7.0])
>>> preds = tensor([2.5, 0.0, 2.0, 8.0])
>>> _signal_noise_ratio(preds, target)
tensor(16.1805)
"""
_deprecated_root_import_func("signal_noise_ratio", "audio")
return signal_noise_ratio(preds=preds, target=target, zero_mean=zero_mean)
16 changes: 16 additions & 0 deletions src/torchmetrics/utilities/prints.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,3 +54,19 @@ def _debug(*args: Any, **kwargs: Any) -> None:
rank_zero_info = rank_zero_only(_info)
rank_zero_warn = rank_zero_only(_warn)
_future_warning = partial(warnings.warn, category=FutureWarning)


def _deprecated_root_import_class(name: str, domain: str) -> None:
"""Warn user that he is importing class from location it has been deprecated."""
_future_warning(
f"Importing `{name}` from `torchmetrics` was deprecated and will be removed in 2.0."
f" Import `{name}` from `torchmetrics.{domain}` instead."
)


def _deprecated_root_import_func(name: str, domain: str) -> None:
"""Warn user that he is importing function from location it has been deprecated."""
_future_warning(
f"Importing `{name}` from `torchmetrics.functional` was deprecated and will be removed in 2.0."
f" Import `{name}` from `torchmetrics.{domain}` instead."
)
4 changes: 2 additions & 2 deletions tests/unittests/audio/test_pesq.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
from scipy.io import wavfile
from torch import Tensor

from torchmetrics.audio.pesq import PerceptualEvaluationSpeechQuality
from torchmetrics.functional.audio.pesq import perceptual_evaluation_speech_quality
from torchmetrics.audio import PerceptualEvaluationSpeechQuality
from torchmetrics.functional.audio import perceptual_evaluation_speech_quality
from unittests.audio import _SAMPLE_AUDIO_SPEECH, _SAMPLE_AUDIO_SPEECH_BAB_DB
from unittests.helpers import seed_all
from unittests.helpers.testers import MetricTester
Expand Down
Loading