Skip to content

Commit

Permalink
Fixed default value for mdmc_average in Accuracy (#1036)
Browse files Browse the repository at this point in the history
* Fixed default value for mdmc_average in Accuracy as per documentation
* fix tests
* changelog
* typo

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
Co-authored-by: Skaftenicki <skaftenicki@gmail.com>
Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
  • Loading branch information
4 people authored May 26, 2022
1 parent 4a28149 commit fa44471
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 27 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Fixed case where `KLDivergence` could output `Nan` ([#1030](https://github.com/PyTorchLightning/metrics/pull/1030))


- Fixed default value for `mdmc_average` in `Accuracy` ([#1036](https://github.com/PyTorchLightning/metrics/pull/1036))


## [0.8.2] - 2022-05-06


Expand Down
52 changes: 26 additions & 26 deletions tests/classification/test_accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,44 +58,44 @@ def _sk_accuracy(preds, target, subset_accuracy):


@pytest.mark.parametrize(
"preds, target, subset_accuracy",
"preds, target, subset_accuracy, mdmc_average",
[
(_input_binary_logits.preds, _input_binary_logits.target, False),
(_input_binary_prob.preds, _input_binary_prob.target, False),
(_input_binary.preds, _input_binary.target, False),
(_input_mlb_prob.preds, _input_mlb_prob.target, True),
(_input_mlb_logits.preds, _input_mlb_logits.target, False),
(_input_mlb_prob.preds, _input_mlb_prob.target, False),
(_input_mlb.preds, _input_mlb.target, True),
(_input_mlb.preds, _input_mlb.target, False),
(_input_mcls_prob.preds, _input_mcls_prob.target, False),
(_input_mcls_logits.preds, _input_mcls_logits.target, False),
(_input_mcls.preds, _input_mcls.target, False),
(_input_mdmc_prob.preds, _input_mdmc_prob.target, False),
(_input_mdmc_prob.preds, _input_mdmc_prob.target, True),
(_input_mdmc.preds, _input_mdmc.target, False),
(_input_mdmc.preds, _input_mdmc.target, True),
(_input_mlmd_prob.preds, _input_mlmd_prob.target, True),
(_input_mlmd_prob.preds, _input_mlmd_prob.target, False),
(_input_mlmd.preds, _input_mlmd.target, True),
(_input_mlmd.preds, _input_mlmd.target, False),
(_input_binary_logits.preds, _input_binary_logits.target, False, None),
(_input_binary_prob.preds, _input_binary_prob.target, False, None),
(_input_binary.preds, _input_binary.target, False, None),
(_input_mlb_prob.preds, _input_mlb_prob.target, True, None),
(_input_mlb_logits.preds, _input_mlb_logits.target, False, None),
(_input_mlb_prob.preds, _input_mlb_prob.target, False, None),
(_input_mlb.preds, _input_mlb.target, True, None),
(_input_mlb.preds, _input_mlb.target, False, "global"),
(_input_mcls_prob.preds, _input_mcls_prob.target, False, None),
(_input_mcls_logits.preds, _input_mcls_logits.target, False, None),
(_input_mcls.preds, _input_mcls.target, False, None),
(_input_mdmc_prob.preds, _input_mdmc_prob.target, False, "global"),
(_input_mdmc_prob.preds, _input_mdmc_prob.target, True, None),
(_input_mdmc.preds, _input_mdmc.target, False, "global"),
(_input_mdmc.preds, _input_mdmc.target, True, None),
(_input_mlmd_prob.preds, _input_mlmd_prob.target, True, None),
(_input_mlmd_prob.preds, _input_mlmd_prob.target, False, None),
(_input_mlmd.preds, _input_mlmd.target, True, None),
(_input_mlmd.preds, _input_mlmd.target, False, "global"),
],
)
class TestAccuracies(MetricTester):
@pytest.mark.parametrize("ddp", [False, True])
@pytest.mark.parametrize("dist_sync_on_step", [False, True])
def test_accuracy_class(self, ddp, dist_sync_on_step, preds, target, subset_accuracy):
def test_accuracy_class(self, ddp, dist_sync_on_step, preds, target, subset_accuracy, mdmc_average):
self.run_class_metric_test(
ddp=ddp,
preds=preds,
target=target,
metric_class=Accuracy,
sk_metric=partial(_sk_accuracy, subset_accuracy=subset_accuracy),
dist_sync_on_step=dist_sync_on_step,
metric_args={"threshold": THRESHOLD, "subset_accuracy": subset_accuracy},
metric_args={"threshold": THRESHOLD, "subset_accuracy": subset_accuracy, "mdmc_average": mdmc_average},
)

def test_accuracy_fn(self, preds, target, subset_accuracy):
def test_accuracy_fn(self, preds, target, subset_accuracy, mdmc_average):
self.run_functional_metric_test(
preds,
target,
Expand All @@ -104,13 +104,13 @@ def test_accuracy_fn(self, preds, target, subset_accuracy):
metric_args={"threshold": THRESHOLD, "subset_accuracy": subset_accuracy},
)

def test_accuracy_differentiability(self, preds, target, subset_accuracy):
def test_accuracy_differentiability(self, preds, target, subset_accuracy, mdmc_average):
self.run_differentiability_test(
preds=preds,
target=target,
metric_module=Accuracy,
metric_functional=accuracy,
metric_args={"threshold": THRESHOLD, "subset_accuracy": subset_accuracy},
metric_args={"threshold": THRESHOLD, "subset_accuracy": subset_accuracy, "mdmc_average": mdmc_average},
)


Expand Down Expand Up @@ -180,7 +180,7 @@ def test_accuracy_differentiability(self, preds, target, subset_accuracy):
],
)
def test_topk_accuracy(preds, target, exp_result, k, subset_accuracy):
topk = Accuracy(top_k=k, subset_accuracy=subset_accuracy)
topk = Accuracy(top_k=k, subset_accuracy=subset_accuracy, mdmc_average="global")

for batch in range(preds.shape[0]):
topk(preds[batch], target[batch])
Expand Down
2 changes: 1 addition & 1 deletion torchmetrics/classification/accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def __init__(
threshold: float = 0.5,
num_classes: Optional[int] = None,
average: str = "micro",
mdmc_average: Optional[str] = "global",
mdmc_average: Optional[str] = None,
ignore_index: Optional[int] = None,
top_k: Optional[int] = None,
multiclass: Optional[bool] = None,
Expand Down

0 comments on commit fa44471

Please sign in to comment.