diff --git a/.github/workflows/ci-checks.yml b/.github/workflows/ci-checks.yml index e895cec8704..c72ffbd177c 100644 --- a/.github/workflows/ci-checks.yml +++ b/.github/workflows/ci-checks.yml @@ -18,7 +18,7 @@ jobs: extra-typing: typing check-schema: - uses: Lightning-AI/utilities/.github/workflows/check-schema.yml@v0.7.1 + uses: Lightning-AI/utilities/.github/workflows/check-schema.yml@v0.8.0 check-package: uses: Lightning-AI/utilities/.github/workflows/check-package.yml@main diff --git a/.github/workflows/clear-cache.yml b/.github/workflows/clear-cache.yml index cf1a09b2aa5..69525e37fe0 100644 --- a/.github/workflows/clear-cache.yml +++ b/.github/workflows/clear-cache.yml @@ -13,12 +13,12 @@ jobs: cron-clear: if: github.event_name == 'schedule' - uses: Lightning-AI/utilities/.github/workflows/clear-cache.yml@v0.7.1 + uses: Lightning-AI/utilities/.github/workflows/clear-cache.yml@v0.8.0 with: pattern: 'pip-latest' direct-clear: if: github.event_name == 'workflow_dispatch' - uses: Lightning-AI/utilities/.github/workflows/clear-cache.yml@v0.7.1 + uses: Lightning-AI/utilities/.github/workflows/clear-cache.yml@v0.8.0 with: pattern: ${{ inputs.pattern }} diff --git a/.github/workflows/publish-pypi.yml b/.github/workflows/publish-pypi.yml index 825a3836c72..9f4c6d01abd 100644 --- a/.github/workflows/publish-pypi.yml +++ b/.github/workflows/publish-pypi.yml @@ -73,7 +73,7 @@ jobs: # We do this, since failures on test.pypi aren't that bad - name: Publish to Test PyPI if: startsWith(github.event.ref, 'refs/tags') || github.event_name == 'release' - uses: pypa/gh-action-pypi-publish@v1.6.4 + uses: pypa/gh-action-pypi-publish@v1.8.4 with: user: __token__ password: ${{ secrets.test_pypi_password }} @@ -82,7 +82,7 @@ jobs: - name: Publish distribution 📦 to PyPI if: startsWith(github.event.ref, 'refs/tags') || github.event_name == 'release' - uses: pypa/gh-action-pypi-publish@v1.6.4 + uses: pypa/gh-action-pypi-publish@v1.8.4 with: user: __token__ password: ${{ secrets.pypi_password }} diff --git a/src/torchmetrics/classification/accuracy.py b/src/torchmetrics/classification/accuracy.py index 220b687d777..b192c56493a 100644 --- a/src/torchmetrics/classification/accuracy.py +++ b/src/torchmetrics/classification/accuracy.py @@ -480,16 +480,25 @@ def __new__( ) -> Metric: """Initialize task metric.""" task = ClassificationTask.from_str(task) + kwargs.update( {"multidim_average": multidim_average, "ignore_index": ignore_index, "validate_args": validate_args} ) + if task == ClassificationTask.BINARY: return BinaryAccuracy(threshold, **kwargs) if task == ClassificationTask.MULTICLASS: - assert isinstance(num_classes, int) - assert isinstance(top_k, int) + if not isinstance(num_classes, int): + raise ValueError( + f"Optional arg `num_classes` must be type `int` when task is {task}. Got {type(num_classes)}" + ) + if not isinstance(top_k, int): + raise ValueError(f"Optional arg `top_k` must be type `int` when task is {task}. Got {type(top_k)}") return MulticlassAccuracy(num_classes, top_k, average, **kwargs) if task == ClassificationTask.MULTILABEL: - assert isinstance(num_labels, int) + if not isinstance(num_labels, int): + raise ValueError( + f"Optional arg `num_labels` must be type `int` when task is {task}. Got {type(num_labels)}" + ) return MultilabelAccuracy(num_labels, threshold, average, **kwargs) - return None + raise ValueError(f"Not handled value: {task}") # this is for compliant of mypy diff --git a/src/torchmetrics/functional/classification/accuracy.py b/src/torchmetrics/functional/classification/accuracy.py index a6d11782d64..61bf8b8b860 100644 --- a/src/torchmetrics/functional/classification/accuracy.py +++ b/src/torchmetrics/functional/classification/accuracy.py @@ -378,8 +378,8 @@ def accuracy( threshold: float = 0.5, num_classes: Optional[int] = None, num_labels: Optional[int] = None, - average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro", - multidim_average: Optional[Literal["global", "samplewise"]] = "global", + average: Literal["micro", "macro", "weighted", "none"] = "micro", + multidim_average: Literal["global", "samplewise"] = "global", top_k: Optional[int] = 1, ignore_index: Optional[int] = None, validate_args: bool = True, @@ -409,17 +409,24 @@ def accuracy( tensor(0.6667) """ task = ClassificationTask.from_str(task) - assert multidim_average is not None + if task == ClassificationTask.BINARY: return binary_accuracy(preds, target, threshold, multidim_average, ignore_index, validate_args) if task == ClassificationTask.MULTICLASS: - assert isinstance(num_classes, int) - assert isinstance(top_k, int) + if not isinstance(num_classes, int): + raise ValueError( + f"Optional arg `num_classes` must be type `int` when task is {task}. Got {type(num_classes)}" + ) + if not isinstance(top_k, int): + raise ValueError(f"Optional arg `top_k` must be type `int` when task is {task}. Got {type(top_k)}") return multiclass_accuracy( preds, target, num_classes, average, top_k, multidim_average, ignore_index, validate_args ) if task == ClassificationTask.MULTILABEL: - assert isinstance(num_labels, int) + if not isinstance(num_labels, int): + raise ValueError( + f"Optional arg `num_labels` must be type `int` when task is {task}. Got {type(num_labels)}" + ) return multilabel_accuracy( preds, target, num_labels, threshold, average, multidim_average, ignore_index, validate_args ) diff --git a/tests/unittests/classification/test_accuracy.py b/tests/unittests/classification/test_accuracy.py index 7c7764a3643..1e4ff2787b1 100644 --- a/tests/unittests/classification/test_accuracy.py +++ b/tests/unittests/classification/test_accuracy.py @@ -20,10 +20,15 @@ from sklearn.metrics import accuracy_score as sk_accuracy from sklearn.metrics import confusion_matrix as sk_confusion_matrix -from torchmetrics.classification.accuracy import BinaryAccuracy, MulticlassAccuracy, MultilabelAccuracy -from torchmetrics.functional.classification.accuracy import binary_accuracy, multiclass_accuracy, multilabel_accuracy +from torchmetrics.classification.accuracy import Accuracy, BinaryAccuracy, MulticlassAccuracy, MultilabelAccuracy +from torchmetrics.functional.classification.accuracy import ( + accuracy, + binary_accuracy, + multiclass_accuracy, + multilabel_accuracy, +) from unittests import NUM_CLASSES, THRESHOLD -from unittests.classification.inputs import _binary_cases, _multiclass_cases, _multilabel_cases +from unittests.classification.inputs import _binary_cases, _input_binary, _multiclass_cases, _multilabel_cases from unittests.helpers import seed_all from unittests.helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index @@ -61,6 +66,34 @@ def _sklearn_accuracy_binary(preds, target, ignore_index, multidim_average): return np.stack(res) +def test_accuracy_raises_invalid_task(): + """Tests accuracy task enum from Accuracy.""" + task = "NotValidTask" + ignore_index = None + multidim_average = "global" + + with pytest.raises(ValueError, match=r"Invalid *"): + Accuracy(threshold=THRESHOLD, task=task, ignore_index=ignore_index, multidim_average=multidim_average) + + +def test_accuracy_functional_raises_invalid_task(): + """Tests accuracy task enum from functional.accuracy.""" + preds, target = _input_binary + task = "NotValidTask" + ignore_index = None + multidim_average = "global" + + with pytest.raises(ValueError, match=r"Invalid *"): + accuracy( + preds, + target, + threshold=THRESHOLD, + task=task, + ignore_index=ignore_index, + multidim_average=multidim_average, + ) + + @pytest.mark.parametrize("input", _binary_cases) class TestBinaryAccuracy(MetricTester): """Test class for `BinaryAccuracy` metric."""