Skip to content

Commit

Permalink
Merge branch 'master' into bugfix/ms_ssim_pad
Browse files Browse the repository at this point in the history
  • Loading branch information
Borda authored Apr 3, 2023
2 parents 4345e83 + cb9540f commit f3f0b0c
Show file tree
Hide file tree
Showing 6 changed files with 67 additions and 18 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci-checks.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ jobs:
extra-typing: typing

check-schema:
uses: Lightning-AI/utilities/.github/workflows/check-schema.yml@v0.7.1
uses: Lightning-AI/utilities/.github/workflows/check-schema.yml@v0.8.0

check-package:
uses: Lightning-AI/utilities/.github/workflows/check-package.yml@main
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/clear-cache.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@ jobs:

cron-clear:
if: github.event_name == 'schedule'
uses: Lightning-AI/utilities/.github/workflows/clear-cache.yml@v0.7.1
uses: Lightning-AI/utilities/.github/workflows/clear-cache.yml@v0.8.0
with:
pattern: 'pip-latest'

direct-clear:
if: github.event_name == 'workflow_dispatch'
uses: Lightning-AI/utilities/.github/workflows/clear-cache.yml@v0.7.1
uses: Lightning-AI/utilities/.github/workflows/clear-cache.yml@v0.8.0
with:
pattern: ${{ inputs.pattern }}
4 changes: 2 additions & 2 deletions .github/workflows/publish-pypi.yml
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ jobs:
# We do this, since failures on test.pypi aren't that bad
- name: Publish to Test PyPI
if: startsWith(github.event.ref, 'refs/tags') || github.event_name == 'release'
uses: pypa/gh-action-pypi-publish@v1.6.4
uses: pypa/gh-action-pypi-publish@v1.8.4
with:
user: __token__
password: ${{ secrets.test_pypi_password }}
Expand All @@ -82,7 +82,7 @@ jobs:

- name: Publish distribution 📦 to PyPI
if: startsWith(github.event.ref, 'refs/tags') || github.event_name == 'release'
uses: pypa/gh-action-pypi-publish@v1.6.4
uses: pypa/gh-action-pypi-publish@v1.8.4
with:
user: __token__
password: ${{ secrets.pypi_password }}
17 changes: 13 additions & 4 deletions src/torchmetrics/classification/accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,16 +480,25 @@ def __new__(
) -> Metric:
"""Initialize task metric."""
task = ClassificationTask.from_str(task)

kwargs.update(
{"multidim_average": multidim_average, "ignore_index": ignore_index, "validate_args": validate_args}
)

if task == ClassificationTask.BINARY:
return BinaryAccuracy(threshold, **kwargs)
if task == ClassificationTask.MULTICLASS:
assert isinstance(num_classes, int)
assert isinstance(top_k, int)
if not isinstance(num_classes, int):
raise ValueError(
f"Optional arg `num_classes` must be type `int` when task is {task}. Got {type(num_classes)}"
)
if not isinstance(top_k, int):
raise ValueError(f"Optional arg `top_k` must be type `int` when task is {task}. Got {type(top_k)}")
return MulticlassAccuracy(num_classes, top_k, average, **kwargs)
if task == ClassificationTask.MULTILABEL:
assert isinstance(num_labels, int)
if not isinstance(num_labels, int):
raise ValueError(
f"Optional arg `num_labels` must be type `int` when task is {task}. Got {type(num_labels)}"
)
return MultilabelAccuracy(num_labels, threshold, average, **kwargs)
return None
raise ValueError(f"Not handled value: {task}") # this is for compliant of mypy
19 changes: 13 additions & 6 deletions src/torchmetrics/functional/classification/accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,8 +378,8 @@ def accuracy(
threshold: float = 0.5,
num_classes: Optional[int] = None,
num_labels: Optional[int] = None,
average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro",
multidim_average: Optional[Literal["global", "samplewise"]] = "global",
average: Literal["micro", "macro", "weighted", "none"] = "micro",
multidim_average: Literal["global", "samplewise"] = "global",
top_k: Optional[int] = 1,
ignore_index: Optional[int] = None,
validate_args: bool = True,
Expand Down Expand Up @@ -409,17 +409,24 @@ def accuracy(
tensor(0.6667)
"""
task = ClassificationTask.from_str(task)
assert multidim_average is not None

if task == ClassificationTask.BINARY:
return binary_accuracy(preds, target, threshold, multidim_average, ignore_index, validate_args)
if task == ClassificationTask.MULTICLASS:
assert isinstance(num_classes, int)
assert isinstance(top_k, int)
if not isinstance(num_classes, int):
raise ValueError(
f"Optional arg `num_classes` must be type `int` when task is {task}. Got {type(num_classes)}"
)
if not isinstance(top_k, int):
raise ValueError(f"Optional arg `top_k` must be type `int` when task is {task}. Got {type(top_k)}")
return multiclass_accuracy(
preds, target, num_classes, average, top_k, multidim_average, ignore_index, validate_args
)
if task == ClassificationTask.MULTILABEL:
assert isinstance(num_labels, int)
if not isinstance(num_labels, int):
raise ValueError(
f"Optional arg `num_labels` must be type `int` when task is {task}. Got {type(num_labels)}"
)
return multilabel_accuracy(
preds, target, num_labels, threshold, average, multidim_average, ignore_index, validate_args
)
Expand Down
39 changes: 36 additions & 3 deletions tests/unittests/classification/test_accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,15 @@
from sklearn.metrics import accuracy_score as sk_accuracy
from sklearn.metrics import confusion_matrix as sk_confusion_matrix

from torchmetrics.classification.accuracy import BinaryAccuracy, MulticlassAccuracy, MultilabelAccuracy
from torchmetrics.functional.classification.accuracy import binary_accuracy, multiclass_accuracy, multilabel_accuracy
from torchmetrics.classification.accuracy import Accuracy, BinaryAccuracy, MulticlassAccuracy, MultilabelAccuracy
from torchmetrics.functional.classification.accuracy import (
accuracy,
binary_accuracy,
multiclass_accuracy,
multilabel_accuracy,
)
from unittests import NUM_CLASSES, THRESHOLD
from unittests.classification.inputs import _binary_cases, _multiclass_cases, _multilabel_cases
from unittests.classification.inputs import _binary_cases, _input_binary, _multiclass_cases, _multilabel_cases
from unittests.helpers import seed_all
from unittests.helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index

Expand Down Expand Up @@ -61,6 +66,34 @@ def _sklearn_accuracy_binary(preds, target, ignore_index, multidim_average):
return np.stack(res)


def test_accuracy_raises_invalid_task():
"""Tests accuracy task enum from Accuracy."""
task = "NotValidTask"
ignore_index = None
multidim_average = "global"

with pytest.raises(ValueError, match=r"Invalid *"):
Accuracy(threshold=THRESHOLD, task=task, ignore_index=ignore_index, multidim_average=multidim_average)


def test_accuracy_functional_raises_invalid_task():
"""Tests accuracy task enum from functional.accuracy."""
preds, target = _input_binary
task = "NotValidTask"
ignore_index = None
multidim_average = "global"

with pytest.raises(ValueError, match=r"Invalid *"):
accuracy(
preds,
target,
threshold=THRESHOLD,
task=task,
ignore_index=ignore_index,
multidim_average=multidim_average,
)


@pytest.mark.parametrize("input", _binary_cases)
class TestBinaryAccuracy(MetricTester):
"""Test class for `BinaryAccuracy` metric."""
Expand Down

0 comments on commit f3f0b0c

Please sign in to comment.