Skip to content

Commit

Permalink
Fix num_classes arg in F1 metric (#5663)
Browse files Browse the repository at this point in the history
* fix f1 metric

* Apply suggestions from code review

* chlog

Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
Co-authored-by: Jirka Borovec <jirka.borovec@seznam.cz>
  • Loading branch information
4 people committed Feb 5, 2021
1 parent 1feff5d commit 605c5a8
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 7 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Fixed

- Fixed `num_classes` argument in F1 metric ([#5663](https://github.com/PyTorchLightning/pytorch-lightning/pull/5663))

- Fixed support custom DataLoader with DDP if they can be re-instantiated ([#5745](https://github.com/PyTorchLightning/pytorch-lightning/pull/5745))

Expand Down
3 changes: 2 additions & 1 deletion pytorch_lightning/metrics/classification/f_beta.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ class F1(FBeta):

def __init__(
self,
num_classes: int = 1,
num_classes: int,
threshold: float = 0.5,
average: str = "micro",
multilabel: bool = False,
Expand All @@ -201,6 +201,7 @@ def __init__(
beta=1.0,
threshold=threshold,
average=average,
multilabel=multilabel,
compute_on_step=compute_on_step,
dist_sync_on_step=dist_sync_on_step,
process_group=process_group,
Expand Down
14 changes: 8 additions & 6 deletions tests/metrics/classification/test_f_beta.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch
from sklearn.metrics import fbeta_score

from pytorch_lightning.metrics import FBeta
from pytorch_lightning.metrics import F1, FBeta
from pytorch_lightning.metrics.functional import f1, fbeta
from tests.metrics.classification.inputs import (
_binary_inputs,
Expand Down Expand Up @@ -97,22 +97,23 @@ def _sk_fbeta_multidim_multiclass(preds, target, average='micro', beta=1.0):
],
)
@pytest.mark.parametrize("average", ['micro', 'macro', 'weighted', None])
@pytest.mark.parametrize("beta", [0.5, 1.0])
@pytest.mark.parametrize("beta", [0.5, 1.0, 2.0])
class TestFBeta(MetricTester):
@pytest.mark.parametrize("ddp", [True, False])
@pytest.mark.parametrize("dist_sync_on_step", [True, False])
def test_fbeta(
self, preds, target, sk_metric, num_classes, multilabel, average, beta, ddp, dist_sync_on_step
):
metric_class = F1 if beta == 1.0 else partial(FBeta, beta=beta)

self.run_class_metric_test(
ddp=ddp,
preds=preds,
target=target,
metric_class=FBeta,
metric_class=metric_class,
sk_metric=partial(sk_metric, average=average, beta=beta),
dist_sync_on_step=dist_sync_on_step,
metric_args={
"beta": beta,
"num_classes": num_classes,
"average": average,
"multilabel": multilabel,
Expand All @@ -125,12 +126,13 @@ def test_fbeta(
def test_fbeta_functional(
self, preds, target, sk_metric, num_classes, multilabel, average, beta
):
metric_functional = f1 if beta == 1.0 else partial(fbeta, beta=beta)

self.run_functional_metric_test(preds=preds,
target=target,
metric_functional=fbeta,
metric_functional=metric_functional,
sk_metric=partial(sk_metric, average=average, beta=beta),
metric_args={
"beta": beta,
"num_classes": num_classes,
"average": average,
"multilabel": multilabel,
Expand Down

0 comments on commit 605c5a8

Please sign in to comment.