diff --git a/CHANGELOG.md b/CHANGELOG.md index 38c8b90397e..5955815102b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -24,6 +24,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added support for unnormalized scores (e.g. logits) in `Accuracy`, `Precision`, `Recall`, `FBeta`, `F1`, `StatScore`, `Hamming`, `ConfusionMatrix` metrics ([#200](https://github.com/PyTorchLightning/metrics/pull/200)) +- Added `MeanAbsolutePercentageError(MAPE)` metric. ([#248](https://github.com/PyTorchLightning/metrics/pull/248)) + + - Added `squared` argument to `MeanSquaredError` for computing `RMSE` ([#249](https://github.com/PyTorchLightning/metrics/pull/249)) @@ -36,6 +39,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Deprecated +- Remove `torchmetrics.functional.mean_relative_error`([#248](https://github.com/PyTorchLightning/metrics/pull/248)) ### Removed diff --git a/docs/source/references/functional.rst b/docs/source/references/functional.rst index d1ac594e17a..7de8d14c065 100644 --- a/docs/source/references/functional.rst +++ b/docs/source/references/functional.rst @@ -188,6 +188,13 @@ mean_absolute_error [func] :noindex: +mean_absolute_percentage_error [func] +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autofunction:: torchmetrics.functional.mean_absolute_percentage_error + :noindex: + + mean_squared_error [func] ~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/docs/source/references/modules.rst b/docs/source/references/modules.rst index 9889d7a6f46..1220cc413ce 100644 --- a/docs/source/references/modules.rst +++ b/docs/source/references/modules.rst @@ -268,6 +268,13 @@ MeanAbsoluteError :noindex: +MeanAbsolutePercentageError +~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: torchmetrics.MeanAbsolutePercentageError + :noindex: + + MeanSquaredError ~~~~~~~~~~~~~~~~ diff --git a/requirements/test.txt b/requirements/test.txt index 606c99e7913..f501671f387 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -15,6 +15,6 @@ yapf>=0.29.0 phmdoctest>=1.1.1 cloudpickle>=1.3 -scikit-learn>0.22.1 +scikit-learn>=0.24 scikit-image>0.17.1 nltk>=3.6 diff --git a/tests/regression/test_mean_error.py b/tests/regression/test_mean_error.py index 8c80f65eef6..5c927c18ee1 100644 --- a/tests/regression/test_mean_error.py +++ b/tests/regression/test_mean_error.py @@ -18,13 +18,24 @@ import pytest import torch from sklearn.metrics import mean_absolute_error as sk_mean_absolute_error +from sklearn.metrics import mean_absolute_percentage_error as sk_mean_abs_percentage_error from sklearn.metrics import mean_squared_error as sk_mean_squared_error from sklearn.metrics import mean_squared_log_error as sk_mean_squared_log_error from tests.helpers import seed_all from tests.helpers.testers import BATCH_SIZE, NUM_BATCHES, MetricTester -from torchmetrics.functional import mean_absolute_error, mean_squared_error, mean_squared_log_error -from torchmetrics.regression import MeanAbsoluteError, MeanSquaredError, MeanSquaredLogError +from torchmetrics.functional import ( + mean_absolute_error, + mean_absolute_percentage_error, + mean_squared_error, + mean_squared_log_error, +) +from torchmetrics.regression import ( + MeanAbsoluteError, + MeanAbsolutePercentageError, + MeanSquaredError, + MeanSquaredLogError, +) from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_6 seed_all(42) @@ -47,14 +58,22 @@ def _single_target_sk_metric(preds, target, sk_fn, metric_args): sk_preds = preds.view(-1).numpy() sk_target = target.view(-1).numpy() - res = sk_fn(sk_preds, sk_target) + + # `sk_target` and `sk_preds` switched to fix failing tests. + # For more info, check https://github.com/PyTorchLightning/metrics/pull/248#issuecomment-841232277 + res = sk_fn(sk_target, sk_preds) + return math.sqrt(res) if (metric_args and not metric_args['squared']) else res def _multi_target_sk_metric(preds, target, sk_fn, metric_args): sk_preds = preds.view(-1, num_targets).numpy() sk_target = target.view(-1, num_targets).numpy() - res = sk_fn(sk_preds, sk_target) + + # `sk_target` and `sk_preds` switched to fix failing tests. + # For more info, check https://github.com/PyTorchLightning/metrics/pull/248#issuecomment-841232277 + res = sk_fn(sk_target, sk_preds) + return math.sqrt(res) if (metric_args and not metric_args['squared']) else res @@ -75,6 +94,7 @@ def _multi_target_sk_metric(preds, target, sk_fn, metric_args): 'squared': False }), (MeanAbsoluteError, mean_absolute_error, sk_mean_absolute_error, {}), + (MeanAbsolutePercentageError, mean_absolute_percentage_error, sk_mean_abs_percentage_error, {}), (MeanSquaredLogError, mean_squared_log_error, sk_mean_squared_log_error, {}), ], ) @@ -124,6 +144,11 @@ def test_mean_error_half_cpu(self, preds, target, sk_metric, metric_class, metri if metric_class == MeanSquaredLogError: # MeanSquaredLogError half + cpu does not work due to missing support in torch.log pytest.xfail("MeanSquaredLogError metric does not support cpu + half precision") + + if metric_class == MeanAbsolutePercentageError: + # MeanSquaredPercentageError half + cpu does not work due to missing support in torch.log + pytest.xfail("MeanSquaredPercentageError metric does not support cpu + half precision") + self.run_precision_test_cpu(preds, target, metric_class, metric_functional) @pytest.mark.skipif(not torch.cuda.is_available(), reason='test requires cuda') @@ -131,7 +156,9 @@ def test_mean_error_half_gpu(self, preds, target, sk_metric, metric_class, metri self.run_precision_test_gpu(preds, target, metric_class, metric_functional) -@pytest.mark.parametrize("metric_class", [MeanSquaredError, MeanAbsoluteError, MeanSquaredLogError]) +@pytest.mark.parametrize( + "metric_class", [MeanSquaredError, MeanAbsoluteError, MeanSquaredLogError, MeanAbsolutePercentageError] +) def test_error_on_different_shape(metric_class): metric = metric_class() with pytest.raises(RuntimeError, match='Predictions and targets are expected to have the same shape'): diff --git a/torchmetrics/__init__.py b/torchmetrics/__init__.py index 487ece2d45b..52a48638a6f 100644 --- a/torchmetrics/__init__.py +++ b/torchmetrics/__init__.py @@ -42,6 +42,7 @@ SSIM, ExplainedVariance, MeanAbsoluteError, + MeanAbsolutePercentageError, MeanSquaredError, MeanSquaredLogError, PearsonCorrcoef, diff --git a/torchmetrics/functional/__init__.py b/torchmetrics/functional/__init__.py index d88936b2635..939987268b5 100644 --- a/torchmetrics/functional/__init__.py +++ b/torchmetrics/functional/__init__.py @@ -32,6 +32,9 @@ from torchmetrics.functional.nlp import bleu_score # noqa: F401 from torchmetrics.functional.regression.explained_variance import explained_variance # noqa: F401 from torchmetrics.functional.regression.mean_absolute_error import mean_absolute_error # noqa: F401 +from torchmetrics.functional.regression.mean_absolute_percentage_error import ( # noqa: F401 + mean_absolute_percentage_error, +) from torchmetrics.functional.regression.mean_relative_error import mean_relative_error # noqa: F401 from torchmetrics.functional.regression.mean_squared_error import mean_squared_error # noqa: F401 from torchmetrics.functional.regression.mean_squared_log_error import mean_squared_log_error # noqa: F401 diff --git a/torchmetrics/functional/regression/__init__.py b/torchmetrics/functional/regression/__init__.py index 28aabb6ec62..a5c8ad5cd4e 100644 --- a/torchmetrics/functional/regression/__init__.py +++ b/torchmetrics/functional/regression/__init__.py @@ -13,6 +13,9 @@ # limitations under the License. from torchmetrics.functional.regression.explained_variance import explained_variance # noqa: F401 from torchmetrics.functional.regression.mean_absolute_error import mean_absolute_error # noqa: F401 +from torchmetrics.functional.regression.mean_absolute_percentage_error import ( # noqa: F401 + mean_absolute_percentage_error, +) from torchmetrics.functional.regression.mean_squared_error import mean_squared_error # noqa: F401 from torchmetrics.functional.regression.mean_squared_log_error import mean_squared_log_error # noqa: F401 from torchmetrics.functional.regression.pearson import pearson_corrcoef # noqa: F401 diff --git a/torchmetrics/functional/regression/mean_absolute_percentage_error.py b/torchmetrics/functional/regression/mean_absolute_percentage_error.py new file mode 100644 index 00000000000..b3b5214c2f7 --- /dev/null +++ b/torchmetrics/functional/regression/mean_absolute_percentage_error.py @@ -0,0 +1,70 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Tuple + +import torch +from torch import Tensor + +from torchmetrics.utilities.checks import _check_same_shape + + +def _mean_absolute_percentage_error_update( + preds: Tensor, + target: Tensor, + epsilon: float = 1.17e-06, +) -> Tuple[Tensor, int]: + + _check_same_shape(preds, target) + + abs_diff = torch.abs(preds - target) + abs_per_error = abs_diff / torch.clamp(torch.abs(target), min=epsilon) + + sum_abs_per_error = torch.sum(abs_per_error) + + num_obs = target.numel() + + return sum_abs_per_error, num_obs + + +def _mean_absolute_percentage_error_compute(sum_abs_per_error: Tensor, num_obs: int) -> Tensor: + return sum_abs_per_error / num_obs + + +def mean_absolute_percentage_error(preds: Tensor, target: Tensor) -> Tensor: + """ + Computes mean absolute percentage error. + + Args: + preds: estimated labels + target: ground truth labels + + Return: + Tensor with MAPE + + Note: + The epsilon value is taken from `scikit-learn's + implementation + `_. + + Example: + >>> from torchmetrics.functional import mean_absolute_percentage_error + >>> target = torch.tensor([1, 10, 1e6]) + >>> preds = torch.tensor([0.9, 15, 1.2e6]) + >>> mean_absolute_percentage_error(preds, target) + tensor(0.2667) + """ + sum_abs_per_error, num_obs = _mean_absolute_percentage_error_update(preds, target) + mean_ape = _mean_absolute_percentage_error_compute(sum_abs_per_error, num_obs) + + return mean_ape diff --git a/torchmetrics/functional/regression/mean_relative_error.py b/torchmetrics/functional/regression/mean_relative_error.py index 286a5c9b687..86db486257f 100644 --- a/torchmetrics/functional/regression/mean_relative_error.py +++ b/torchmetrics/functional/regression/mean_relative_error.py @@ -11,25 +11,15 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Tuple +from warnings import warn import torch from torch import Tensor -from torchmetrics.utilities.checks import _check_same_shape - - -def _mean_relative_error_update(preds: Tensor, target: Tensor) -> Tuple[Tensor, int]: - _check_same_shape(preds, target) - target_nz = target.clone() - target_nz[target == 0] = 1 - sum_rltv_error = torch.sum(torch.abs((preds - target) / target_nz)) - n_obs = target.numel() - return sum_rltv_error, n_obs - - -def _mean_relative_error_compute(sum_rltv_error: Tensor, n_obs: int) -> Tensor: - return sum_rltv_error / n_obs +from torchmetrics.functional.regression.mean_absolute_percentage_error import ( + _mean_absolute_percentage_error_compute, + _mean_absolute_percentage_error_update, +) def mean_relative_error(preds: Tensor, target: Tensor) -> Tensor: @@ -50,6 +40,13 @@ def mean_relative_error(preds: Tensor, target: Tensor) -> Tensor: >>> mean_relative_error(x, y) tensor(0.1250) + .. deprecated:: + Use :func:`torchmetrics.functional.mean_absolute_percentage_error`. Will be removed in v0.5.0. + """ - sum_rltv_error, n_obs = _mean_relative_error_update(preds, target) - return _mean_relative_error_compute(sum_rltv_error, n_obs) + warn( + "Function `mean_relative_error` was deprecated v0.4 and will be removed in v0.5." + "Use `mean_absolute_percentage_error` instead.", DeprecationWarning + ) + sum_rltv_error, n_obs = _mean_absolute_percentage_error_update(preds, target) + return _mean_absolute_percentage_error_compute(sum_rltv_error, n_obs) diff --git a/torchmetrics/regression/__init__.py b/torchmetrics/regression/__init__.py index d10c35ae864..b9d804d954c 100644 --- a/torchmetrics/regression/__init__.py +++ b/torchmetrics/regression/__init__.py @@ -13,6 +13,7 @@ # limitations under the License. from torchmetrics.regression.explained_variance import ExplainedVariance # noqa: F401 from torchmetrics.regression.mean_absolute_error import MeanAbsoluteError # noqa: F401 +from torchmetrics.regression.mean_absolute_percentage_error import MeanAbsolutePercentageError # noqa: F401 from torchmetrics.regression.mean_squared_error import MeanSquaredError # noqa: F401 from torchmetrics.regression.mean_squared_log_error import MeanSquaredLogError # noqa: F401 from torchmetrics.regression.pearson import PearsonCorrcoef # noqa: F401 diff --git a/torchmetrics/regression/mean_absolute_percentage_error.py b/torchmetrics/regression/mean_absolute_percentage_error.py new file mode 100644 index 00000000000..0d9adbe0225 --- /dev/null +++ b/torchmetrics/regression/mean_absolute_percentage_error.py @@ -0,0 +1,101 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Callable, Optional + +import torch +from torch import Tensor, tensor + +from torchmetrics.functional.regression.mean_absolute_percentage_error import ( + _mean_absolute_percentage_error_compute, + _mean_absolute_percentage_error_update, +) +from torchmetrics.metric import Metric + + +class MeanAbsolutePercentageError(Metric): + r""" + Computes `mean absolute percentage error `_ (MAPE): + + .. math:: \text{MAPE} = \frac{1}{n}\sum_1^n\frac{| y_i - \hat{y_i} |}{\max(\epsilon, y_i)} + + Where :math:`y` is a tensor of target values, and :math:`\hat{y}` is a tensor of predictions. + + Args: + compute_on_step: + Forward only calls ``update()`` and return None if this is set to False. default: True + dist_sync_on_step: + Synchronize metric state across processes at each ``forward()`` + before returning the value at the step. default: False + process_group: + Specify the process group on which synchronization is called. default: None (which selects the entire world) + + Note: + The epsilon value is taken from `scikit-learn's implementation + `_. + + Note: + MAPE output is a non-negative floating point. Best result is 0.0 . But it is important to note that, + bad predictions, can lead to arbitarily large values. Especially when some ``target`` values are close to 0. + This implementation returns a very large number instead of ``inf``. + For more information, `read here + `_. + + Example: + >>> from torchmetrics import MeanAbsolutePercentageError + >>> target = torch.tensor([1, 10, 1e6]) + >>> preds = torch.tensor([0.9, 15, 1.2e6]) + >>> mean_abs_percentage_error = MeanAbsolutePercentageError() + >>> mean_abs_percentage_error(preds, target) + tensor(0.2667) + """ + + def __init__( + self, + compute_on_step: bool = True, + dist_sync_on_step: bool = False, + process_group: Optional[Any] = None, + dist_sync_fn: Callable = None, + ): + super().__init__( + compute_on_step=compute_on_step, + dist_sync_on_step=dist_sync_on_step, + process_group=process_group, + dist_sync_fn=dist_sync_fn, + ) + + self.add_state("sum_abs_per_error", default=tensor(0.0), dist_reduce_fx="sum") + self.add_state("total", default=tensor(0.0), dist_reduce_fx="sum") + + def update(self, preds: Tensor, target: Tensor) -> None: + """ + Update state with predictions and targets. + + Args: + preds: Predictions from model + target: Ground truth values + """ + sum_abs_per_error, num_obs = _mean_absolute_percentage_error_update(preds, target) + + self.sum_abs_per_error += sum_abs_per_error + self.total += num_obs + + def compute(self) -> Tensor: + """ + Computes mean absolute percentage error over state. + """ + return _mean_absolute_percentage_error_compute(self.sum_abs_per_error, self.total) + + @property + def is_differentiable(self) -> bool: + return True