Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Mean Absolute Percentage Error #248

Merged
merged 45 commits into from
Jun 10, 2021
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
0f81384
initial code, test failing
pranjaldatta May 14, 2021
844fc9e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 14, 2021
ecccca9
increase requirement
SkafteNicki May 15, 2021
b04fa68
Merge branch 'master' into metrics/mape
SkafteNicki May 15, 2021
62107cd
Apply suggestions from code review
SkafteNicki May 15, 2021
a61efa6
added docs+changelog and other review changes
pranjaldatta May 15, 2021
df634e3
consistent shorted type hints
pranjaldatta May 16, 2021
2e2d582
fix tests
SkafteNicki May 18, 2021
8515662
remove arg
SkafteNicki May 18, 2021
b862861
Merge branch 'master' into metrics/mape
SkafteNicki May 18, 2021
e3c2db9
Merge branch 'master' into metrics/mape
mergify[bot] May 18, 2021
ef8e342
Merge branch 'master' into metrics/mape
pranjaldatta May 20, 2021
8b01f0b
minor merge fixes
pranjaldatta May 20, 2021
fdc3c3b
Merge branch 'master' into metrics/mape
mergify[bot] May 23, 2021
6a97671
Merge branch 'master' into metrics/mape
mergify[bot] May 25, 2021
53c218f
Merge branch 'master' into metrics/mape
mergify[bot] May 27, 2021
6c35769
Merge branch 'master' into metrics/mape
mergify[bot] May 28, 2021
e0a540d
Merge branch 'master' into metrics/mape
mergify[bot] May 31, 2021
29364ae
Merge branch 'master' into metrics/mape
mergify[bot] May 31, 2021
a20720e
Merge branch 'master' into metrics/mape
mergify[bot] Jun 3, 2021
ee9c22b
Merge branch 'master' into metrics/mape
mergify[bot] Jun 3, 2021
f0498d6
Merge branch 'master' into metrics/mape
mergify[bot] Jun 8, 2021
537302b
Merge branch 'master' into metrics/mape
mergify[bot] Jun 8, 2021
d7ceccf
Merge branch 'master' into metrics/mape
mergify[bot] Jun 8, 2021
31d3435
Merge branch 'master' into metrics/mape
mergify[bot] Jun 9, 2021
b5e7590
Merge branch 'master' into metrics/mape
mergify[bot] Jun 9, 2021
6046312
Merge branch 'master' into metrics/mape
mergify[bot] Jun 9, 2021
7b93371
Merge branch 'master' into metrics/mape
mergify[bot] Jun 9, 2021
1f490e6
Merge branch 'master' into metrics/mape
mergify[bot] Jun 9, 2021
e41b49a
Merge branch 'master' into metrics/mape
mergify[bot] Jun 9, 2021
fc99d2a
Merge branch 'master' into metrics/mape
mergify[bot] Jun 9, 2021
6dad077
Merge branch 'master' into metrics/mape
mergify[bot] Jun 9, 2021
dca19b5
dep warning added to mean_rel_err
pranjaldatta Jun 9, 2021
7f6d9da
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 9, 2021
f608e56
Apply suggestions from code review
Borda Jun 9, 2021
ef336e1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 9, 2021
f679d44
Merge branch 'master' into metrics/mape
mergify[bot] Jun 9, 2021
eae2eef
applied suggestions from code review
pranjaldatta Jun 9, 2021
c8709d3
update
Borda Jun 9, 2021
13574f8
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 9, 2021
2ffd6a4
added comments regarding test case fixes
pranjaldatta Jun 9, 2021
dd860da
removed unused imports
pranjaldatta Jun 9, 2021
3f48a8d
,
Borda Jun 9, 2021
8064453
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 9, 2021
65e09f0
Merge branch 'master' into metrics/mape
mergify[bot] Jun 10, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 17 additions & 3 deletions tests/regression/test_mean_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,24 @@
import pytest
import torch
from sklearn.metrics import mean_absolute_error as sk_mean_absolute_error
from sklearn.metrics import mean_absolute_percentage_error as sk_mean_abs_percentage_error
from sklearn.metrics import mean_squared_error as sk_mean_squared_error
from sklearn.metrics import mean_squared_log_error as sk_mean_squared_log_error

from tests.helpers import seed_all
from tests.helpers.testers import BATCH_SIZE, NUM_BATCHES, MetricTester
from torchmetrics.functional import mean_absolute_error, mean_squared_error, mean_squared_log_error
from torchmetrics.regression import MeanAbsoluteError, MeanSquaredError, MeanSquaredLogError
from torchmetrics.functional import (
mean_absolute_error,
mean_absolute_percentage_error,
mean_squared_error,
mean_squared_log_error,
)
from torchmetrics.regression import (
MeanAbsoluteError,
MeanAbsolutePercentageError,
MeanSquaredError,
MeanSquaredLogError,
)
from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_6

