From 9da607f77c76470be28f5e4feba72a9e8d496151 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Detlefsen Date: Mon, 7 Mar 2022 14:55:39 +0100 Subject: [PATCH] Revert "Refactor: Rename update and compute methods to _update and _compute (#840)" This reverts commit 99a0c6b80cd7c9a23cff0f3e740a195341783ad3. --- CHANGELOG.md | 3 - docs/source/pages/brief_intro.rst | 4 +- docs/source/pages/implement.rst | 29 +++-- docs/source/pages/quickstart.rst | 14 +-- integrations/test_lightning.py | 4 +- tests/bases/test_aggregation.py | 16 +-- tests/bases/test_collections.py | 8 +- tests/bases/test_composition.py | 4 +- tests/bases/test_ddp.py | 8 +- tests/bases/test_metric.py | 29 +---- tests/helpers/testers.py | 18 +-- tests/test_deprecated.py | 43 +------ tests/wrappers/test_bootstrapping.py | 2 +- tests/wrappers/test_minmax.py | 4 +- tests/wrappers/test_multioutput.py | 4 +- torchmetrics/aggregation.py | 18 +-- torchmetrics/audio/pesq.py | 4 +- torchmetrics/audio/pit.py | 4 +- torchmetrics/audio/sdr.py | 8 +- torchmetrics/audio/snr.py | 8 +- torchmetrics/audio/stoi.py | 4 +- torchmetrics/classification/accuracy.py | 4 +- torchmetrics/classification/auc.py | 4 +- torchmetrics/classification/auroc.py | 4 +- torchmetrics/classification/avg_precision.py | 4 +- .../classification/binned_precision_recall.py | 12 +- .../classification/calibration_error.py | 4 +- torchmetrics/classification/cohen_kappa.py | 4 +- .../classification/confusion_matrix.py | 4 +- torchmetrics/classification/f_beta.py | 2 +- torchmetrics/classification/hamming.py | 4 +- torchmetrics/classification/hinge.py | 4 +- torchmetrics/classification/jaccard.py | 2 +- torchmetrics/classification/kl_divergence.py | 4 +- .../classification/matthews_corrcoef.py | 4 +- .../classification/precision_recall.py | 4 +- .../classification/precision_recall_curve.py | 4 +- torchmetrics/classification/roc.py | 4 +- torchmetrics/classification/specificity.py | 2 +- torchmetrics/classification/stat_scores.py | 4 +- torchmetrics/detection/map.py | 4 +- torchmetrics/image/fid.py | 4 +- torchmetrics/image/inception.py | 4 +- torchmetrics/image/kid.py | 4 +- torchmetrics/image/lpip.py | 4 +- torchmetrics/image/psnr.py | 4 +- torchmetrics/image/ssim.py | 8 +- torchmetrics/image/uqi.py | 4 +- torchmetrics/metric.py | 112 +++++------------- torchmetrics/regression/cosine_similarity.py | 4 +- torchmetrics/regression/explained_variance.py | 4 +- torchmetrics/regression/log_mse.py | 4 +- torchmetrics/regression/mae.py | 4 +- torchmetrics/regression/mape.py | 4 +- torchmetrics/regression/mse.py | 4 +- torchmetrics/regression/pearson.py | 4 +- torchmetrics/regression/r2.py | 4 +- torchmetrics/regression/spearman.py | 4 +- torchmetrics/regression/symmetric_mape.py | 4 +- torchmetrics/regression/tweedie_deviance.py | 4 +- torchmetrics/retrieval/base.py | 4 +- torchmetrics/retrieval/fall_out.py | 2 +- torchmetrics/text/bert.py | 4 +- torchmetrics/text/bleu.py | 4 +- torchmetrics/text/cer.py | 4 +- torchmetrics/text/chrf.py | 4 +- torchmetrics/text/eed.py | 4 +- torchmetrics/text/mer.py | 4 +- torchmetrics/text/rouge.py | 4 +- torchmetrics/text/sacre_bleu.py | 2 +- torchmetrics/text/squad.py | 4 +- torchmetrics/text/ter.py | 4 +- torchmetrics/text/wer.py | 4 +- torchmetrics/text/wil.py | 4 +- torchmetrics/text/wip.py | 4 +- torchmetrics/wrappers/bootstrapping.py | 9 +- torchmetrics/wrappers/classwise.py | 7 +- torchmetrics/wrappers/minmax.py | 4 +- torchmetrics/wrappers/multioutput.py | 8 +- 79 files changed, 229 insertions(+), 361 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 34ff900f4f3..7b2e21088b0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -42,9 +42,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Deprecated passing in `dist_sync_on_step`, `process_group`, `dist_sync_fn` direct argument ([#833](https://github.com/PyTorchLightning/metrics/pull/833)) -- Moved particular metrics implementation from `update` and `compute` methods to `_update` and `_compute` ([#840](https://github.com/PyTorchLightning/metrics/pull/840)) - - ### Removed - Removed support for versions of Lightning lower than v1.5 ([#788](https://github.com/PyTorchLightning/metrics/pull/788)) diff --git a/docs/source/pages/brief_intro.rst b/docs/source/pages/brief_intro.rst index 1be88530ba9..d8088a5bc8f 100644 --- a/docs/source/pages/brief_intro.rst +++ b/docs/source/pages/brief_intro.rst @@ -79,7 +79,7 @@ Implementing a metric self.add_state("correct", default=torch.tensor(0), dist_reduce_fx="sum") self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum") - def _update(self, preds: torch.Tensor, target: torch.Tensor): + def update(self, preds: torch.Tensor, target: torch.Tensor): # update metric states preds, target = self._input_format(preds, target) assert preds.shape == target.shape @@ -87,6 +87,6 @@ Implementing a metric self.correct += torch.sum(preds == target) self.total += target.numel() - def _compute(self): + def compute(self): # compute final result return self.correct.float() / self.total diff --git a/docs/source/pages/implement.rst b/docs/source/pages/implement.rst index 7555797faaf..35d3829082e 100644 --- a/docs/source/pages/implement.rst +++ b/docs/source/pages/implement.rst @@ -4,12 +4,11 @@ Implementing a Metric ********************* -To implement your own custom metric, subclass the base :class:`~torchmetrics.Metric` class and -implement the following methods: +To implement your own custom metric, subclass the base :class:`~torchmetrics.Metric` class and implement the following methods: - ``__init__()``: Each state variable should be called using ``self.add_state(...)``. -- ``_update()``: Any code needed to update the state given any inputs to the metric. -- ``_compute()``: Computes a final value from the state of the metric. +- ``update()``: Any code needed to update the state given any inputs to the metric. +- ``compute()``: Computes a final value from the state of the metric. We provide the remaining interface, such as ``reset()`` that will make sure to correctly reset all metric states that have been added using ``add_state``. You should therefore not implement ``reset()`` yourself. @@ -30,14 +29,14 @@ Example implementation: self.add_state("correct", default=torch.tensor(0), dist_reduce_fx="sum") self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum") - def _update(self, preds: torch.Tensor, target: torch.Tensor): + def update(self, preds: torch.Tensor, target: torch.Tensor): preds, target = self._input_format(preds, target) assert preds.shape == target.shape self.correct += torch.sum(preds == target) self.total += target.numel() - def _compute(self): + def compute(self): return self.correct.float() / self.total @@ -45,18 +44,18 @@ Internal implementation details ------------------------------- This section briefly describes how metrics work internally. We encourage looking at the source code for more info. -Whenever the public ``update`` or ``compute`` method is called they will internally try to synchronize and reduce -metric states across multiple device before calling the actual implementation provided in the private methods -`_update()` and `_compute`. More precisely, calling ``update()`` does the following internally: +Internally, TorchMetrics wraps the user defined ``update()`` and ``compute()`` method. We do this to automatically +synchronize and reduce metric states across multiple devices. More precisely, calling ``update()`` does the +following internally: 1. Clears computed cache. -2. Calls user-defined ``_update()``. +2. Calls user-defined ``update()``. Similarly, calling ``compute()`` does the following internally: 1. Syncs metric states between processes. 2. Reduce gathered metric states. -3. Calls the user defined ``_compute()`` method on the gathered metric states. +3. Calls the user defined ``compute()`` method on the gathered metric states. 4. Cache computed result. From a user's standpoint this has one important side-effect: computed results are cached. This means that no @@ -77,6 +76,7 @@ to ``update`` and ``compute`` in the following way: This procedure has the consequence of calling the user defined ``update`` **twice** during a single forward call (one to update global statistics and one for getting the batch statistics). + --------- .. autoclass:: torchmetrics.Metric @@ -95,8 +95,7 @@ and tests gets formatted in the following way: metric (classification, regression, nlp etc) and ``new_metric`` is the name of the metric. In this file, there should be the following three functions: - 1. ``_new_metric_update(...)``: everything that has to do with type/shape checking and all logic required before distributed - syncing need to go here. + 1. ``_new_metric_update(...)``: everything that has to do with type/shape checking and all logic required before distributed syncing need to go here. 2. ``_new_metric_compute(...)``: all remaining logic. 3. ``new_metric(...)``: essentially wraps the ``_update`` and ``_compute`` private functions into one public function that makes up the functional interface for the metric. @@ -110,8 +109,8 @@ and tests gets formatted in the following way: 1. Create a new module metric by subclassing ``torchmetrics.Metric``. 2. In the ``__init__`` of the module call ``self.add_state`` for as many metric states are needed for the metric to proper accumulate metric statistics. - 3. The module interface should essentially call the private ``_new_metric_update(...)`` in its `_update` method and similarly the - ``_new_metric_compute(...)`` function in its ``_compute``. No logic should really be implemented in the module interface. + 3. The module interface should essentially call the private ``_new_metric_update(...)`` in its `update` method and similarly the + ``_new_metric_compute(...)`` function in its ``compute``. No logic should really be implemented in the module interface. We do this to not have duplicate code to maintain. .. note:: diff --git a/docs/source/pages/quickstart.rst b/docs/source/pages/quickstart.rst index 219a0b60619..69a387bb4b9 100644 --- a/docs/source/pages/quickstart.rst +++ b/docs/source/pages/quickstart.rst @@ -115,18 +115,16 @@ Implementing your own metric Implementing your own metric is as easy as subclassing a :class:`torch.nn.Module`. Simply, subclass :class:`~torchmetrics.Metric` and do the following: 1. Implement ``__init__`` where you call ``self.add_state`` for every internal state that is needed for the metrics computations -2. Implement ``_update`` method, where all logic that is necessary for updating metric states go -3. Implement ``_compute`` method, where the final metric computations happens +2. Implement ``update`` method, where all logic that is necessary for updating metric states go +3. Implement ``compute`` method, where the final metric computations happens For practical examples and more info about implementing a metric, please see this :ref:`page `. + Development Environment ~~~~~~~~~~~~~~~~~~~~~~~ -TorchMetrics provides a `Devcontainer `_ configuration for -`Visual Studio Code `_ to use a `Docker container `_ as a -pre-configured development environment. +TorchMetrics provides a `Devcontainer `_ configuration for `Visual Studio Code `_ to use a `Docker container `_ as a pre-configured development environment. This avoids struggles setting up a development environment and makes them reproducible and consistent. -Please follow the `installation instructions `_ -and make yourself familiar with the `container tutorials `_ -if you want to use them. In order to use GPUs, you can enable them within the ``.devcontainer/devcontainer.json`` file. +Please follow the `installation instructions `_ and make yourself familiar with the `container tutorials `_ if you want to use them. +In order to use GPUs, you can enable them within the ``.devcontainer/devcontainer.json`` file. diff --git a/integrations/test_lightning.py b/integrations/test_lightning.py index e76fadf18fc..226faee7e72 100644 --- a/integrations/test_lightning.py +++ b/integrations/test_lightning.py @@ -23,8 +23,8 @@ class DiffMetric(SumMetric): - def _update(self, value): - super()._update(-value) + def update(self, value): + super().update(-value) def test_metric_lightning(tmpdir): diff --git a/tests/bases/test_aggregation.py b/tests/bases/test_aggregation.py index 982f62aa849..106621e9cb4 100644 --- a/tests/bases/test_aggregation.py +++ b/tests/bases/test_aggregation.py @@ -31,33 +31,33 @@ def compare_max(values, weights): class WrappedMinMetric(MinMetric): """Wrapped min metric.""" - def _update(self, values, weights): + def update(self, values, weights): """only pass values on.""" - super()._update(values) + super().update(values) class WrappedMaxMetric(MaxMetric): """Wrapped max metric.""" - def _update(self, values, weights): + def update(self, values, weights): """only pass values on.""" - super()._update(values) + super().update(values) class WrappedSumMetric(SumMetric): """Wrapped min metric.""" - def _update(self, values, weights): + def update(self, values, weights): """only pass values on.""" - super()._update(values) + super().update(values) class WrappedCatMetric(CatMetric): """Wrapped cat metric.""" - def _update(self, values, weights): + def update(self, values, weights): """only pass values on.""" - super()._update(values) + super().update(values) @pytest.mark.parametrize( diff --git a/tests/bases/test_collections.py b/tests/bases/test_collections.py index 1f98246dabb..6427f58dcbc 100644 --- a/tests/bases/test_collections.py +++ b/tests/bases/test_collections.py @@ -258,20 +258,20 @@ class DummyMetric(Metric): def __init__(self): super().__init__() - def _update(self, *args, kwarg): + def update(self, *args, kwarg): print("Entered DummyMetric") - def _compute(self): + def compute(self): return class MyAccuracy(Metric): def __init__(self): super().__init__() - def _update(self, preds, target, kwarg2): + def update(self, preds, target, kwarg2): print("Entered MyAccuracy") - def _compute(self): + def compute(self): return mc = MetricCollection([Accuracy(), DummyMetric()]) diff --git a/tests/bases/test_composition.py b/tests/bases/test_composition.py index 978d43992bf..36516c9997a 100644 --- a/tests/bases/test_composition.py +++ b/tests/bases/test_composition.py @@ -28,10 +28,10 @@ def __init__(self, val_to_return): self._val_to_return = val_to_return self._update_called = True - def _update(self, *args, **kwargs) -> None: + def update(self, *args, **kwargs) -> None: self._num_updates += 1 - def _compute(self): + def compute(self): return tensor(self._val_to_return) diff --git a/tests/bases/test_ddp.py b/tests/bases/test_ddp.py index 2202aa66642..805faa493d2 100644 --- a/tests/bases/test_ddp.py +++ b/tests/bases/test_ddp.py @@ -115,10 +115,10 @@ def __init__(self): super().__init__() self.add_state("x", default=[], dist_reduce_fx=None) - def _update(self, x): + def update(self, x): self.x.append(x) - def _compute(self): + def compute(self): x = torch.cat(self.x, dim=0) return x.sum() @@ -141,11 +141,11 @@ def __init__(self): self.add_state("x", torch.tensor(0), dist_reduce_fx=torch.sum) self.add_state("c", torch.tensor(0), dist_reduce_fx=torch.sum) - def _update(self, x): + def update(self, x): self.x += x self.c += 1 - def _compute(self): + def compute(self): return self.x // self.c def __repr__(self): diff --git a/tests/bases/test_metric.py b/tests/bases/test_metric.py index 23dcadd82b2..bad93b2b0d7 100644 --- a/tests/bases/test_metric.py +++ b/tests/bases/test_metric.py @@ -22,7 +22,6 @@ from tests.helpers import seed_all from tests.helpers.testers import DummyListMetric, DummyMetric, DummyMetricMultiOutput, DummyMetricSum -from torchmetrics.metric import Metric from torchmetrics.utilities.imports import _TORCH_LOWER_1_6 seed_all(42) @@ -37,24 +36,6 @@ def test_error_on_wrong_input(): DummyMetric(dist_sync_fn=[2, 3]) -def test_error_on_not_implemented_methods(): - """Test that error is raised if _update or _compute is not implemented.""" - - class TempMetric(Metric): - def _compute(self): - return None - - with pytest.raises(NotImplementedError, match="Expected method `_update` to be implemented in subclass."): - TempMetric() - - class TempMetric(Metric): - def _update(self): - pass - - with pytest.raises(NotImplementedError, match="Expected method `_compute` to be implemented in subclass."): - TempMetric() - - def test_inherit(): """Test that metric that inherits can be instanciated.""" DummyMetric() @@ -135,7 +116,7 @@ def test_reset_compute(): def test_update(): class A(DummyMetric): - def _update(self, x): + def update(self, x): self.x += x a = A() @@ -151,10 +132,10 @@ def _update(self, x): def test_compute(): class A(DummyMetric): - def _update(self, x): + def update(self, x): self.x += x - def _compute(self): + def compute(self): return self.x a = A() @@ -201,10 +182,10 @@ class B(DummyListMetric): def test_forward(): class A(DummyMetric): - def _update(self, x): + def update(self, x): self.x += x - def _compute(self): + def compute(self): return self.x a = A() diff --git a/tests/helpers/testers.py b/tests/helpers/testers.py index a07f9e9a2b6..f8b7e301ace 100644 --- a/tests/helpers/testers.py +++ b/tests/helpers/testers.py @@ -571,10 +571,10 @@ def __init__(self, **kwargs): super().__init__(**kwargs) self.add_state("x", tensor(0.0), dist_reduce_fx=None) - def _update(self): + def update(self): pass - def _compute(self): + def compute(self): pass @@ -585,29 +585,29 @@ def __init__(self): super().__init__() self.add_state("x", [], dist_reduce_fx=None) - def _update(self): + def update(self): pass - def _compute(self): + def compute(self): pass class DummyMetricSum(DummyMetric): - def _update(self, x): + def update(self, x): self.x += x - def _compute(self): + def compute(self): return self.x class DummyMetricDiff(DummyMetric): - def _update(self, y): + def update(self, y): self.x -= y - def _compute(self): + def compute(self): return self.x class DummyMetricMultiOutput(DummyMetricSum): - def _compute(self): + def compute(self): return [self.x, self.x] diff --git a/tests/test_deprecated.py b/tests/test_deprecated.py index 7074b842b40..14e03324798 100644 --- a/tests/test_deprecated.py +++ b/tests/test_deprecated.py @@ -1,7 +1,6 @@ import pytest -import torch -from torchmetrics import Accuracy, Metric +from torchmetrics import Accuracy def test_compute_on_step(): @@ -9,43 +8,3 @@ def test_compute_on_step(): DeprecationWarning, match="Argument `compute_on_step` is deprecated in v0.8 and will be removed in v0.9" ): Accuracy(compute_on_step=False) # any metric will raise the warning - - -def test_warning_on_overriden_update(): - """Test that deprecation error is raised if user tries to overwrite update method.""" - - class OldMetricAPI(Metric): - def __init__(self): - super().__init__() - self.add_state("x", torch.tensor(0)) - - def update(self, *args, **kwargs): - self.x += 1 - - def compute(self): - return self.x - - with pytest.warns( - DeprecationWarning, match="We detected that you have overwritten the ``update`` method, which was.*" - ): - OldMetricAPI() - - -def test_warning_on_overriden_compute(): - """Test that deprecation error is raised if user tries to overwrite compute method.""" - - class OldMetricAPI(Metric): - def __init__(self): - super().__init__() - self.add_state("x", torch.tensor(0)) - - def update(self, *args, **kwargs): - self.x += 1 - - def compute(self): - return self.x - - with pytest.warns( - DeprecationWarning, match="We detected that you have overwritten the ``compute`` method, which was.*" - ): - OldMetricAPI() diff --git a/tests/wrappers/test_bootstrapping.py b/tests/wrappers/test_bootstrapping.py index 184f9988a7a..1c27bde0b07 100644 --- a/tests/wrappers/test_bootstrapping.py +++ b/tests/wrappers/test_bootstrapping.py @@ -36,7 +36,7 @@ class TestBootStrapper(BootStrapper): """For testing purpose, we subclass the bootstrapper class so we can get the exact permutation the class is creating.""" - def _update(self, *args) -> None: + def update(self, *args) -> None: self.out = [] for idx in range(self.num_bootstraps): size = len(args[0]) diff --git a/tests/wrappers/test_minmax.py b/tests/wrappers/test_minmax.py index af8c20024a3..c1b113f8843 100644 --- a/tests/wrappers/test_minmax.py +++ b/tests/wrappers/test_minmax.py @@ -15,9 +15,9 @@ class TestingMinMaxMetric(MinMaxMetric): """wrap metric to fit testing framework.""" - def _compute(self): + def compute(self): """instead of returning dict, return as list.""" - output_dict = super()._compute() + output_dict = super().compute() return [output_dict["raw"], output_dict["min"], output_dict["max"]] def forward(self, *args, **kwargs): diff --git a/tests/wrappers/test_multioutput.py b/tests/wrappers/test_multioutput.py index 58a7e0d66eb..75e30c29e4a 100644 --- a/tests/wrappers/test_multioutput.py +++ b/tests/wrappers/test_multioutput.py @@ -31,11 +31,11 @@ def __init__( num_outputs=num_outputs, ) - def _update(self, preds: torch.Tensor, target: torch.Tensor) -> None: + def update(self, preds: torch.Tensor, target: torch.Tensor) -> None: """Update the each pair of outputs and predictions.""" return self.metric.update(preds, target) - def _compute(self) -> torch.Tensor: + def compute(self) -> torch.Tensor: """Compute the R2 score between each pair of outputs and predictions.""" return self.metric.compute() diff --git a/torchmetrics/aggregation.py b/torchmetrics/aggregation.py index f6b14f7cfff..6f3dba1ea0f 100644 --- a/torchmetrics/aggregation.py +++ b/torchmetrics/aggregation.py @@ -90,11 +90,11 @@ def _cast_and_nan_check_input(self, x: Union[float, Tensor]) -> Tensor: return x.float() - def _update(self, value: Union[float, Tensor]) -> None: # type: ignore + def update(self, value: Union[float, Tensor]) -> None: # type: ignore """Overwrite in child class.""" pass - def _compute(self) -> Tensor: + def compute(self) -> Tensor: """Compute the aggregated value.""" return self.value @@ -145,7 +145,7 @@ def __init__( **kwargs, ) - def _update(self, value: Union[float, Tensor]) -> None: # type: ignore + def update(self, value: Union[float, Tensor]) -> None: # type: ignore """Update state with data. Args: @@ -203,7 +203,7 @@ def __init__( **kwargs, ) - def _update(self, value: Union[float, Tensor]) -> None: # type: ignore + def update(self, value: Union[float, Tensor]) -> None: # type: ignore """Update state with data. Args: @@ -261,7 +261,7 @@ def __init__( **kwargs, ) - def _update(self, value: Union[float, Tensor]) -> None: # type: ignore + def update(self, value: Union[float, Tensor]) -> None: # type: ignore """Update state with data. Args: @@ -312,7 +312,7 @@ def __init__( ): super().__init__("cat", [], nan_strategy, compute_on_step, **kwargs) - def _update(self, value: Union[float, Tensor]) -> None: # type: ignore + def update(self, value: Union[float, Tensor]) -> None: # type: ignore """Update state with data. Args: @@ -323,7 +323,7 @@ def _update(self, value: Union[float, Tensor]) -> None: # type: ignore if any(value.flatten()): self.value.append(value) - def _compute(self) -> Tensor: + def compute(self) -> Tensor: """Compute the aggregated value.""" if isinstance(self.value, list) and self.value: return dim_zero_cat(self.value) @@ -377,7 +377,7 @@ def __init__( ) self.add_state("weight", default=torch.tensor(0.0), dist_reduce_fx="sum") - def _update(self, value: Union[float, Tensor], weight: Union[float, Tensor] = 1.0) -> None: # type: ignore + def update(self, value: Union[float, Tensor], weight: Union[float, Tensor] = 1.0) -> None: # type: ignore """Update state with data. Args: @@ -403,6 +403,6 @@ def _update(self, value: Union[float, Tensor], weight: Union[float, Tensor] = 1. self.value += (value * weight).sum() self.weight += weight.sum() - def _compute(self) -> Tensor: + def compute(self) -> Tensor: """Compute the aggregated value.""" return self.value / self.weight diff --git a/torchmetrics/audio/pesq.py b/torchmetrics/audio/pesq.py index c70543f8930..aed0157a349 100644 --- a/torchmetrics/audio/pesq.py +++ b/torchmetrics/audio/pesq.py @@ -107,7 +107,7 @@ def __init__( self.add_state("sum_pesq", default=tensor(0.0), dist_reduce_fx="sum") self.add_state("total", default=tensor(0), dist_reduce_fx="sum") - def _update(self, preds: Tensor, target: Tensor) -> None: # type: ignore + def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore """Update state with predictions and targets. Args: @@ -121,6 +121,6 @@ def _update(self, preds: Tensor, target: Tensor) -> None: # type: ignore self.sum_pesq += pesq_batch.sum() self.total += pesq_batch.numel() - def _compute(self) -> Tensor: + def compute(self) -> Tensor: """Computes average PESQ.""" return self.sum_pesq / self.total diff --git a/torchmetrics/audio/pit.py b/torchmetrics/audio/pit.py index aed0ca19c0a..8384ef505b3 100644 --- a/torchmetrics/audio/pit.py +++ b/torchmetrics/audio/pit.py @@ -91,7 +91,7 @@ def __init__( self.add_state("sum_pit_metric", default=tensor(0.0), dist_reduce_fx="sum") self.add_state("total", default=tensor(0), dist_reduce_fx="sum") - def _update(self, preds: Tensor, target: Tensor) -> None: # type: ignore + def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore """Update state with predictions and targets. Args: @@ -103,6 +103,6 @@ def _update(self, preds: Tensor, target: Tensor) -> None: # type: ignore self.sum_pit_metric += pit_metric.sum() self.total += pit_metric.numel() - def _compute(self) -> Tensor: + def compute(self) -> Tensor: """Computes average PermutationInvariantTraining metric.""" return self.sum_pit_metric / self.total diff --git a/torchmetrics/audio/sdr.py b/torchmetrics/audio/sdr.py index c6355ed15f6..63281318b5f 100644 --- a/torchmetrics/audio/sdr.py +++ b/torchmetrics/audio/sdr.py @@ -126,7 +126,7 @@ def __init__( self.add_state("sum_sdr", default=tensor(0.0), dist_reduce_fx="sum") self.add_state("total", default=tensor(0), dist_reduce_fx="sum") - def _update(self, preds: Tensor, target: Tensor) -> None: # type: ignore + def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore """Update state with predictions and targets. Args: @@ -140,7 +140,7 @@ def _update(self, preds: Tensor, target: Tensor) -> None: # type: ignore self.sum_sdr += sdr_batch.sum() self.total += sdr_batch.numel() - def _compute(self) -> Tensor: + def compute(self) -> Tensor: """Computes average SDR.""" return self.sum_sdr / self.total @@ -204,7 +204,7 @@ def __init__( self.add_state("sum_si_sdr", default=tensor(0.0), dist_reduce_fx="sum") self.add_state("total", default=tensor(0), dist_reduce_fx="sum") - def _update(self, preds: Tensor, target: Tensor) -> None: # type: ignore + def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore """Update state with predictions and targets. Args: @@ -216,6 +216,6 @@ def _update(self, preds: Tensor, target: Tensor) -> None: # type: ignore self.sum_si_sdr += si_sdr_batch.sum() self.total += si_sdr_batch.numel() - def _compute(self) -> Tensor: + def compute(self) -> Tensor: """Computes average SI-SDR.""" return self.sum_si_sdr / self.total diff --git a/torchmetrics/audio/snr.py b/torchmetrics/audio/snr.py index f3b5dc3c4be..5bab6b8bcc2 100644 --- a/torchmetrics/audio/snr.py +++ b/torchmetrics/audio/snr.py @@ -84,7 +84,7 @@ def __init__( self.add_state("sum_snr", default=tensor(0.0), dist_reduce_fx="sum") self.add_state("total", default=tensor(0), dist_reduce_fx="sum") - def _update(self, preds: Tensor, target: Tensor) -> None: # type: ignore + def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore """Update state with predictions and targets. Args: @@ -96,7 +96,7 @@ def _update(self, preds: Tensor, target: Tensor) -> None: # type: ignore self.sum_snr += snr_batch.sum() self.total += snr_batch.numel() - def _compute(self) -> Tensor: + def compute(self) -> Tensor: """Computes average SNR.""" return self.sum_snr / self.total @@ -156,7 +156,7 @@ def __init__( self.add_state("sum_si_snr", default=tensor(0.0), dist_reduce_fx="sum") self.add_state("total", default=tensor(0), dist_reduce_fx="sum") - def _update(self, preds: Tensor, target: Tensor) -> None: # type: ignore + def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore """Update state with predictions and targets. Args: @@ -168,6 +168,6 @@ def _update(self, preds: Tensor, target: Tensor) -> None: # type: ignore self.sum_si_snr += si_snr_batch.sum() self.total += si_snr_batch.numel() - def _compute(self) -> Tensor: + def compute(self) -> Tensor: """Computes average SI-SNR.""" return self.sum_si_snr / self.total diff --git a/torchmetrics/audio/stoi.py b/torchmetrics/audio/stoi.py index b1740910786..3bc10dbb33e 100644 --- a/torchmetrics/audio/stoi.py +++ b/torchmetrics/audio/stoi.py @@ -109,7 +109,7 @@ def __init__( self.add_state("sum_stoi", default=tensor(0.0), dist_reduce_fx="sum") self.add_state("total", default=tensor(0), dist_reduce_fx="sum") - def _update(self, preds: Tensor, target: Tensor) -> None: # type: ignore + def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore """Update state with predictions and targets. Args: @@ -123,6 +123,6 @@ def _update(self, preds: Tensor, target: Tensor) -> None: # type: ignore self.sum_stoi += stoi_batch.sum() self.total += stoi_batch.numel() - def _compute(self) -> Tensor: + def compute(self) -> Tensor: """Computes average STOI.""" return self.sum_stoi / self.total diff --git a/torchmetrics/classification/accuracy.py b/torchmetrics/classification/accuracy.py index ccf7480c8de..b5a2339ddcf 100644 --- a/torchmetrics/classification/accuracy.py +++ b/torchmetrics/classification/accuracy.py @@ -210,7 +210,7 @@ def __init__( self.add_state("correct", default=tensor(0), dist_reduce_fx="sum") self.add_state("total", default=tensor(0), dist_reduce_fx="sum") - def _update(self, preds: Tensor, target: Tensor) -> None: # type: ignore + def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore """Update state with predictions and targets. See :ref:`references/modules:input types` for more information on input types. @@ -264,7 +264,7 @@ def _update(self, preds: Tensor, target: Tensor) -> None: # type: ignore self.tn.append(tn) self.fn.append(fn) - def _compute(self) -> Tensor: + def compute(self) -> Tensor: """Computes accuracy based on inputs passed in to ``update`` previously.""" if not self.mode: raise RuntimeError("You have to have determined mode.") diff --git a/torchmetrics/classification/auc.py b/torchmetrics/classification/auc.py index 841c9a0fe52..722e7a47ec1 100644 --- a/torchmetrics/classification/auc.py +++ b/torchmetrics/classification/auc.py @@ -63,7 +63,7 @@ def __init__( " For large datasets this may lead to large memory footprint." ) - def _update(self, preds: Tensor, target: Tensor) -> None: # type: ignore + def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore """Update state with predictions and targets. Args: @@ -75,7 +75,7 @@ def _update(self, preds: Tensor, target: Tensor) -> None: # type: ignore self.x.append(x) self.y.append(y) - def _compute(self) -> Tensor: + def compute(self) -> Tensor: """Computes AUC based on inputs passed in to ``update`` previously.""" x = dim_zero_cat(self.x) y = dim_zero_cat(self.y) diff --git a/torchmetrics/classification/auroc.py b/torchmetrics/classification/auroc.py index a2bd2298eca..92cdedfa6ed 100644 --- a/torchmetrics/classification/auroc.py +++ b/torchmetrics/classification/auroc.py @@ -147,7 +147,7 @@ def __init__( " For large datasets this may lead to large memory footprint." ) - def _update(self, preds: Tensor, target: Tensor) -> None: # type: ignore + def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore """Update state with predictions and targets. Args: @@ -166,7 +166,7 @@ def _update(self, preds: Tensor, target: Tensor) -> None: # type: ignore ) self.mode = mode - def _compute(self) -> Tensor: + def compute(self) -> Tensor: """Computes AUROC based on inputs passed in to ``update`` previously.""" if not self.mode: raise RuntimeError("You have to have determined mode.") diff --git a/torchmetrics/classification/avg_precision.py b/torchmetrics/classification/avg_precision.py index dbc7221c62a..d14745b5f86 100644 --- a/torchmetrics/classification/avg_precision.py +++ b/torchmetrics/classification/avg_precision.py @@ -114,7 +114,7 @@ def __init__( " For large datasets this may lead to large memory footprint." ) - def _update(self, preds: Tensor, target: Tensor) -> None: # type: ignore + def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore """Update state with predictions and targets. Args: @@ -129,7 +129,7 @@ def _update(self, preds: Tensor, target: Tensor) -> None: # type: ignore self.num_classes = num_classes self.pos_label = pos_label - def _compute(self) -> Union[Tensor, List[Tensor]]: + def compute(self) -> Union[Tensor, List[Tensor]]: """Compute the average precision score. Returns: diff --git a/torchmetrics/classification/binned_precision_recall.py b/torchmetrics/classification/binned_precision_recall.py index defd0f7c775..66665b78eff 100644 --- a/torchmetrics/classification/binned_precision_recall.py +++ b/torchmetrics/classification/binned_precision_recall.py @@ -147,7 +147,7 @@ def __init__( dist_reduce_fx="sum", ) - def _update(self, preds: Tensor, target: Tensor) -> None: # type: ignore + def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore """ Args preds: (n_samples, n_classes) tensor @@ -169,7 +169,7 @@ def _update(self, preds: Tensor, target: Tensor) -> None: # type: ignore self.FPs[:, i] += ((~target) & predictions).sum(dim=0) self.FNs[:, i] += (target & (~predictions)).sum(dim=0) - def _compute(self) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]: + def compute(self) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]: """Returns float tensor of size n_classes.""" precisions = (self.TPs + METRIC_EPS) / (self.TPs + self.FPs + METRIC_EPS) recalls = self.TPs / (self.TPs + self.FNs + METRIC_EPS) @@ -237,8 +237,8 @@ class BinnedAveragePrecision(BinnedPrecisionRecallCurve): [tensor(1.0000), tensor(1.0000), tensor(0.2500), tensor(0.2500), tensor(-0.)] """ - def _compute(self) -> Union[List[Tensor], Tensor]: # type: ignore - precisions, recalls, _ = super()._compute() + def compute(self) -> Union[List[Tensor], Tensor]: # type: ignore + precisions, recalls, _ = super().compute() return _average_precision_compute_with_precision_recall(precisions, recalls, self.num_classes, average=None) @@ -305,9 +305,9 @@ def __init__( super().__init__(num_classes=num_classes, thresholds=thresholds, compute_on_step=compute_on_step, **kwargs) self.min_precision = min_precision - def _compute(self) -> Tuple[Tensor, Tensor]: # type: ignore + def compute(self) -> Tuple[Tensor, Tensor]: # type: ignore """Returns float tensor of size n_classes.""" - precisions, recalls, thresholds = super()._compute() + precisions, recalls, thresholds = super().compute() if self.num_classes == 1: return _recall_at_precision(precisions, recalls, thresholds, self.min_precision) diff --git a/torchmetrics/classification/calibration_error.py b/torchmetrics/classification/calibration_error.py index 0d810e5ed34..72519b47980 100644 --- a/torchmetrics/classification/calibration_error.py +++ b/torchmetrics/classification/calibration_error.py @@ -90,7 +90,7 @@ def __init__( self.add_state("confidences", [], dist_reduce_fx="cat") self.add_state("accuracies", [], dist_reduce_fx="cat") - def _update(self, preds: Tensor, target: Tensor) -> None: # type: ignore + def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore """Computes top-level confidences and accuracies for the input probabilites and appends them to internal state. @@ -103,7 +103,7 @@ def _update(self, preds: Tensor, target: Tensor) -> None: # type: ignore self.confidences.append(confidences) self.accuracies.append(accuracies) - def _compute(self) -> Tensor: + def compute(self) -> Tensor: """Computes calibration error across all confidences and accuracies. Returns: diff --git a/torchmetrics/classification/cohen_kappa.py b/torchmetrics/classification/cohen_kappa.py index f38a67fa0fb..42f85f50aa5 100644 --- a/torchmetrics/classification/cohen_kappa.py +++ b/torchmetrics/classification/cohen_kappa.py @@ -99,7 +99,7 @@ def __init__( self.add_state("confmat", default=torch.zeros(num_classes, num_classes, dtype=torch.long), dist_reduce_fx="sum") - def _update(self, preds: Tensor, target: Tensor) -> None: # type: ignore + def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore """Update state with predictions and targets. Args: @@ -109,6 +109,6 @@ def _update(self, preds: Tensor, target: Tensor) -> None: # type: ignore confmat = _cohen_kappa_update(preds, target, self.num_classes, self.threshold) self.confmat += confmat - def _compute(self) -> Tensor: + def compute(self) -> Tensor: """Computes cohen kappa score.""" return _cohen_kappa_compute(self.confmat, self.weights) diff --git a/torchmetrics/classification/confusion_matrix.py b/torchmetrics/classification/confusion_matrix.py index 568f06718ef..6773e88375e 100644 --- a/torchmetrics/classification/confusion_matrix.py +++ b/torchmetrics/classification/confusion_matrix.py @@ -120,7 +120,7 @@ def __init__( default = torch.zeros(num_classes, num_classes, dtype=torch.long) self.add_state("confmat", default=default, dist_reduce_fx="sum") - def _update(self, preds: Tensor, target: Tensor) -> None: # type: ignore + def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore """Update state with predictions and targets. Args: @@ -130,7 +130,7 @@ def _update(self, preds: Tensor, target: Tensor) -> None: # type: ignore confmat = _confusion_matrix_update(preds, target, self.num_classes, self.threshold, self.multilabel) self.confmat += confmat - def _compute(self) -> Tensor: + def compute(self) -> Tensor: """Computes confusion matrix. Returns: diff --git a/torchmetrics/classification/f_beta.py b/torchmetrics/classification/f_beta.py index 0c4b99121a5..3a093d350ba 100644 --- a/torchmetrics/classification/f_beta.py +++ b/torchmetrics/classification/f_beta.py @@ -161,7 +161,7 @@ def __init__( self.average = average - def _compute(self) -> Tensor: + def compute(self) -> Tensor: """Computes fbeta over state.""" tp, fp, tn, fn = self._get_final_stats() return _fbeta_compute(tp, fp, tn, fn, self.beta, self.ignore_index, self.average, self.mdmc_reduce) diff --git a/torchmetrics/classification/hamming.py b/torchmetrics/classification/hamming.py index 858290ef793..86f241e1024 100644 --- a/torchmetrics/classification/hamming.py +++ b/torchmetrics/classification/hamming.py @@ -82,7 +82,7 @@ def __init__( self.threshold = threshold - def _update(self, preds: Tensor, target: Tensor) -> None: # type: ignore + def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore """Update state with predictions and targets. See :ref:`references/modules:input types` for more information on input types. @@ -96,6 +96,6 @@ def _update(self, preds: Tensor, target: Tensor) -> None: # type: ignore self.correct += correct self.total += total - def _compute(self) -> Tensor: + def compute(self) -> Tensor: """Computes hamming distance based on inputs passed in to ``update`` previously.""" return _hamming_distance_compute(self.correct, self.total) diff --git a/torchmetrics/classification/hinge.py b/torchmetrics/classification/hinge.py index 1d8fe2c4335..eb1900a387f 100644 --- a/torchmetrics/classification/hinge.py +++ b/torchmetrics/classification/hinge.py @@ -120,11 +120,11 @@ def __init__( self.squared = squared self.multiclass_mode = multiclass_mode - def _update(self, preds: Tensor, target: Tensor) -> None: # type: ignore + def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore measure, total = _hinge_update(preds, target, squared=self.squared, multiclass_mode=self.multiclass_mode) self.measure = measure + self.measure self.total = total + self.total - def _compute(self) -> Tensor: + def compute(self) -> Tensor: return _hinge_compute(self.measure, self.total) diff --git a/torchmetrics/classification/jaccard.py b/torchmetrics/classification/jaccard.py index e33da253520..84e87e8eae7 100644 --- a/torchmetrics/classification/jaccard.py +++ b/torchmetrics/classification/jaccard.py @@ -105,7 +105,7 @@ def __init__( self.ignore_index = ignore_index self.absent_score = absent_score - def _compute(self) -> Tensor: + def compute(self) -> Tensor: """Computes intersection over union (IoU)""" return _jaccard_from_confmat( self.confmat, self.num_classes, self.ignore_index, self.absent_score, self.reduction diff --git a/torchmetrics/classification/kl_divergence.py b/torchmetrics/classification/kl_divergence.py index 38d064f25e0..17f26d1107a 100644 --- a/torchmetrics/classification/kl_divergence.py +++ b/torchmetrics/classification/kl_divergence.py @@ -99,7 +99,7 @@ def __init__( self.add_state("measures", [], dist_reduce_fx="cat") self.add_state("total", torch.tensor(0), dist_reduce_fx="sum") - def _update(self, p: Tensor, q: Tensor) -> None: # type: ignore + def update(self, p: Tensor, q: Tensor) -> None: # type: ignore measures, total = _kld_update(p, q, self.log_prob) if self.reduction is None or self.reduction == "none": self.measures.append(measures) @@ -107,6 +107,6 @@ def _update(self, p: Tensor, q: Tensor) -> None: # type: ignore self.measures += measures.sum() self.total += total - def _compute(self) -> Tensor: + def compute(self) -> Tensor: measures = dim_zero_cat(self.measures) if self.reduction is None or self.reduction == "none" else self.measures return _kld_compute(measures, self.total, self.reduction) diff --git a/torchmetrics/classification/matthews_corrcoef.py b/torchmetrics/classification/matthews_corrcoef.py index 4dd9d53436f..a95a3380708 100644 --- a/torchmetrics/classification/matthews_corrcoef.py +++ b/torchmetrics/classification/matthews_corrcoef.py @@ -88,7 +88,7 @@ def __init__( self.add_state("confmat", default=torch.zeros(num_classes, num_classes), dist_reduce_fx="sum") - def _update(self, preds: Tensor, target: Tensor) -> None: # type: ignore + def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore """Update state with predictions and targets. Args: @@ -98,6 +98,6 @@ def _update(self, preds: Tensor, target: Tensor) -> None: # type: ignore confmat = _matthews_corrcoef_update(preds, target, self.num_classes, self.threshold) self.confmat += confmat - def _compute(self) -> Tensor: + def compute(self) -> Tensor: """Computes matthews correlation coefficient.""" return _matthews_corrcoef_compute(self.confmat) diff --git a/torchmetrics/classification/precision_recall.py b/torchmetrics/classification/precision_recall.py index 1d889b51e78..76b13d9654e 100644 --- a/torchmetrics/classification/precision_recall.py +++ b/torchmetrics/classification/precision_recall.py @@ -150,7 +150,7 @@ def __init__( self.average = average - def _compute(self) -> Tensor: + def compute(self) -> Tensor: """Computes the precision score based on inputs passed in to ``update`` previously. Return: @@ -295,7 +295,7 @@ def __init__( self.average = average - def _compute(self) -> Tensor: + def compute(self) -> Tensor: """Computes the recall score based on inputs passed in to ``update`` previously. Return: diff --git a/torchmetrics/classification/precision_recall_curve.py b/torchmetrics/classification/precision_recall_curve.py index 1607ed7089a..3c0b29e1c78 100644 --- a/torchmetrics/classification/precision_recall_curve.py +++ b/torchmetrics/classification/precision_recall_curve.py @@ -106,7 +106,7 @@ def __init__( " For large datasets this may lead to large memory footprint." ) - def _update(self, preds: Tensor, target: Tensor) -> None: # type: ignore + def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore """Update state with predictions and targets. Args: @@ -121,7 +121,7 @@ def _update(self, preds: Tensor, target: Tensor) -> None: # type: ignore self.num_classes = num_classes self.pos_label = pos_label - def _compute(self) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]: + def compute(self) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]: """Compute the precision-recall curve. Returns: diff --git a/torchmetrics/classification/roc.py b/torchmetrics/classification/roc.py index ebffe9cc6a5..e7e5bc5b8eb 100644 --- a/torchmetrics/classification/roc.py +++ b/torchmetrics/classification/roc.py @@ -130,7 +130,7 @@ def __init__( " For large datasets this may lead to large memory footprint." ) - def _update(self, preds: Tensor, target: Tensor) -> None: # type: ignore + def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore """Update state with predictions and targets. Args: @@ -143,7 +143,7 @@ def _update(self, preds: Tensor, target: Tensor) -> None: # type: ignore self.num_classes = num_classes self.pos_label = pos_label - def _compute(self) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]: + def compute(self) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]: """Compute the receiver operating characteristic. Returns: diff --git a/torchmetrics/classification/specificity.py b/torchmetrics/classification/specificity.py index a847f986b9e..045c0b172cd 100644 --- a/torchmetrics/classification/specificity.py +++ b/torchmetrics/classification/specificity.py @@ -151,7 +151,7 @@ def __init__( self.average = average - def _compute(self) -> Tensor: + def compute(self) -> Tensor: """Computes the specificity score based on inputs passed in to ``update`` previously. Return: diff --git a/torchmetrics/classification/stat_scores.py b/torchmetrics/classification/stat_scores.py index 0a6f5687798..0122a9c20d5 100644 --- a/torchmetrics/classification/stat_scores.py +++ b/torchmetrics/classification/stat_scores.py @@ -180,7 +180,7 @@ def __init__( for s in ("tp", "fp", "tn", "fn"): self.add_state(s, default=default(), dist_reduce_fx=reduce_fn) - def _update(self, preds: Tensor, target: Tensor) -> None: # type: ignore + def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore """Update state with predictions and targets. See :ref:`references/modules:input types` for more information on input types. @@ -222,7 +222,7 @@ def _get_final_stats(self) -> Tuple[Tensor, Tensor, Tensor, Tensor]: fn = torch.cat(self.fn) if isinstance(self.fn, list) else self.fn return tp, fp, tn, fn - def _compute(self) -> Tensor: + def compute(self) -> Tensor: """Computes the stat scores based on inputs passed in to ``update`` previously. Return: diff --git a/torchmetrics/detection/map.py b/torchmetrics/detection/map.py index b9433563eed..1395a014b56 100644 --- a/torchmetrics/detection/map.py +++ b/torchmetrics/detection/map.py @@ -268,7 +268,7 @@ def __init__( self.add_state("groundtruth_boxes", default=[], dist_reduce_fx=None) self.add_state("groundtruth_labels", default=[], dist_reduce_fx=None) - def _update(self, preds: List[Dict[str, Tensor]], target: List[Dict[str, Tensor]]) -> None: # type: ignore + def update(self, preds: List[Dict[str, Tensor]], target: List[Dict[str, Tensor]]) -> None: # type: ignore """Add detections and ground truth to the metric. Args: @@ -671,7 +671,7 @@ def __calculate_recall_precision_scores( return recall, precision, scores - def _compute(self) -> dict: + def compute(self) -> dict: """Compute the `Mean-Average-Precision (mAP) and Mean-Average-Recall (mAR)` scores. Note: diff --git a/torchmetrics/image/fid.py b/torchmetrics/image/fid.py index eec6e37ef27..2b7bdef7ee8 100644 --- a/torchmetrics/image/fid.py +++ b/torchmetrics/image/fid.py @@ -240,7 +240,7 @@ def __init__( self.add_state("real_features", [], dist_reduce_fx=None) self.add_state("fake_features", [], dist_reduce_fx=None) - def _update(self, imgs: Tensor, real: bool) -> None: # type: ignore + def update(self, imgs: Tensor, real: bool) -> None: # type: ignore """Update the state with extracted features. Args: @@ -254,7 +254,7 @@ def _update(self, imgs: Tensor, real: bool) -> None: # type: ignore else: self.fake_features.append(features) - def _compute(self) -> Tensor: + def compute(self) -> Tensor: """Calculate FID score based on accumulated extracted features from the two distributions.""" real_features = dim_zero_cat(self.real_features) fake_features = dim_zero_cat(self.fake_features) diff --git a/torchmetrics/image/inception.py b/torchmetrics/image/inception.py index d503dd1f75b..dfe9d0b8e09 100644 --- a/torchmetrics/image/inception.py +++ b/torchmetrics/image/inception.py @@ -138,7 +138,7 @@ def __init__( self.splits = splits self.add_state("features", [], dist_reduce_fx=None) - def _update(self, imgs: Tensor) -> None: # type: ignore + def update(self, imgs: Tensor) -> None: # type: ignore """Update the state with extracted features. Args: @@ -147,7 +147,7 @@ def _update(self, imgs: Tensor) -> None: # type: ignore features = self.inception(imgs) self.features.append(features) - def _compute(self) -> Tuple[Tensor, Tensor]: + def compute(self) -> Tuple[Tensor, Tensor]: features = dim_zero_cat(self.features) # random permute the features idx = torch.randperm(features.shape[0]) diff --git a/torchmetrics/image/kid.py b/torchmetrics/image/kid.py index a9ecb324239..ae26e13aec7 100644 --- a/torchmetrics/image/kid.py +++ b/torchmetrics/image/kid.py @@ -226,7 +226,7 @@ def __init__( self.add_state("real_features", [], dist_reduce_fx=None) self.add_state("fake_features", [], dist_reduce_fx=None) - def _update(self, imgs: Tensor, real: bool) -> None: # type: ignore + def update(self, imgs: Tensor, real: bool) -> None: # type: ignore """Update the state with extracted features. Args: @@ -240,7 +240,7 @@ def _update(self, imgs: Tensor, real: bool) -> None: # type: ignore else: self.fake_features.append(features) - def _compute(self) -> Tuple[Tensor, Tensor]: + def compute(self) -> Tuple[Tensor, Tensor]: """Calculate KID score based on accumulated extracted features from the two distributions. Returns a tuple of mean and standard deviation of KID scores calculated on subsets of extracted features. diff --git a/torchmetrics/image/lpip.py b/torchmetrics/image/lpip.py index da5fce55235..be44f1f55a9 100644 --- a/torchmetrics/image/lpip.py +++ b/torchmetrics/image/lpip.py @@ -122,7 +122,7 @@ def __init__( self.add_state("sum_scores", torch.tensor(0.0), dist_reduce_fx="sum") self.add_state("total", torch.tensor(0.0), dist_reduce_fx="sum") - def _update(self, img1: Tensor, img2: Tensor) -> None: # type: ignore + def update(self, img1: Tensor, img2: Tensor) -> None: # type: ignore """Update internal states with lpips score. Args: @@ -141,7 +141,7 @@ def _update(self, img1: Tensor, img2: Tensor) -> None: # type: ignore self.sum_scores += loss.sum() self.total += img1.shape[0] - def _compute(self) -> Tensor: + def compute(self) -> Tensor: """Compute final perceptual similarity metric.""" if self.reduction == "mean": return self.sum_scores / self.total diff --git a/torchmetrics/image/psnr.py b/torchmetrics/image/psnr.py index d68dff20fe6..4aed12a3aec 100644 --- a/torchmetrics/image/psnr.py +++ b/torchmetrics/image/psnr.py @@ -108,7 +108,7 @@ def __init__( self.reduction = reduction self.dim = tuple(dim) if isinstance(dim, Sequence) else dim - def _update(self, preds: Tensor, target: Tensor) -> None: # type: ignore + def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore """Update state with predictions and targets. Args: @@ -128,7 +128,7 @@ def _update(self, preds: Tensor, target: Tensor) -> None: # type: ignore self.sum_squared_error.append(sum_squared_error) self.total.append(n_obs) - def _compute(self) -> Tensor: + def compute(self) -> Tensor: """Compute peak signal-to-noise ratio over state.""" if self.data_range is not None: data_range = self.data_range diff --git a/torchmetrics/image/ssim.py b/torchmetrics/image/ssim.py index 5355c4b19c7..6d80532aa20 100644 --- a/torchmetrics/image/ssim.py +++ b/torchmetrics/image/ssim.py @@ -91,7 +91,7 @@ def __init__( self.k2 = k2 self.reduction = reduction - def _update(self, preds: Tensor, target: Tensor) -> None: # type: ignore + def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore """Update state with predictions and targets. Args: @@ -102,7 +102,7 @@ def _update(self, preds: Tensor, target: Tensor) -> None: # type: ignore self.preds.append(preds) self.target.append(target) - def _compute(self) -> Tensor: + def compute(self) -> Tensor: """Computes explained variance over state.""" preds = dim_zero_cat(self.preds) target = dim_zero_cat(self.target) @@ -206,7 +206,7 @@ def __init__( raise ValueError("Argument `normalize` to be expected either `None` or one of 'relu' or 'simple'") self.normalize = normalize - def _update(self, preds: Tensor, target: Tensor) -> None: # type: ignore + def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore """Update state with predictions and targets. Args: @@ -217,7 +217,7 @@ def _update(self, preds: Tensor, target: Tensor) -> None: # type: ignore self.preds.append(preds) self.target.append(target) - def _compute(self) -> Tensor: + def compute(self) -> Tensor: """Computes explained variance over state.""" preds = dim_zero_cat(self.preds) target = dim_zero_cat(self.target) diff --git a/torchmetrics/image/uqi.py b/torchmetrics/image/uqi.py index ab67429df06..27ed688c619 100644 --- a/torchmetrics/image/uqi.py +++ b/torchmetrics/image/uqi.py @@ -85,7 +85,7 @@ def __init__( self.data_range = data_range self.reduction = reduction - def _update(self, preds: Tensor, target: Tensor) -> None: # type: ignore + def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore """Update state with predictions and targets. Args: @@ -96,7 +96,7 @@ def _update(self, preds: Tensor, target: Tensor) -> None: # type: ignore self.preds.append(preds) self.target.append(target) - def _compute(self) -> Tensor: + def compute(self) -> Tensor: """Computes explained variance over state.""" preds = dim_zero_cat(self.preds) target = dim_zero_cat(self.target) diff --git a/torchmetrics/metric.py b/torchmetrics/metric.py index e5bcd5e67c8..7508f438261 100644 --- a/torchmetrics/metric.py +++ b/torchmetrics/metric.py @@ -14,7 +14,7 @@ import functools import inspect import warnings -from abc import ABC +from abc import ABC, abstractmethod from contextlib import contextmanager from copy import deepcopy from typing import Any, Callable, Dict, Generator, List, Optional, Sequence, Union @@ -41,20 +41,6 @@ def jit_distributed_available() -> bool: return torch.distributed.is_available() and torch.distributed.is_initialized() -def is_overridden(method_name: str, instance: object, parent: object) -> bool: - """Tempoary needed function to make sure that users move from old interface of overwriting update and compute - to instead of implementing _update and _compute. - - Remove in v0.9 - """ - instance_attr = getattr(instance, method_name, None) - parent_attr = getattr(parent, method_name) - if instance_attr is None: - return False - - return instance_attr.__code__ != parent_attr.__code__ - - class Metric(Module, ABC): """Base class for all metrics present in the Metrics API. @@ -134,48 +120,25 @@ def __init__( warnings.warn( "Argument `compute_on_step` is deprecated in v0.8 and will be removed in v0.9", DeprecationWarning ) + self.dist_sync_on_step = kwargs.pop("dist_sync_on_step", False) if not isinstance(self.dist_sync_on_step, bool): raise ValueError( f"Expected keyword argument `dist_sync_on_step` to be an `bool` but got {self.dist_sync_on_step}" ) + self.process_group = kwargs.pop("process_group", None) + self.dist_sync_fn = kwargs.pop("dist_sync_fn", None) if self.dist_sync_fn is not None and not callable(self.dist_sync_fn): raise ValueError( f"Expected keyword argument `dist_sync_fn` to be an callable function but got {self.dist_sync_fn}" ) - # check update and compute format - if is_overridden("update", self, Metric): - warnings.warn( - "We detected that you have overwritten the ``update`` method, which was the API" - " for torchmetrics v0.7 and below. Insted implement the ``_update`` method." - " (exact same as before just with a ``_`` infront to make the implementation private)" - " Implementing `update` directly was deprecated in v0.8 and will be removed in v0.9.", - DeprecationWarning, - ) - self._update_signature = inspect.signature(self.update) - self.update: Callable = self._wrap_update(self.update) # type: ignore - else: - if not hasattr(self, "_update"): - raise NotImplementedError("Expected method `_update` to be implemented in subclass.") - self._update_signature = inspect.signature(self._update) - - if is_overridden("compute", self, Metric): - warnings.warn( - "We detected that you have overwritten the ``compute`` method, which was the API" - " for torchmetrics v0.7 and below. Insted implement the ``_compute`` method." - " (exact same as before just with a ``_`` infront to make the implementation private)" - " Implementing `compute` directly was deprecated in v0.8 and will be removed in v0.9.", - DeprecationWarning, - ) - self.compute: Callable = self._wrap_compute(self.compute) # type: ignore - else: - if not hasattr(self, "_compute"): - raise NotImplementedError("Expected method `_compute` to be implemented in subclass.") - # initialize + self._update_signature = inspect.signature(self.update) + self.update: Callable = self._wrap_update(self.update) # type: ignore + self.compute: Callable = self._wrap_compute(self.compute) # type: ignore self._computed = None self._forward_cache = None self._update_called = False @@ -326,6 +289,15 @@ def _sync_dist(self, dist_sync_fn: Callable = gather_all_tensors, process_group: reduced = reduction_fn(output_dict[attr]) if reduction_fn is not None else output_dict[attr] setattr(self, attr, reduced) + def _wrap_update(self, update: Callable) -> Callable: + @functools.wraps(update) + def wrapped_func(*args: Any, **kwargs: Any) -> Optional[Any]: + self._computed = None + self._update_called = True + return update(*args, **kwargs) + + return wrapped_func + def sync( self, dist_sync_fn: Optional[Callable] = None, @@ -418,15 +390,6 @@ def sync_context( self.unsync(should_unsync=self._is_synced and should_unsync) - def _wrap_update(self, update: Callable) -> Callable: - @functools.wraps(update) - def wrapped_func(*args: Any, **kwargs: Any) -> Optional[Any]: - self._computed = None - self._update_called = True - return update(*args, **kwargs) - - return wrapped_func - def _wrap_compute(self, compute: Callable) -> Callable: @functools.wraps(compute) def wrapped_func(*args: Any, **kwargs: Any) -> Any: @@ -457,36 +420,14 @@ def wrapped_func(*args: Any, **kwargs: Any) -> Any: return wrapped_func - def update(self, *args: Any, **kwargs: Any) -> None: - self._computed = None - self._update_called = True - self._update(*args, **kwargs) + @abstractmethod + def update(self, *_: Any, **__: Any) -> None: + """Override this method to update the state variables of your metric class.""" + @abstractmethod def compute(self) -> Any: - if not self._update_called: - rank_zero_warn( - f"The ``compute`` method of metric {self.__class__.__name__}" - " was called before the ``update`` method which may lead to errors," - " as metric states have not yet been updated.", - UserWarning, - ) - - # return cached value - if self._computed is not None: - return self._computed - - # compute relies on the sync context manager to gather the states across processes and apply reduction - # if synchronization happened, the current rank accumulated states will be restored to keep - # accumulation going if ``should_unsync=True``, - with self.sync_context( - dist_sync_fn=self.dist_sync_fn, # type: ignore - should_sync=self._to_sync, - should_unsync=self._should_unsync, - ): - value = self._compute() - self._computed = _squeeze_if_scalar(value) - - return self._computed + """Override this method to compute the final metric value from state variables synchronized across the + distributed backend.""" def reset(self) -> None: """This method automatically resets the metric state variables to their default value.""" @@ -517,6 +458,8 @@ def __setstate__(self, state: Dict[str, Any]) -> None: # manually restore update and compute functions for pickling self.__dict__.update(state) self._update_signature = inspect.signature(self.update) + self.update: Callable = self._wrap_update(self.update) # type: ignore + self.compute: Callable = self._wrap_compute(self.compute) # type: ignore def __setattr__(self, name: str, value: Any) -> None: if name in ("higher_is_better", "is_differentiable"): @@ -828,14 +771,14 @@ def _sync_dist(self, dist_sync_fn: Optional[Callable] = None, process_group: Opt # No syncing required here. syncing will be done in metric_a and metric_b pass - def _update(self, *args: Any, **kwargs: Any) -> None: + def update(self, *args: Any, **kwargs: Any) -> None: if isinstance(self.metric_a, Metric): self.metric_a.update(*args, **self.metric_a._filter_kwargs(**kwargs)) if isinstance(self.metric_b, Metric): self.metric_b.update(*args, **self.metric_b._filter_kwargs(**kwargs)) - def _compute(self) -> Any: + def compute(self) -> Any: # also some parsing for kwargs? if isinstance(self.metric_a, Metric): @@ -900,3 +843,6 @@ def __repr__(self) -> str: repr_str = self.__class__.__name__ + _op_metrics return repr_str + + def _wrap_compute(self, compute: Callable) -> Callable: + return compute diff --git a/torchmetrics/regression/cosine_similarity.py b/torchmetrics/regression/cosine_similarity.py index 557f98c84d4..e2d82f2035e 100644 --- a/torchmetrics/regression/cosine_similarity.py +++ b/torchmetrics/regression/cosine_similarity.py @@ -78,7 +78,7 @@ def __init__( self.add_state("preds", [], dist_reduce_fx="cat") self.add_state("target", [], dist_reduce_fx="cat") - def _update(self, preds: Tensor, target: Tensor) -> None: # type: ignore + def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore """Update metric states with predictions and targets. Args: @@ -90,7 +90,7 @@ def _update(self, preds: Tensor, target: Tensor) -> None: # type: ignore self.preds.append(preds) self.target.append(target) - def _compute(self) -> Tensor: + def compute(self) -> Tensor: preds = dim_zero_cat(self.preds) target = dim_zero_cat(self.target) return _cosine_similarity_compute(preds, target, self.reduction) diff --git a/torchmetrics/regression/explained_variance.py b/torchmetrics/regression/explained_variance.py index 2d21ab803e5..b2422c237cd 100644 --- a/torchmetrics/regression/explained_variance.py +++ b/torchmetrics/regression/explained_variance.py @@ -105,7 +105,7 @@ def __init__( self.add_state("sum_squared_target", default=tensor(0.0), dist_reduce_fx="sum") self.add_state("n_obs", default=tensor(0.0), dist_reduce_fx="sum") - def _update(self, preds: Tensor, target: Tensor) -> None: # type: ignore + def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore """Update state with predictions and targets. Args: @@ -119,7 +119,7 @@ def _update(self, preds: Tensor, target: Tensor) -> None: # type: ignore self.sum_target = self.sum_target + sum_target self.sum_squared_target = self.sum_squared_target + sum_squared_target - def _compute(self) -> Union[Tensor, Sequence[Tensor]]: + def compute(self) -> Union[Tensor, Sequence[Tensor]]: """Computes explained variance over state.""" return _explained_variance_compute( self.n_obs, diff --git a/torchmetrics/regression/log_mse.py b/torchmetrics/regression/log_mse.py index bd8ccd072e7..87e7f816fa3 100644 --- a/torchmetrics/regression/log_mse.py +++ b/torchmetrics/regression/log_mse.py @@ -65,7 +65,7 @@ def __init__( self.add_state("sum_squared_log_error", default=tensor(0.0), dist_reduce_fx="sum") self.add_state("total", default=tensor(0), dist_reduce_fx="sum") - def _update(self, preds: Tensor, target: Tensor) -> None: # type: ignore + def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore """Update state with predictions and targets. Args: @@ -77,6 +77,6 @@ def _update(self, preds: Tensor, target: Tensor) -> None: # type: ignore self.sum_squared_log_error += sum_squared_log_error self.total += n_obs - def _compute(self) -> Tensor: + def compute(self) -> Tensor: """Compute mean squared logarithmic error over state.""" return _mean_squared_log_error_compute(self.sum_squared_log_error, self.total) diff --git a/torchmetrics/regression/mae.py b/torchmetrics/regression/mae.py index 0ae4fc83065..42e63ccb980 100644 --- a/torchmetrics/regression/mae.py +++ b/torchmetrics/regression/mae.py @@ -61,7 +61,7 @@ def __init__( self.add_state("sum_abs_error", default=tensor(0.0), dist_reduce_fx="sum") self.add_state("total", default=tensor(0), dist_reduce_fx="sum") - def _update(self, preds: Tensor, target: Tensor) -> None: # type: ignore + def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore """Update state with predictions and targets. Args: @@ -73,6 +73,6 @@ def _update(self, preds: Tensor, target: Tensor) -> None: # type: ignore self.sum_abs_error += sum_abs_error self.total += n_obs - def _compute(self) -> Tensor: + def compute(self) -> Tensor: """Computes mean absolute error over state.""" return _mean_absolute_error_compute(self.sum_abs_error, self.total) diff --git a/torchmetrics/regression/mape.py b/torchmetrics/regression/mape.py index 5ed760dbdb9..c343783a165 100644 --- a/torchmetrics/regression/mape.py +++ b/torchmetrics/regression/mape.py @@ -73,7 +73,7 @@ def __init__( self.add_state("sum_abs_per_error", default=tensor(0.0), dist_reduce_fx="sum") self.add_state("total", default=tensor(0.0), dist_reduce_fx="sum") - def _update(self, preds: Tensor, target: Tensor) -> None: # type: ignore + def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore """Update state with predictions and targets. Args: @@ -85,6 +85,6 @@ def _update(self, preds: Tensor, target: Tensor) -> None: # type: ignore self.sum_abs_per_error += sum_abs_per_error self.total += num_obs - def _compute(self) -> Tensor: + def compute(self) -> Tensor: """Computes mean absolute percentage error over state.""" return _mean_absolute_percentage_error_compute(self.sum_abs_per_error, self.total) diff --git a/torchmetrics/regression/mse.py b/torchmetrics/regression/mse.py index cbd33e920cc..82ed567af4e 100644 --- a/torchmetrics/regression/mse.py +++ b/torchmetrics/regression/mse.py @@ -66,7 +66,7 @@ def __init__( self.add_state("total", default=tensor(0), dist_reduce_fx="sum") self.squared = squared - def _update(self, preds: Tensor, target: Tensor) -> None: # type: ignore + def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore """Update state with predictions and targets. Args: @@ -78,6 +78,6 @@ def _update(self, preds: Tensor, target: Tensor) -> None: # type: ignore self.sum_squared_error += sum_squared_error self.total += n_obs - def _compute(self) -> Tensor: + def compute(self) -> Tensor: """Computes mean squared error over state.""" return _mean_squared_error_compute(self.sum_squared_error, self.total, squared=self.squared) diff --git a/torchmetrics/regression/pearson.py b/torchmetrics/regression/pearson.py index b95bfa9fe1f..7f5ba8f5ff7 100644 --- a/torchmetrics/regression/pearson.py +++ b/torchmetrics/regression/pearson.py @@ -111,7 +111,7 @@ def __init__( self.add_state("corr_xy", default=torch.tensor(0.0), dist_reduce_fx=None) self.add_state("n_total", default=torch.tensor(0.0), dist_reduce_fx=None) - def _update(self, preds: Tensor, target: Tensor) -> None: # type: ignore + def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore """Update state with predictions and targets. Args: @@ -122,7 +122,7 @@ def _update(self, preds: Tensor, target: Tensor) -> None: # type: ignore preds, target, self.mean_x, self.mean_y, self.var_x, self.var_y, self.corr_xy, self.n_total ) - def _compute(self) -> Tensor: + def compute(self) -> Tensor: """Computes pearson correlation coefficient over state.""" if self.mean_x.numel() > 1: # multiple devices, need further reduction var_x, var_y, corr_xy, n_total = _final_aggregation( diff --git a/torchmetrics/regression/r2.py b/torchmetrics/regression/r2.py index be22258ae24..0ce8ce0739a 100644 --- a/torchmetrics/regression/r2.py +++ b/torchmetrics/regression/r2.py @@ -123,7 +123,7 @@ def __init__( self.add_state("residual", default=torch.zeros(self.num_outputs), dist_reduce_fx="sum") self.add_state("total", default=tensor(0), dist_reduce_fx="sum") - def _update(self, preds: Tensor, target: Tensor) -> None: # type: ignore + def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore """Update state with predictions and targets. Args: @@ -137,7 +137,7 @@ def _update(self, preds: Tensor, target: Tensor) -> None: # type: ignore self.residual += residual self.total += total - def _compute(self) -> Tensor: + def compute(self) -> Tensor: """Computes r2 score over the metric states.""" return _r2_score_compute( self.sum_squared_error, self.sum_error, self.residual, self.total, self.adjusted, self.multioutput diff --git a/torchmetrics/regression/spearman.py b/torchmetrics/regression/spearman.py index ecabe8ba057..f98c72369d2 100644 --- a/torchmetrics/regression/spearman.py +++ b/torchmetrics/regression/spearman.py @@ -70,7 +70,7 @@ def __init__( self.add_state("preds", default=[], dist_reduce_fx="cat") self.add_state("target", default=[], dist_reduce_fx="cat") - def _update(self, preds: Tensor, target: Tensor) -> None: # type: ignore + def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore """Update state with predictions and targets. Args: @@ -81,7 +81,7 @@ def _update(self, preds: Tensor, target: Tensor) -> None: # type: ignore self.preds.append(preds) self.target.append(target) - def _compute(self) -> Tensor: + def compute(self) -> Tensor: """Computes spearmans correlation coefficient.""" preds = dim_zero_cat(self.preds) target = dim_zero_cat(self.target) diff --git a/torchmetrics/regression/symmetric_mape.py b/torchmetrics/regression/symmetric_mape.py index 776a5ec579b..d9170154f2a 100644 --- a/torchmetrics/regression/symmetric_mape.py +++ b/torchmetrics/regression/symmetric_mape.py @@ -70,7 +70,7 @@ def __init__( self.add_state("sum_abs_per_error", default=tensor(0.0), dist_reduce_fx="sum") self.add_state("total", default=tensor(0.0), dist_reduce_fx="sum") - def _update(self, preds: Tensor, target: Tensor) -> None: # type: ignore + def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore """Update state with predictions and targets. Args: @@ -82,6 +82,6 @@ def _update(self, preds: Tensor, target: Tensor) -> None: # type: ignore self.sum_abs_per_error += sum_abs_per_error self.total += num_obs - def _compute(self) -> Tensor: + def compute(self) -> Tensor: """Computes mean absolute percentage error over state.""" return _symmetric_mean_absolute_percentage_error_compute(self.sum_abs_per_error, self.total) diff --git a/torchmetrics/regression/tweedie_deviance.py b/torchmetrics/regression/tweedie_deviance.py index 1255c3e4fd4..85336bdafec 100644 --- a/torchmetrics/regression/tweedie_deviance.py +++ b/torchmetrics/regression/tweedie_deviance.py @@ -91,7 +91,7 @@ def __init__( self.add_state("sum_deviance_score", torch.tensor(0.0), dist_reduce_fx="sum") self.add_state("num_observations", torch.tensor(0), dist_reduce_fx="sum") - def _update(self, preds: Tensor, targets: Tensor) -> None: # type: ignore + def update(self, preds: Tensor, targets: Tensor) -> None: # type: ignore """Update metric states with predictions and targets. Args: @@ -103,5 +103,5 @@ def _update(self, preds: Tensor, targets: Tensor) -> None: # type: ignore self.sum_deviance_score += sum_deviance_score self.num_observations += num_observations - def _compute(self) -> Tensor: + def compute(self) -> Tensor: return _tweedie_deviance_score_compute(self.sum_deviance_score, self.num_observations) diff --git a/torchmetrics/retrieval/base.py b/torchmetrics/retrieval/base.py index 32b89aaa881..cbc5bb80439 100644 --- a/torchmetrics/retrieval/base.py +++ b/torchmetrics/retrieval/base.py @@ -99,7 +99,7 @@ def __init__( self.add_state("preds", default=[], dist_reduce_fx=None) self.add_state("target", default=[], dist_reduce_fx=None) - def _update(self, preds: Tensor, target: Tensor, indexes: Tensor) -> None: # type: ignore + def update(self, preds: Tensor, target: Tensor, indexes: Tensor) -> None: # type: ignore """Check shape, check and convert dtypes, flatten and add to accumulators.""" if indexes is None: raise ValueError("Argument `indexes` cannot be None") @@ -112,7 +112,7 @@ def _update(self, preds: Tensor, target: Tensor, indexes: Tensor) -> None: # ty self.preds.append(preds) self.target.append(target) - def _compute(self) -> Tensor: + def compute(self) -> Tensor: """First concat state ``indexes``, ``preds`` and ``target`` since they were stored as lists. After that, compute list of groups that will help in keeping together predictions about the same query. Finally, diff --git a/torchmetrics/retrieval/fall_out.py b/torchmetrics/retrieval/fall_out.py index 512eda78c4a..dad1218a156 100644 --- a/torchmetrics/retrieval/fall_out.py +++ b/torchmetrics/retrieval/fall_out.py @@ -97,7 +97,7 @@ def __init__( raise ValueError("`k` has to be a positive integer or None") self.k = k - def _compute(self) -> Tensor: + def compute(self) -> Tensor: """First concat state `indexes`, `preds` and `target` since they were stored as lists. After that, compute list of groups that will help in keeping together predictions about the same query. Finally, diff --git a/torchmetrics/text/bert.py b/torchmetrics/text/bert.py index 610b84a418c..b2392cd0f2f 100644 --- a/torchmetrics/text/bert.py +++ b/torchmetrics/text/bert.py @@ -192,7 +192,7 @@ def __init__( self.add_state("target_input_ids", [], dist_reduce_fx="cat") self.add_state("target_attention_mask", [], dist_reduce_fx="cat") - def _update(self, preds: List[str], target: List[str]) -> None: # type: ignore + def update(self, preds: List[str], target: List[str]) -> None: # type: ignore """Store predictions/references for computing BERT scores. It is necessary to store sentences in a tokenized form to ensure the DDP mode working. @@ -224,7 +224,7 @@ def _update(self, preds: List[str], target: List[str]) -> None: # type: ignore self.target_input_ids.append(target_dict["input_ids"]) self.target_attention_mask.append(target_dict["attention_mask"]) - def _compute(self) -> Dict[str, Union[List[float], str]]: + def compute(self) -> Dict[str, Union[List[float], str]]: """Calculate BERT scores. Return: diff --git a/torchmetrics/text/bleu.py b/torchmetrics/text/bleu.py index 5f49fa20b34..3b41f66eb4a 100644 --- a/torchmetrics/text/bleu.py +++ b/torchmetrics/text/bleu.py @@ -86,7 +86,7 @@ def __init__( self.add_state("numerator", torch.zeros(self.n_gram), dist_reduce_fx="sum") self.add_state("denominator", torch.zeros(self.n_gram), dist_reduce_fx="sum") - def _update(self, preds: Sequence[str], target: Sequence[Sequence[str]]) -> None: # type: ignore + def update(self, preds: Sequence[str], target: Sequence[Sequence[str]]) -> None: # type: ignore """Compute Precision Scores. Args: @@ -104,7 +104,7 @@ def _update(self, preds: Sequence[str], target: Sequence[Sequence[str]]) -> None _tokenize_fn, ) - def _compute(self) -> Tensor: + def compute(self) -> Tensor: """Calculate BLEU score. Return: diff --git a/torchmetrics/text/cer.py b/torchmetrics/text/cer.py index a24fc8bb0b9..9aaf89d4b0f 100644 --- a/torchmetrics/text/cer.py +++ b/torchmetrics/text/cer.py @@ -75,7 +75,7 @@ def __init__( self.add_state("errors", tensor(0, dtype=torch.float), dist_reduce_fx="sum") self.add_state("total", tensor(0, dtype=torch.float), dist_reduce_fx="sum") - def _update(self, preds: Union[str, List[str]], target: Union[str, List[str]]) -> None: # type: ignore + def update(self, preds: Union[str, List[str]], target: Union[str, List[str]]) -> None: # type: ignore """Store references/predictions for computing Character Error Rate scores. Args: @@ -86,7 +86,7 @@ def _update(self, preds: Union[str, List[str]], target: Union[str, List[str]]) - self.errors += errors self.total += total - def _compute(self) -> Tensor: + def compute(self) -> Tensor: """Calculate the character error rate. Returns: diff --git a/torchmetrics/text/chrf.py b/torchmetrics/text/chrf.py index fe363d767e1..111f246d2b3 100644 --- a/torchmetrics/text/chrf.py +++ b/torchmetrics/text/chrf.py @@ -134,7 +134,7 @@ def __init__( if self.return_sentence_level_score: self.add_state("sentence_chrf_score", [], dist_reduce_fx="cat") - def _update(self, preds: Sequence[str], target: Sequence[Sequence[str]]) -> None: # type: ignore + def update(self, preds: Sequence[str], target: Sequence[Sequence[str]]) -> None: # type: ignore """Compute Precision Scores. Args: @@ -159,7 +159,7 @@ def _update(self, preds: Sequence[str], target: Sequence[Sequence[str]]) -> None if self.sentence_chrf_score is not None: self.sentence_chrf_score = n_grams_dicts_tuple[-1] - def _compute(self) -> Union[Tensor, Tuple[Tensor, Tensor]]: + def compute(self) -> Union[Tensor, Tuple[Tensor, Tensor]]: """Calculate chrF/chrF++ score. Return: diff --git a/torchmetrics/text/eed.py b/torchmetrics/text/eed.py index 40408501724..9a42da4e76d 100644 --- a/torchmetrics/text/eed.py +++ b/torchmetrics/text/eed.py @@ -97,7 +97,7 @@ def __init__( self.add_state("sentence_eed", [], dist_reduce_fx="cat") - def _update( # type: ignore + def update( # type: ignore self, preds: Union[str, Sequence[str]], target: Sequence[Union[str, Sequence[str]]], @@ -119,7 +119,7 @@ def _update( # type: ignore self.sentence_eed, ) - def _compute(self) -> Union[Tensor, Tuple[Tensor, Tensor]]: + def compute(self) -> Union[Tensor, Tuple[Tensor, Tensor]]: """Calculate extended edit distance score. Return: diff --git a/torchmetrics/text/mer.py b/torchmetrics/text/mer.py index a60f8e8c375..b89d05feb01 100644 --- a/torchmetrics/text/mer.py +++ b/torchmetrics/text/mer.py @@ -73,7 +73,7 @@ def __init__( self.add_state("errors", tensor(0, dtype=torch.float), dist_reduce_fx="sum") self.add_state("total", tensor(0, dtype=torch.float), dist_reduce_fx="sum") - def _update( # type: ignore + def update( # type: ignore self, preds: Union[str, List[str]], target: Union[str, List[str]], @@ -91,7 +91,7 @@ def _update( # type: ignore self.errors += errors self.total += total - def _compute(self) -> Tensor: + def compute(self) -> Tensor: """Calculate the Match error rate. Returns: diff --git a/torchmetrics/text/rouge.py b/torchmetrics/text/rouge.py index 1397cc75272..b6344803dc5 100644 --- a/torchmetrics/text/rouge.py +++ b/torchmetrics/text/rouge.py @@ -132,7 +132,7 @@ def __init__( for score in ["fmeasure", "precision", "recall"]: self.add_state(f"{rouge_key}_{score}", [], dist_reduce_fx=None) - def _update( # type: ignore + def update( # type: ignore self, preds: Union[str, Sequence[str]], target: Union[str, Sequence[str], Sequence[Sequence[str]]] ) -> None: """Compute rouge scores. @@ -167,7 +167,7 @@ def _update( # type: ignore for tp, value in metric.items(): getattr(self, f"rouge{rouge_key}_{tp}").append(value.to(self.device)) - def _compute(self) -> Dict[str, Tensor]: + def compute(self) -> Dict[str, Tensor]: """Calculate (Aggregate and provide confidence intervals) ROUGE score. Return: diff --git a/torchmetrics/text/sacre_bleu.py b/torchmetrics/text/sacre_bleu.py index 678d50d50f7..4e8563a336f 100644 --- a/torchmetrics/text/sacre_bleu.py +++ b/torchmetrics/text/sacre_bleu.py @@ -99,7 +99,7 @@ def __init__( ) self.tokenizer = _SacreBLEUTokenizer(tokenize, lowercase) - def _update(self, preds: Sequence[str], target: Sequence[Sequence[str]]) -> None: # type: ignore + def update(self, preds: Sequence[str], target: Sequence[Sequence[str]]) -> None: # type: ignore """Compute Precision Scores. Args: diff --git a/torchmetrics/text/squad.py b/torchmetrics/text/squad.py index e8142aa6e6b..f0b4b13ef3e 100644 --- a/torchmetrics/text/squad.py +++ b/torchmetrics/text/squad.py @@ -71,7 +71,7 @@ def __init__( self.add_state(name="exact_match", default=torch.tensor(0, dtype=torch.float), dist_reduce_fx="sum") self.add_state(name="total", default=torch.tensor(0, dtype=torch.int), dist_reduce_fx="sum") - def _update(self, preds: PREDS_TYPE, target: TARGETS_TYPE) -> None: # type: ignore + def update(self, preds: PREDS_TYPE, target: TARGETS_TYPE) -> None: # type: ignore """Compute F1 Score and Exact Match for a collection of predictions and references. Args: @@ -118,7 +118,7 @@ def _update(self, preds: PREDS_TYPE, target: TARGETS_TYPE) -> None: # type: ign self.exact_match += exact_match self.total += total - def _compute(self) -> Dict[str, Tensor]: + def compute(self) -> Dict[str, Tensor]: """Aggregate the F1 Score and Exact match for the batch. Return: diff --git a/torchmetrics/text/ter.py b/torchmetrics/text/ter.py index 4dc077bd7d1..b83185c3db0 100644 --- a/torchmetrics/text/ter.py +++ b/torchmetrics/text/ter.py @@ -93,7 +93,7 @@ def __init__( if self.return_sentence_level_score: self.add_state("sentence_ter", [], dist_reduce_fx="cat") - def _update( # type: ignore + def update( # type: ignore self, preds: Union[str, Sequence[str]], target: Sequence[Union[str, Sequence[str]]] ) -> None: """Update TER statistics. @@ -113,7 +113,7 @@ def _update( # type: ignore self.sentence_ter, ) - def _compute(self) -> Union[Tensor, Tuple[Tensor, Tensor]]: + def compute(self) -> Union[Tensor, Tuple[Tensor, Tensor]]: """Calculate the translate error rate (TER). Return: diff --git a/torchmetrics/text/wer.py b/torchmetrics/text/wer.py index 6c61b85ec18..3a38b11e0f7 100644 --- a/torchmetrics/text/wer.py +++ b/torchmetrics/text/wer.py @@ -73,7 +73,7 @@ def __init__( self.add_state("errors", tensor(0, dtype=torch.float), dist_reduce_fx="sum") self.add_state("total", tensor(0, dtype=torch.float), dist_reduce_fx="sum") - def _update(self, preds: Union[str, List[str]], target: Union[str, List[str]]) -> None: # type: ignore + def update(self, preds: Union[str, List[str]], target: Union[str, List[str]]) -> None: # type: ignore """Store references/predictions for computing Word Error Rate scores. Args: @@ -84,7 +84,7 @@ def _update(self, preds: Union[str, List[str]], target: Union[str, List[str]]) - self.errors += errors self.total += total - def _compute(self) -> Tensor: + def compute(self) -> Tensor: """Calculate the word error rate. Returns: diff --git a/torchmetrics/text/wil.py b/torchmetrics/text/wil.py index 0ab361cbc3f..1f0f22e983a 100644 --- a/torchmetrics/text/wil.py +++ b/torchmetrics/text/wil.py @@ -74,7 +74,7 @@ def __init__( self.add_state("target_total", tensor(0.0), dist_reduce_fx="sum") self.add_state("preds_total", tensor(0.0), dist_reduce_fx="sum") - def _update(self, preds: Union[str, List[str]], target: Union[str, List[str]]) -> None: # type: ignore + def update(self, preds: Union[str, List[str]], target: Union[str, List[str]]) -> None: # type: ignore """Store predictions/references for computing Word Information Lost scores. Args: @@ -88,7 +88,7 @@ def _update(self, preds: Union[str, List[str]], target: Union[str, List[str]]) - self.target_total += target_total self.preds_total += preds_total - def _compute(self) -> Tensor: + def compute(self) -> Tensor: """Calculate the Word Information Lost. Returns: diff --git a/torchmetrics/text/wip.py b/torchmetrics/text/wip.py index 6681cad42be..5ad016a0279 100644 --- a/torchmetrics/text/wip.py +++ b/torchmetrics/text/wip.py @@ -74,7 +74,7 @@ def __init__( self.add_state("target_total", tensor(0.0), dist_reduce_fx="sum") self.add_state("preds_total", tensor(0.0), dist_reduce_fx="sum") - def _update(self, preds: Union[str, List[str]], target: Union[str, List[str]]) -> None: # type: ignore + def update(self, preds: Union[str, List[str]], target: Union[str, List[str]]) -> None: # type: ignore """Store predictions/references for computing word Information Preserved scores. Args: @@ -88,7 +88,7 @@ def _update(self, preds: Union[str, List[str]], target: Union[str, List[str]]) - self.target_total += target_total self.preds_total += preds_total - def _compute(self) -> Tensor: + def compute(self) -> Tensor: """Calculate the word Information Preserved. Returns: diff --git a/torchmetrics/wrappers/bootstrapping.py b/torchmetrics/wrappers/bootstrapping.py index 9e715bff20a..6c95da1e52e 100644 --- a/torchmetrics/wrappers/bootstrapping.py +++ b/torchmetrics/wrappers/bootstrapping.py @@ -131,7 +131,7 @@ def __init__( ) self.sampling_strategy = sampling_strategy - def _update(self, *args: Any, **kwargs: Any) -> None: + def update(self, *args: Any, **kwargs: Any) -> None: """Updates the state of the base metric. Any tensor passed in will be bootstrapped along dimension 0 @@ -150,7 +150,7 @@ def _update(self, *args: Any, **kwargs: Any) -> None: new_kwargs = apply_to_collection(kwargs, Tensor, torch.index_select, dim=0, index=sample_idx) self.metrics[idx].update(*new_args, **new_kwargs) - def _compute(self) -> Dict[str, Tensor]: + def compute(self) -> Dict[str, Tensor]: """Computes the bootstrapped metric values. Allways returns a dict of tensors, which can contain the following keys: ``mean``, ``std``, ``quantile`` and @@ -167,8 +167,3 @@ def _compute(self) -> Dict[str, Tensor]: if self.raw: output_dict["raw"] = computed_vals return output_dict - - def reset(self) -> None: - """Reset all underlying metrics.""" - for metric in self.metrics: - metric.reset() diff --git a/torchmetrics/wrappers/classwise.py b/torchmetrics/wrappers/classwise.py index 5dbf4ed61a6..63f519a5403 100644 --- a/torchmetrics/wrappers/classwise.py +++ b/torchmetrics/wrappers/classwise.py @@ -67,11 +67,8 @@ def _convert(self, x: Tensor) -> Dict[str, Any]: return {f"{name}_{i}": val for i, val in enumerate(x)} return {f"{name}_{lab}": val for lab, val in zip(self.labels, x)} - def _update(self, *args: Any, **kwargs: Any) -> None: + def update(self, *args: Any, **kwargs: Any) -> None: self.metric.update(*args, **kwargs) - def _compute(self) -> Dict[str, Tensor]: + def compute(self) -> Dict[str, Tensor]: return self._convert(self.metric.compute()) - - def reset(self) -> None: - self.metric.reset() diff --git a/torchmetrics/wrappers/minmax.py b/torchmetrics/wrappers/minmax.py index e3a5db042e4..f8451eb3058 100644 --- a/torchmetrics/wrappers/minmax.py +++ b/torchmetrics/wrappers/minmax.py @@ -76,11 +76,11 @@ def __init__( self.register_buffer("min_val", torch.tensor(float("inf"))) self.register_buffer("max_val", torch.tensor(float("-inf"))) - def _update(self, *args: Any, **kwargs: Any) -> None: # type: ignore + def update(self, *args: Any, **kwargs: Any) -> None: # type: ignore """Updates the underlying metric.""" self._base_metric.update(*args, **kwargs) - def _compute(self) -> Dict[str, Tensor]: # type: ignore + def compute(self) -> Dict[str, Tensor]: # type: ignore """Computes the underlying metric as well as max and min values for this metric. Returns a dictionary that consists of the computed value (``raw``), as well as the minimum (``min``) and maximum diff --git a/torchmetrics/wrappers/multioutput.py b/torchmetrics/wrappers/multioutput.py index 65f43b44b5f..b1f55d3c6d8 100644 --- a/torchmetrics/wrappers/multioutput.py +++ b/torchmetrics/wrappers/multioutput.py @@ -90,10 +90,6 @@ def __init__( squeeze_outputs: bool = True, ): super().__init__() - if not isinstance(base_metric, Metric): - raise ValueError( - "Expected base metric to be an instance of torchmetrics.Metric" f" but received {base_metric}" - ) self.metrics = nn.ModuleList([deepcopy(base_metric) for _ in range(num_outputs)]) self.output_dim = output_dim self.remove_nans = remove_nans @@ -122,13 +118,13 @@ def _get_args_kwargs_by_output( args_kwargs_by_output.append((selected_args, selected_kwargs)) return args_kwargs_by_output - def _update(self, *args: Any, **kwargs: Any) -> None: + def update(self, *args: Any, **kwargs: Any) -> None: """Update each underlying metric with the corresponding output.""" reshaped_args_kwargs = self._get_args_kwargs_by_output(*args, **kwargs) for metric, (selected_args, selected_kwargs) in zip(self.metrics, reshaped_args_kwargs): metric.update(*selected_args, **selected_kwargs) - def _compute(self) -> List[torch.Tensor]: + def compute(self) -> List[torch.Tensor]: """Compute metrics.""" return [m.compute() for m in self.metrics]