Skip to content

Commit

Permalink
Merge branch 'master' into fix_dtype
Browse files Browse the repository at this point in the history
  • Loading branch information
SkafteNicki authored Sep 7, 2021
2 parents 8314b3b + f96c717 commit 13e170a
Show file tree
Hide file tree
Showing 5 changed files with 165 additions and 29 deletions.
9 changes: 9 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,17 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added Tweedie Deviance Score ([#499](https://github.com/PyTorchLightning/metrics/pull/499))


- Added support for float targets in `nDCG` metric ([#437](https://github.com/PyTorchLightning/metrics/pull/437))


- Added `average` argument to `AveragePrecision` metric for reducing multilabel and multiclass problems ([#477](https://github.com/PyTorchLightning/metrics/pull/477))


### Changed

- `AveragePrecision` will now as default output the `macro` average for multilabel and multiclass problems ([#477](https://github.com/PyTorchLightning/metrics/pull/477))


- `half`, `double`, `float` will no longer change the dtype of the metric states. Use `metric.set_dtype` instead ([#493](https://github.com/PyTorchLightning/metrics/pull/493))


Expand Down
74 changes: 58 additions & 16 deletions tests/classification/test_average_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
seed_all(42)


def _sk_average_precision_score(y_true, probas_pred, num_classes=1):
def _sk_average_precision_score(y_true, probas_pred, num_classes=1, average=None):
if num_classes == 1:
return sk_average_precision_score(y_true, probas_pred)

Expand All @@ -39,33 +39,41 @@ def _sk_average_precision_score(y_true, probas_pred, num_classes=1):
y_true_temp = np.zeros_like(y_true)
y_true_temp[y_true == i] = 1
res.append(sk_average_precision_score(y_true_temp, probas_pred[:, i]))

if average == "macro":
return np.array(res).mean()
elif average == "weighted":
weights = np.bincount(y_true) if y_true.max() > 1 else y_true.sum(axis=0)
weights = weights / sum(weights)
return (np.array(res) * weights).sum()

return res


def _sk_avg_prec_binary_prob(preds, target, num_classes=1):
def _sk_avg_prec_binary_prob(preds, target, num_classes=1, average=None):
sk_preds = preds.view(-1).numpy()
sk_target = target.view(-1).numpy()

return _sk_average_precision_score(y_true=sk_target, probas_pred=sk_preds, num_classes=num_classes)
return _sk_average_precision_score(y_true=sk_target, probas_pred=sk_preds, num_classes=num_classes, average=average)


def _sk_avg_prec_multiclass_prob(preds, target, num_classes=1):
def _sk_avg_prec_multiclass_prob(preds, target, num_classes=1, average=None):
sk_preds = preds.reshape(-1, num_classes).numpy()
sk_target = target.view(-1).numpy()

return _sk_average_precision_score(y_true=sk_target, probas_pred=sk_preds, num_classes=num_classes)
return _sk_average_precision_score(y_true=sk_target, probas_pred=sk_preds, num_classes=num_classes, average=average)


def _sk_avg_prec_multilabel_prob(preds, target, num_classes):
def _sk_avg_prec_multilabel_prob(preds, target, num_classes=1, average=None):
sk_preds = preds.reshape(-1, num_classes).numpy()
sk_target = target.view(-1, num_classes).numpy()
return sk_average_precision_score(sk_target, sk_preds, average=None)
return sk_average_precision_score(sk_target, sk_preds, average=average)


def _sk_avg_prec_multidim_multiclass_prob(preds, target, num_classes=1):
def _sk_avg_prec_multidim_multiclass_prob(preds, target, num_classes=1, average=None):
sk_preds = preds.transpose(0, 1).reshape(num_classes, -1).transpose(0, 1).numpy()
sk_target = target.view(-1).numpy()
return _sk_average_precision_score(y_true=sk_target, probas_pred=sk_preds, num_classes=num_classes)
return _sk_average_precision_score(y_true=sk_target, probas_pred=sk_preds, num_classes=num_classes, average=average)


@pytest.mark.parametrize(
Expand All @@ -77,30 +85,37 @@ def _sk_avg_prec_multidim_multiclass_prob(preds, target, num_classes=1):
(_input_multilabel.preds, _input_multilabel.target, _sk_avg_prec_multilabel_prob, NUM_CLASSES),
],
)
@pytest.mark.parametrize("average", ["micro", "macro", "weighted", None])
class TestAveragePrecision(MetricTester):
@pytest.mark.parametrize("ddp", [True, False])
@pytest.mark.parametrize("dist_sync_on_step", [True, False])
def test_average_precision(self, preds, target, sk_metric, num_classes, ddp, dist_sync_on_step):
def test_average_precision(self, preds, target, sk_metric, num_classes, average, ddp, dist_sync_on_step):
if target.max() > 1 and average == "micro":
pytest.skip("average=micro and multiclass input cannot be used together")

self.run_class_metric_test(
ddp=ddp,
preds=preds,
target=target,
metric_class=AveragePrecision,
sk_metric=partial(sk_metric, num_classes=num_classes),
sk_metric=partial(sk_metric, num_classes=num_classes, average=average),
dist_sync_on_step=dist_sync_on_step,
metric_args={"num_classes": num_classes},
metric_args={"num_classes": num_classes, "average": average},
)

def test_average_precision_functional(self, preds, target, sk_metric, num_classes):
def test_average_precision_functional(self, preds, target, sk_metric, num_classes, average):
if target.max() > 1 and average == "micro":
pytest.skip("average=micro and multiclass input cannot be used together")

self.run_functional_metric_test(
preds=preds,
target=target,
metric_functional=average_precision,
sk_metric=partial(sk_metric, num_classes=num_classes),
metric_args={"num_classes": num_classes},
sk_metric=partial(sk_metric, num_classes=num_classes, average=average),
metric_args={"num_classes": num_classes, "average": average},
)

def test_average_precision_differentiability(self, preds, sk_metric, target, num_classes):
def test_average_precision_differentiability(self, preds, sk_metric, target, num_classes, average):
self.run_differentiability_test(
preds=preds,
target=target,
Expand All @@ -126,3 +141,30 @@ def test_average_precision_differentiability(self, preds, sk_metric, target, num
)
def test_average_precision(scores, target, expected_score):
assert average_precision(scores, target) == expected_score


def test_average_precision_warnings_and_errors():
"""Test that the correct errors and warnings gets raised."""

# check average argument
with pytest.raises(ValueError, match="Expected argument `average` to be one .*"):
AveragePrecision(num_classes=5, average="samples")

# check that micro average cannot be used with multilabel input
pred = tensor(
[
[0.75, 0.05, 0.05, 0.05, 0.05],
[0.05, 0.75, 0.05, 0.05, 0.05],
[0.05, 0.05, 0.75, 0.05, 0.05],
[0.05, 0.05, 0.05, 0.75, 0.05],
]
)
target = tensor([0, 1, 3, 2])
average_precision = AveragePrecision(num_classes=5, average="micro")
with pytest.raises(ValueError, match="Cannot use `micro` average with multi-class input"):
average_precision(pred, target)

# check that warning is thrown when average=macro and nan is encoutered in individual scores
average_precision = AveragePrecision(num_classes=5, average="macro")
with pytest.warns(UserWarning, match="Average precision score for one or more classes was `nan`.*"):
average_precision(pred, target)
24 changes: 21 additions & 3 deletions torchmetrics/classification/average_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,19 @@ class AveragePrecision(Metric):
which for binary problem is translate to 1. For multiclass problems
this argument should not be set as we iteratively change it in the
range [0,num_classes-1]
average:
defines the reduction that is applied in the case of multiclass and multilabel input.
Should be one of the following:
- ``'macro'`` [default]: Calculate the metric for each class separately, and average the
metrics across classes (with equal weights for each class).
- ``'micro'``: Calculate the metric globally, across all samples and classes. Cannot be
used with multiclass input.
- ``'weighted'``: Calculate the metric for each class separately, and average the
metrics across classes, weighting each class by its support.
- ``'none'`` or ``None``: Calculate the metric for each class separately, and return
the metric for every class.
compute_on_step:
Forward only calls ``update()`` and return None if this is set to False. default: True
dist_sync_on_step:
Expand All @@ -66,7 +79,7 @@ class AveragePrecision(Metric):
... [0.05, 0.05, 0.75, 0.05, 0.05],
... [0.05, 0.05, 0.05, 0.75, 0.05]])
>>> target = torch.tensor([0, 1, 3, 2])
>>> average_precision = AveragePrecision(num_classes=5)
>>> average_precision = AveragePrecision(num_classes=5, average=None)
>>> average_precision(pred, target)
[tensor(1.), tensor(1.), tensor(0.2500), tensor(0.2500), tensor(nan)]
"""
Expand All @@ -78,6 +91,7 @@ def __init__(
self,
num_classes: Optional[int] = None,
pos_label: Optional[int] = None,
average: Optional[str] = "macro",
compute_on_step: bool = True,
dist_sync_on_step: bool = False,
process_group: Optional[Any] = None,
Expand All @@ -90,6 +104,10 @@ def __init__(

self.num_classes = num_classes
self.pos_label = pos_label
allowed_average = ("micro", "macro", "weighted", None)
if average not in allowed_average:
raise ValueError(f"Expected argument `average` to be one of {allowed_average}" f" but got {average}")
self.average = average

self.add_state("preds", default=[], dist_reduce_fx="cat")
self.add_state("target", default=[], dist_reduce_fx="cat")
Expand All @@ -107,7 +125,7 @@ def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore
target: Ground truth values
"""
preds, target, num_classes, pos_label = _average_precision_update(
preds, target, self.num_classes, self.pos_label
preds, target, self.num_classes, self.pos_label, self.average
)
self.preds.append(preds)
self.target.append(target)
Expand All @@ -125,7 +143,7 @@ def compute(self) -> Union[Tensor, List[Tensor]]:
target = dim_zero_cat(self.target)
if not self.num_classes:
raise ValueError(f"`num_classes` bas to be positive number, but got {self.num_classes}")
return _average_precision_compute(preds, target, self.num_classes, self.pos_label)
return _average_precision_compute(preds, target, self.num_classes, self.pos_label, self.average)

@property
def is_differentiable(self) -> bool:
Expand Down
2 changes: 1 addition & 1 deletion torchmetrics/classification/binned_precision_recall.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ class BinnedAveragePrecision(BinnedPrecisionRecallCurve):

def compute(self) -> Union[List[Tensor], Tensor]: # type: ignore
precisions, recalls, _ = super().compute()
return _average_precision_compute_with_precision_recall(precisions, recalls, self.num_classes)
return _average_precision_compute_with_precision_recall(precisions, recalls, self.num_classes, average=None)


class BinnedRecallAtFixedPrecision(BinnedPrecisionRecallCurve):
Expand Down
Loading

0 comments on commit 13e170a

Please sign in to comment.