seed_all(42)
Expand Down Expand Up @@ -68,6 +79,7 @@ def _multi_target_sk_metric(preds, target, sk_fn=mean_squared_error):
(MeanSquaredError, mean_squared_error, sk_mean_squared_error),
(MeanAbsoluteError, mean_absolute_error, sk_mean_absolute_error),
(MeanSquaredLogError, mean_squared_log_error, sk_mean_squared_log_error),
(MeanAbsolutePercentageError, mean_absolute_percentage_error, sk_mean_abs_percentage_error),
],
)
class TestMeanError(MetricTester):
Expand Down Expand Up @@ -115,7 +127,9 @@ def test_mean_error_half_gpu(self, preds, target, sk_metric, metric_class, metri
self.run_precision_test_gpu(preds, target, metric_class, metric_functional)


@pytest.mark.parametrize("metric_class", [MeanSquaredError, MeanAbsoluteError, MeanSquaredLogError])
@pytest.mark.parametrize(
"metric_class", [MeanSquaredError, MeanAbsoluteError, MeanSquaredLogError, MeanAbsolutePercentageError]
)
def test_error_on_different_shape(metric_class):
metric = metric_class()
with pytest.raises(RuntimeError, match='Predictions and targets are expected to have the same shape'):
Expand Down
3 changes: 3 additions & 0 deletions torchmetrics/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@
from torchmetrics.functional.nlp import bleu_score # noqa: F401
from torchmetrics.functional.regression.explained_variance import explained_variance # noqa: F401
from torchmetrics.functional.regression.mean_absolute_error import mean_absolute_error # noqa: F401
from torchmetrics.functional.regression.mean_absolute_percentage_error import ( # noqa: F401
mean_absolute_percentage_error,
)
from torchmetrics.functional.regression.mean_relative_error import mean_relative_error # noqa: F401
from torchmetrics.functional.regression.mean_squared_error import mean_squared_error # noqa: F401
from torchmetrics.functional.regression.mean_squared_log_error import mean_squared_log_error # noqa: F401
Expand Down
3 changes: 3 additions & 0 deletions torchmetrics/functional/regression/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
# limitations under the License.
from torchmetrics.functional.regression.explained_variance import explained_variance # noqa: F401
from torchmetrics.functional.regression.mean_absolute_error import mean_absolute_error # noqa: F401
from torchmetrics.functional.regression.mean_absolute_percentage_error import ( # noqa: F401
mean_absolute_percentage_error,
)
from torchmetrics.functional.regression.mean_squared_error import mean_squared_error # noqa: F401
from torchmetrics.functional.regression.mean_squared_log_error import mean_squared_log_error # noqa: F401
from torchmetrics.functional.regression.pearson import pearson_corrcoef # noqa: F401
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Tuple

import numpy as np
pranjaldatta marked this conversation as resolved.
Show resolved Hide resolved
import torch
from torch import Tensor, tensor
from torch._C import dtype

from torchmetrics.utilities.checks import _check_same_shape


def _mean_absolute_percentage_error_update(preds: torch.Tensor, target: torch.Tensor, eps: torch.Tensor) -> Tuple[Tensor, int]:

_check_same_shape(preds, target)

#eps = torch.tensor(1.17e-07)# torch.tensor(np.finfo(np.float32).eps)
SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved

abs_diff = torch.abs(preds - target)
abs_per_error = abs_diff / torch.max(eps, torch.abs(target))

sum_abs_per_error = torch.sum(abs_per_error)

num_obs = target.numel()

return sum_abs_per_error, num_obs


def _mean_absolute_percentage_error_compute(sum_abs_per_error: Tensor, num_obs: int) -> Tensor:

return sum_abs_per_error / num_obs


def mean_absolute_percentage_error(preds: torch.Tensor, target: torch.Tensor, eps: float= 1.17e-07) -> Tensor:
pranjaldatta marked this conversation as resolved.
Show resolved Hide resolved
"""something"""
pranjaldatta marked this conversation as resolved.
Show resolved Hide resolved

eps = torch.tensor(eps)
sum_abs_per_error, num_obs = _mean_absolute_percentage_error_update(preds, target, eps)
mean_ape = _mean_absolute_percentage_error_compute(sum_abs_per_error, num_obs)

return mean_ape
1 change: 1 addition & 0 deletions torchmetrics/regression/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
from torchmetrics.regression.explained_variance import ExplainedVariance # noqa: F401
from torchmetrics.regression.mean_absolute_error import MeanAbsoluteError # noqa: F401
from torchmetrics.regression.mean_absolute_percentage_error import MeanAbsolutePercentageError # noqa: F401
from torchmetrics.regression.mean_squared_error import MeanSquaredError # noqa: F401
from torchmetrics.regression.mean_squared_log_error import MeanSquaredLogError # noqa: F401
from torchmetrics.regression.pearson import PearsonCorrcoef # noqa: F401
Expand Down
58 changes: 58 additions & 0 deletions torchmetrics/regression/mean_absolute_percentage_error.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Optional

import torch
from torch import Tensor, tensor

from torchmetrics.functional.regression.mean_absolute_percentage_error import (
_mean_absolute_percentage_error_compute,
_mean_absolute_percentage_error_update,
)
from torchmetrics.metric import Metric


class MeanAbsolutePercentageError(Metric):
Borda marked this conversation as resolved.
Show resolved Hide resolved
def __init__(
self,
eps: float = 1.17e-07,
compute_on_step: bool = True,
dist_sync_on_step: bool = False,
process_group: Optional[Any] = None,
dist_sync_fn: Callable = None,
):
pranjaldatta marked this conversation as resolved.
Show resolved Hide resolved
super().__init__(
compute_on_step=compute_on_step,
dist_sync_on_step=dist_sync_on_step,
process_group=process_group,
dist_sync_fn=dist_sync_fn,
)

self.add_state("sum_abs_per_error", default=tensor(0.0), dist_reduce_fx="sum")
self.add_state("total", default=tensor(0.0), dist_reduce_fx="sum")
self.eps = torch.tensor(eps)

def update(self, preds: Tensor, target: Tensor) -> None:

sum_abs_per_error, num_obs = _mean_absolute_percentage_error_update(preds, target, eps=self.eps)

self.sum_abs_per_error += sum_abs_per_error
self.total += num_obs

def compute(self):
SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved
return _mean_absolute_percentage_error_compute(self.sum_abs_per_error, self.total)

@property
def is_differentiable(self):
SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved
return False