From 060b4c6a1f01c066bdbfb6e0dfc9f0f1f74d2772 Mon Sep 17 00:00:00 2001 From: Jirka Date: Mon, 12 Sep 2022 23:46:37 +0200 Subject: [PATCH 01/10] try --- .github/workflows/code-format.yml | 3 +-- src/torchmetrics/classification/accuracy.py | 4 ++-- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/.github/workflows/code-format.yml b/.github/workflows/code-format.yml index 41340d8138e..003302733b0 100644 --- a/.github/workflows/code-format.yml +++ b/.github/workflows/code-format.yml @@ -2,8 +2,7 @@ name: Code formatting # see: https://help.github.com/en/actions/reference/events-that-trigger-workflows on: # Trigger the workflow on push or pull request, but only for the master branch - push: - branches: [master, "release/*"] + push: {} pull_request: branches: [master, "release/*"] diff --git a/src/torchmetrics/classification/accuracy.py b/src/torchmetrics/classification/accuracy.py index a13f51a6c72..69eae6f2d59 100644 --- a/src/torchmetrics/classification/accuracy.py +++ b/src/torchmetrics/classification/accuracy.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional +from typing import Any, Optional, Union import torch from torch import Tensor, tensor @@ -472,7 +472,7 @@ def __new__( multidim_average: Optional[Literal["global", "samplewise"]] = "global", validate_args: bool = True, **kwargs: Any, - ) -> None: + ) -> Union[None, BinaryAccuracy, MulticlassAccuracy, MultilabelAccuracy]: if task is not None: kwargs.update( dict(multidim_average=multidim_average, ignore_index=ignore_index, validate_args=validate_args) From 16799251f600ae4f8e296db1b3c44e0976eef6e1 Mon Sep 17 00:00:00 2001 From: Jirka Date: Tue, 13 Sep 2022 00:29:07 +0200 Subject: [PATCH 02/10] return Metric --- src/torchmetrics/classification/accuracy.py | 5 +++-- src/torchmetrics/classification/auroc.py | 2 +- src/torchmetrics/classification/average_precision.py | 2 +- src/torchmetrics/classification/calibration_error.py | 2 +- src/torchmetrics/classification/cohen_kappa.py | 2 +- src/torchmetrics/classification/confusion_matrix.py | 2 +- src/torchmetrics/classification/f_beta.py | 5 +++-- src/torchmetrics/classification/hamming.py | 2 +- src/torchmetrics/classification/hinge.py | 2 +- src/torchmetrics/classification/jaccard.py | 3 ++- src/torchmetrics/classification/matthews_corrcoef.py | 2 +- src/torchmetrics/classification/precision_recall.py | 5 +++-- src/torchmetrics/classification/precision_recall_curve.py | 2 +- src/torchmetrics/classification/roc.py | 2 +- src/torchmetrics/classification/specificity.py | 3 ++- src/torchmetrics/classification/stat_scores.py | 2 +- 16 files changed, 24 insertions(+), 19 deletions(-) diff --git a/src/torchmetrics/classification/accuracy.py b/src/torchmetrics/classification/accuracy.py index 69eae6f2d59..5a82089880e 100644 --- a/src/torchmetrics/classification/accuracy.py +++ b/src/torchmetrics/classification/accuracy.py @@ -11,12 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional, Union +from typing import Any, Optional import torch from torch import Tensor, tensor from typing_extensions import Literal +from torchmetrics import Metric from torchmetrics.functional.classification.accuracy import ( _accuracy_compute, _accuracy_reduce, @@ -472,7 +473,7 @@ def __new__( multidim_average: Optional[Literal["global", "samplewise"]] = "global", validate_args: bool = True, **kwargs: Any, - ) -> Union[None, BinaryAccuracy, MulticlassAccuracy, MultilabelAccuracy]: + ) -> Metric: if task is not None: kwargs.update( dict(multidim_average=multidim_average, ignore_index=ignore_index, validate_args=validate_args) diff --git a/src/torchmetrics/classification/auroc.py b/src/torchmetrics/classification/auroc.py index 5365e46dbec..2526870d52d 100644 --- a/src/torchmetrics/classification/auroc.py +++ b/src/torchmetrics/classification/auroc.py @@ -416,7 +416,7 @@ def __new__( ignore_index: Optional[int] = None, validate_args: bool = True, **kwargs: Any, - ) -> None: + ) -> Metric: if task is not None: kwargs.update(dict(thresholds=thresholds, ignore_index=ignore_index, validate_args=validate_args)) if task == "binary": diff --git a/src/torchmetrics/classification/average_precision.py b/src/torchmetrics/classification/average_precision.py index 53d5bffb0e9..504af498f9c 100644 --- a/src/torchmetrics/classification/average_precision.py +++ b/src/torchmetrics/classification/average_precision.py @@ -395,7 +395,7 @@ def __new__( ignore_index: Optional[int] = None, validate_args: bool = True, **kwargs: Any, - ) -> None: + ) -> Metric: if task is not None: kwargs.update(dict(thresholds=thresholds, ignore_index=ignore_index, validate_args=validate_args)) if task == "binary": diff --git a/src/torchmetrics/classification/calibration_error.py b/src/torchmetrics/classification/calibration_error.py index 3fcf8ace9a2..ae4ff259351 100644 --- a/src/torchmetrics/classification/calibration_error.py +++ b/src/torchmetrics/classification/calibration_error.py @@ -274,7 +274,7 @@ def __new__( ignore_index: Optional[int] = None, validate_args: bool = True, **kwargs: Any, - ) -> None: + ) -> Metric: if task is not None: kwargs.update(dict(n_bins=n_bins, norm=norm, ignore_index=ignore_index, validate_args=validate_args)) if task == "binary": diff --git a/src/torchmetrics/classification/cohen_kappa.py b/src/torchmetrics/classification/cohen_kappa.py index 3306c419117..52977052391 100644 --- a/src/torchmetrics/classification/cohen_kappa.py +++ b/src/torchmetrics/classification/cohen_kappa.py @@ -245,7 +245,7 @@ def __new__( ignore_index: Optional[int] = None, validate_args: bool = True, **kwargs: Any, - ) -> None: + ) -> Metric: if task is not None: kwargs.update(dict(weights=weights, ignore_index=ignore_index, validate_args=validate_args)) if task == "binary": diff --git a/src/torchmetrics/classification/confusion_matrix.py b/src/torchmetrics/classification/confusion_matrix.py index c012dfddc1b..912bbec5ff7 100644 --- a/src/torchmetrics/classification/confusion_matrix.py +++ b/src/torchmetrics/classification/confusion_matrix.py @@ -401,7 +401,7 @@ def __new__( ignore_index: Optional[int] = None, validate_args: bool = True, **kwargs: Any, - ) -> None: + ) -> Metric: if task is not None: kwargs.update(dict(normalize=normalize, ignore_index=ignore_index, validate_args=validate_args)) if task == "binary": diff --git a/src/torchmetrics/classification/f_beta.py b/src/torchmetrics/classification/f_beta.py index 73c3bdc446c..c40a032dc5b 100644 --- a/src/torchmetrics/classification/f_beta.py +++ b/src/torchmetrics/classification/f_beta.py @@ -17,6 +17,7 @@ from torch import Tensor from typing_extensions import Literal +from torchmetrics import Metric from torchmetrics.classification.stat_scores import ( BinaryStatScores, MulticlassStatScores, @@ -832,7 +833,7 @@ def __new__( multidim_average: Optional[Literal["global", "samplewise"]] = "global", validate_args: bool = True, **kwargs: Any, - ) -> None: + ) -> Metric: if task is not None: kwargs.update( dict(multidim_average=multidim_average, ignore_index=ignore_index, validate_args=validate_args) @@ -993,7 +994,7 @@ def __new__( multidim_average: Optional[Literal["global", "samplewise"]] = "global", validate_args: bool = True, **kwargs: Any, - ) -> None: + ) -> Metric: if task is not None: kwargs.update( dict(multidim_average=multidim_average, ignore_index=ignore_index, validate_args=validate_args) diff --git a/src/torchmetrics/classification/hamming.py b/src/torchmetrics/classification/hamming.py index d63f0a1bf72..c031e5c6c98 100644 --- a/src/torchmetrics/classification/hamming.py +++ b/src/torchmetrics/classification/hamming.py @@ -374,7 +374,7 @@ def __new__( ignore_index: Optional[int] = None, validate_args: bool = True, **kwargs: Any, - ) -> None: + ) -> Metric: if task is not None: kwargs.update( dict(multidim_average=multidim_average, ignore_index=ignore_index, validate_args=validate_args) diff --git a/src/torchmetrics/classification/hinge.py b/src/torchmetrics/classification/hinge.py index ce99e28e83a..7af4f1cce3d 100644 --- a/src/torchmetrics/classification/hinge.py +++ b/src/torchmetrics/classification/hinge.py @@ -290,7 +290,7 @@ def __new__( ignore_index: Optional[int] = None, validate_args: bool = True, **kwargs: Any, - ) -> None: + ) -> Metric: if task is not None: kwargs.update(dict(ignore_index=ignore_index, validate_args=validate_args)) if task == "binary": diff --git a/src/torchmetrics/classification/jaccard.py b/src/torchmetrics/classification/jaccard.py index e17e07e4af5..e00f3fb5a1e 100644 --- a/src/torchmetrics/classification/jaccard.py +++ b/src/torchmetrics/classification/jaccard.py @@ -17,6 +17,7 @@ from torch import Tensor from typing_extensions import Literal +from torchmetrics import Metric from torchmetrics.classification import BinaryConfusionMatrix, MulticlassConfusionMatrix, MultilabelConfusionMatrix from torchmetrics.classification.confusion_matrix import ConfusionMatrix from torchmetrics.functional.classification.jaccard import ( @@ -329,7 +330,7 @@ def __new__( num_labels: Optional[int] = None, validate_args: bool = True, **kwargs: Any, - ) -> None: + ) -> Metric: if task is not None: kwargs.update(dict(ignore_index=ignore_index, validate_args=validate_args)) if task == "binary": diff --git a/src/torchmetrics/classification/matthews_corrcoef.py b/src/torchmetrics/classification/matthews_corrcoef.py index b31d24b3e2e..0108e3ea2f3 100644 --- a/src/torchmetrics/classification/matthews_corrcoef.py +++ b/src/torchmetrics/classification/matthews_corrcoef.py @@ -280,7 +280,7 @@ def __new__( ignore_index: Optional[int] = None, validate_args: bool = True, **kwargs: Any, - ) -> None: + ) -> Metric: if task is not None: kwargs.update(dict(ignore_index=ignore_index, validate_args=validate_args)) if task == "binary": diff --git a/src/torchmetrics/classification/precision_recall.py b/src/torchmetrics/classification/precision_recall.py index 49f03d84c92..a6b225cf037 100644 --- a/src/torchmetrics/classification/precision_recall.py +++ b/src/torchmetrics/classification/precision_recall.py @@ -17,6 +17,7 @@ from torch import Tensor from typing_extensions import Literal +from torchmetrics import Metric from torchmetrics.classification.stat_scores import ( BinaryStatScores, MulticlassStatScores, @@ -713,7 +714,7 @@ def __new__( multidim_average: Optional[Literal["global", "samplewise"]] = "global", validate_args: bool = True, **kwargs: Any, - ) -> None: + ) -> Metric: if task is not None: kwargs.update( dict(multidim_average=multidim_average, ignore_index=ignore_index, validate_args=validate_args) @@ -882,7 +883,7 @@ def __new__( multidim_average: Optional[Literal["global", "samplewise"]] = "global", validate_args: bool = True, **kwargs: Any, - ) -> None: + ) -> Metric: if task is not None: kwargs.update( dict(multidim_average=multidim_average, ignore_index=ignore_index, validate_args=validate_args) diff --git a/src/torchmetrics/classification/precision_recall_curve.py b/src/torchmetrics/classification/precision_recall_curve.py index a8e9d8c1823..26bbfcdb484 100644 --- a/src/torchmetrics/classification/precision_recall_curve.py +++ b/src/torchmetrics/classification/precision_recall_curve.py @@ -481,7 +481,7 @@ def __new__( ignore_index: Optional[int] = None, validate_args: bool = True, **kwargs: Any, - ) -> None: + ) -> Metric: if task is not None: kwargs.update(dict(thresholds=thresholds, ignore_index=ignore_index, validate_args=validate_args)) if task == "binary": diff --git a/src/torchmetrics/classification/roc.py b/src/torchmetrics/classification/roc.py index 9f30e1691e8..6b6f1bf1eb1 100644 --- a/src/torchmetrics/classification/roc.py +++ b/src/torchmetrics/classification/roc.py @@ -405,7 +405,7 @@ def __new__( ignore_index: Optional[int] = None, validate_args: bool = True, **kwargs: Any, - ) -> None: + ) -> Metric: if task is not None: kwargs.update(dict(thresholds=thresholds, ignore_index=ignore_index, validate_args=validate_args)) if task == "binary": diff --git a/src/torchmetrics/classification/specificity.py b/src/torchmetrics/classification/specificity.py index e087a0e0077..97687f6b668 100644 --- a/src/torchmetrics/classification/specificity.py +++ b/src/torchmetrics/classification/specificity.py @@ -17,6 +17,7 @@ from torch import Tensor from typing_extensions import Literal +from torchmetrics import Metric from torchmetrics.classification.stat_scores import ( BinaryStatScores, MulticlassStatScores, @@ -403,7 +404,7 @@ def __new__( multidim_average: Optional[Literal["global", "samplewise"]] = "global", validate_args: bool = True, **kwargs: Any, - ) -> None: + ) -> Metric: if task is not None: kwargs.update( dict(multidim_average=multidim_average, ignore_index=ignore_index, validate_args=validate_args) diff --git a/src/torchmetrics/classification/stat_scores.py b/src/torchmetrics/classification/stat_scores.py index 39406255441..286ad3b0390 100644 --- a/src/torchmetrics/classification/stat_scores.py +++ b/src/torchmetrics/classification/stat_scores.py @@ -602,7 +602,7 @@ def __new__( multidim_average: Optional[Literal["global", "samplewise"]] = "global", validate_args: bool = True, **kwargs: Any, - ) -> None: + ) -> Metric: if task is not None: kwargs.update( dict(multidim_average=multidim_average, ignore_index=ignore_index, validate_args=validate_args) From 31b50d78ab83b8a212d68887f47c20e1fc37bfbc Mon Sep 17 00:00:00 2001 From: Jirka Date: Tue, 13 Sep 2022 00:34:55 +0200 Subject: [PATCH 03/10] fix top_k, average --- src/torchmetrics/classification/accuracy.py | 2 +- src/torchmetrics/classification/f_beta.py | 4 ++-- src/torchmetrics/classification/hamming.py | 2 +- src/torchmetrics/classification/hinge.py | 2 +- src/torchmetrics/classification/specificity.py | 2 +- src/torchmetrics/classification/stat_scores.py | 2 +- 6 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/torchmetrics/classification/accuracy.py b/src/torchmetrics/classification/accuracy.py index 5a82089880e..783713fb205 100644 --- a/src/torchmetrics/classification/accuracy.py +++ b/src/torchmetrics/classification/accuracy.py @@ -481,7 +481,7 @@ def __new__( if task == "binary": return BinaryAccuracy(threshold, **kwargs) if task == "multiclass": - return MulticlassAccuracy(num_classes, average, top_k, **kwargs) + return MulticlassAccuracy(num_classes, top_k, average, **kwargs) if task == "multilabel": return MultilabelAccuracy(num_labels, threshold, average, **kwargs) raise ValueError( diff --git a/src/torchmetrics/classification/f_beta.py b/src/torchmetrics/classification/f_beta.py index c40a032dc5b..384aa26b2be 100644 --- a/src/torchmetrics/classification/f_beta.py +++ b/src/torchmetrics/classification/f_beta.py @@ -841,7 +841,7 @@ def __new__( if task == "binary": return BinaryFBetaScore(beta, threshold, **kwargs) if task == "multiclass": - return MulticlassFBetaScore(beta, num_classes, average, top_k, **kwargs) + return MulticlassFBetaScore(beta, num_classes, top_k, average, **kwargs) if task == "multilabel": return MultilabelFBetaScore(beta, num_labels, threshold, average, **kwargs) raise ValueError( @@ -1002,7 +1002,7 @@ def __new__( if task == "binary": return BinaryF1Score(threshold, **kwargs) if task == "multiclass": - return MulticlassF1Score(num_classes, average, top_k, **kwargs) + return MulticlassF1Score(num_classes, top_k, average, **kwargs) if task == "multilabel": return MultilabelF1Score(num_labels, threshold, average, **kwargs) raise ValueError( diff --git a/src/torchmetrics/classification/hamming.py b/src/torchmetrics/classification/hamming.py index c031e5c6c98..701387cbaba 100644 --- a/src/torchmetrics/classification/hamming.py +++ b/src/torchmetrics/classification/hamming.py @@ -382,7 +382,7 @@ def __new__( if task == "binary": return BinaryHammingDistance(threshold, **kwargs) if task == "multiclass": - return MulticlassHammingDistance(num_classes, average, top_k, **kwargs) + return MulticlassHammingDistance(num_classes, top_k, average, **kwargs) if task == "multilabel": return MultilabelHammingDistance(num_labels, threshold, average, **kwargs) raise ValueError( diff --git a/src/torchmetrics/classification/hinge.py b/src/torchmetrics/classification/hinge.py index 7af4f1cce3d..dede90b5609 100644 --- a/src/torchmetrics/classification/hinge.py +++ b/src/torchmetrics/classification/hinge.py @@ -284,7 +284,7 @@ class HingeLoss(Metric): def __new__( cls, squared: bool = False, - multiclass_mode: Optional[Union[str, MulticlassMode]] = None, + multiclass_mode: Literal["crammer-singer", "one-vs-all"] = None, task: Optional[Literal["binary", "multiclass", "multilabel"]] = None, num_classes: Optional[int] = None, ignore_index: Optional[int] = None, diff --git a/src/torchmetrics/classification/specificity.py b/src/torchmetrics/classification/specificity.py index 97687f6b668..7fa1be2d19a 100644 --- a/src/torchmetrics/classification/specificity.py +++ b/src/torchmetrics/classification/specificity.py @@ -412,7 +412,7 @@ def __new__( if task == "binary": return BinarySpecificity(threshold, **kwargs) if task == "multiclass": - return MulticlassSpecificity(num_classes, average, top_k, **kwargs) + return MulticlassSpecificity(num_classes, top_k, average, **kwargs) if task == "multilabel": return MultilabelSpecificity(num_labels, threshold, average, **kwargs) raise ValueError( diff --git a/src/torchmetrics/classification/stat_scores.py b/src/torchmetrics/classification/stat_scores.py index 286ad3b0390..32b02b083a5 100644 --- a/src/torchmetrics/classification/stat_scores.py +++ b/src/torchmetrics/classification/stat_scores.py @@ -610,7 +610,7 @@ def __new__( if task == "binary": return BinaryStatScores(threshold, **kwargs) if task == "multiclass": - return MulticlassStatScores(num_classes, average, top_k, **kwargs) + return MulticlassStatScores(num_classes, top_k, average, **kwargs) if task == "multilabel": return MultilabelStatScores(num_labels, threshold, average, **kwargs) raise ValueError( From 3804e4271bfb4ae8aad83ce777fb23096dec8c3f Mon Sep 17 00:00:00 2001 From: Jirka Date: Tue, 13 Sep 2022 00:45:55 +0200 Subject: [PATCH 04/10] assert --- .../functional/classification/accuracy.py | 2 +- .../functional/classification/auroc.py | 22 ++++++------ .../classification/average_precision.py | 2 +- .../functional/classification/cohen_kappa.py | 7 ++-- .../functional/classification/f_beta.py | 19 ++++++---- .../functional/classification/hamming.py | 17 ++++++--- .../functional/classification/jaccard.py | 2 +- .../classification/precision_recall.py | 36 +++++++++++++------ .../classification/precision_recall_curve.py | 10 +++--- .../functional/classification/roc.py | 9 ++--- .../functional/classification/specificity.py | 17 ++++++--- .../functional/classification/stat_scores.py | 17 ++++++--- 12 files changed, 102 insertions(+), 58 deletions(-) diff --git a/src/torchmetrics/functional/classification/accuracy.py b/src/torchmetrics/functional/classification/accuracy.py index 88a5174bdd8..d5d77c91334 100644 --- a/src/torchmetrics/functional/classification/accuracy.py +++ b/src/torchmetrics/functional/classification/accuracy.py @@ -625,7 +625,7 @@ def _subset_accuracy_compute(correct: Tensor, total: Tensor) -> Tensor: def accuracy( preds: Tensor, target: Tensor, - average: Optional[str] = "micro", + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro", mdmc_average: Optional[str] = "global", threshold: float = 0.5, top_k: Optional[int] = None, diff --git a/src/torchmetrics/functional/classification/auroc.py b/src/torchmetrics/functional/classification/auroc.py index 85a8c8f8313..40bd2f4ba3a 100644 --- a/src/torchmetrics/functional/classification/auroc.py +++ b/src/torchmetrics/functional/classification/auroc.py @@ -53,7 +53,6 @@ def _reduce_auroc( weights: Optional[Tensor] = None, ) -> Tensor: """Utility function for reducing multiple average precision score into one number.""" - res = [] if isinstance(fpr, Tensor): res = _auc_compute_without_check(fpr, tpr, 1.0, axis=1) else: @@ -96,7 +95,7 @@ def _binary_auroc_compute( thresholds: Optional[Tensor], max_fpr: Optional[float] = None, pos_label: int = 1, -) -> Tuple[Tensor, Tensor, Tensor]: +) -> Union[Tensor, Tuple[Tensor, Tensor, Tensor]]: fpr, tpr, _ = _binary_roc_compute(state, thresholds, pos_label) if max_fpr is None or max_fpr == 1: return _auc_compute_without_check(fpr, tpr, 1.0) @@ -221,7 +220,7 @@ def multiclass_auroc( thresholds: Optional[Union[int, List[float], Tensor]] = None, ignore_index: Optional[int] = None, validate_args: bool = True, -) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]: +) -> Tensor: r""" Compute Area Under the Receiver Operating Characteristic Curve (`ROC AUC`_) for multiclass tasks. The AUROC score summarizes the ROC curve into an single number that describes the performance of a model for multiple @@ -317,7 +316,7 @@ def _multilabel_auroc_compute( average: Optional[Literal["micro", "macro", "weighted", "none"]], thresholds: Optional[Tensor], ignore_index: Optional[int] = None, -) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]: +) -> Union[Tuple[Tensor, Tensor, Tensor], Tensor]: if average == "micro": if isinstance(state, Tensor) and thresholds is not None: return _binary_auroc_compute(state.sum(1), thresholds, max_fpr=None) @@ -328,7 +327,7 @@ def _multilabel_auroc_compute( idx = target == ignore_index preds = preds[~idx] target = target[~idx] - return _binary_auroc_compute([preds, target], thresholds, max_fpr=None) + return _binary_auroc_compute((preds, target), thresholds, max_fpr=None) else: fpr, tpr, _ = _multilabel_roc_compute(state, num_labels, thresholds, ignore_index) @@ -605,7 +604,7 @@ def auroc( target: Tensor, num_classes: Optional[int] = None, pos_label: Optional[int] = None, - average: Optional[str] = "macro", + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro", max_fpr: Optional[float] = None, sample_weights: Optional[Sequence] = None, task: Optional[Literal["binary", "multiclass", "multilabel"]] = None, @@ -613,7 +612,7 @@ def auroc( num_labels: Optional[int] = None, ignore_index: Optional[int] = None, validate_args: bool = True, -) -> Tensor: +) -> Union[Tensor, Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]: r""" .. note:: From v0.10 an `'binary_*'`, `'multiclass_*', `'multilabel_*'` version now exist of each classification @@ -686,13 +685,14 @@ def auroc( tensor(0.7778) """ if task is not None: - kwargs = dict(thresholds=thresholds, ignore_index=ignore_index, validate_args=validate_args) if task == "binary": - return binary_auroc(preds, target, max_fpr, **kwargs) + return binary_auroc(preds, target, max_fpr, thresholds, ignore_index, validate_args) if task == "multiclass": - return multiclass_auroc(preds, target, num_classes, average, **kwargs) + assert isinstance(num_classes, int) + return multiclass_auroc(preds, target, num_classes, average, thresholds, ignore_index, validate_args) if task == "multilabel": - return multilabel_auroc(preds, target, num_labels, average, **kwargs) + assert isinstance(num_labels, int) + return multilabel_auroc(preds, target, num_labels, average, thresholds, ignore_index, validate_args) raise ValueError( f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" ) diff --git a/src/torchmetrics/functional/classification/average_precision.py b/src/torchmetrics/functional/classification/average_precision.py index 4bff1961f1c..f23b2600e2b 100644 --- a/src/torchmetrics/functional/classification/average_precision.py +++ b/src/torchmetrics/functional/classification/average_precision.py @@ -562,7 +562,7 @@ def average_precision( target: Tensor, num_classes: Optional[int] = None, pos_label: Optional[int] = None, - average: Optional[str] = "macro", + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro", task: Optional[Literal["binary", "multiclass", "multilabel"]] = None, thresholds: Optional[Union[int, List[float], Tensor]] = None, num_labels: Optional[int] = None, diff --git a/src/torchmetrics/functional/classification/cohen_kappa.py b/src/torchmetrics/functional/classification/cohen_kappa.py index 5d3d8a064c6..a32f03039a2 100644 --- a/src/torchmetrics/functional/classification/cohen_kappa.py +++ b/src/torchmetrics/functional/classification/cohen_kappa.py @@ -287,7 +287,7 @@ def cohen_kappa( preds: Tensor, target: Tensor, num_classes: int, - weights: Optional[str] = None, + weights: Optional[Literal["linear", "quadratic", "none"]] = None, threshold: float = 0.5, task: Optional[Literal["binary", "multiclass", "multilabel"]] = None, ignore_index: Optional[int] = None, @@ -335,11 +335,10 @@ class labels. tensor(0.5000) """ if task is not None: - kwargs = dict(weights=weights, ignore_index=ignore_index, validate_args=validate_args) if task == "binary": - return binary_cohen_kappa(preds, target, threshold, **kwargs) + return binary_cohen_kappa(preds, target, threshold, weights, ignore_index, validate_args) if task == "multiclass": - return multiclass_cohen_kappa(preds, target, num_classes, **kwargs) + return multiclass_cohen_kappa(preds, target, num_classes, weights, ignore_index, validate_args) raise ValueError( f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" ) diff --git a/src/torchmetrics/functional/classification/f_beta.py b/src/torchmetrics/functional/classification/f_beta.py index bea2e88a6f1..e863fac7aa9 100644 --- a/src/torchmetrics/functional/classification/f_beta.py +++ b/src/torchmetrics/functional/classification/f_beta.py @@ -794,7 +794,7 @@ def fbeta_score( preds: Tensor, target: Tensor, beta: float = 1.0, - average: Optional[str] = "micro", + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro", mdmc_average: Optional[str] = None, ignore_index: Optional[int] = None, num_classes: Optional[int] = None, @@ -962,7 +962,7 @@ def f1_score( preds: Tensor, target: Tensor, beta: float = 1.0, - average: Optional[str] = "micro", + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro", mdmc_average: Optional[str] = None, ignore_index: Optional[int] = None, num_classes: Optional[int] = None, @@ -1076,13 +1076,20 @@ def f1_score( tensor(0.3333) """ if task is not None: - kwargs = dict(multidim_average=multidim_average, ignore_index=ignore_index, validate_args=validate_args) + assert multidim_average is not None if task == "binary": - return binary_f1_score(preds, target, threshold, **kwargs) + return binary_f1_score(preds, target, threshold, multidim_average, ignore_index, validate_args) if task == "multiclass": - return multiclass_f1_score(preds, target, num_classes, average, top_k, **kwargs) + assert isinstance(num_classes, int) + assert isinstance(top_k, int) + return multiclass_f1_score( + preds, target, num_classes, average, top_k, multidim_average, ignore_index, validate_args + ) if task == "multilabel": - return multilabel_f1_score(preds, target, num_labels, threshold, average, **kwargs) + assert isinstance(num_labels, int) + return multilabel_f1_score( + preds, target, num_labels, threshold, average, multidim_average, ignore_index, validate_args + ) raise ValueError( f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" ) diff --git a/src/torchmetrics/functional/classification/hamming.py b/src/torchmetrics/functional/classification/hamming.py index 84d459de82d..bc7f41c8749 100644 --- a/src/torchmetrics/functional/classification/hamming.py +++ b/src/torchmetrics/functional/classification/hamming.py @@ -437,7 +437,7 @@ def hamming_distance( task: Optional[Literal["binary", "multiclass", "multilabel"]] = None, num_classes: Optional[int] = None, num_labels: Optional[int] = None, - average: Optional[str] = "macro", + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro", top_k: int = 1, multidim_average: Optional[Literal["global", "samplewise"]] = "global", ignore_index: Optional[int] = None, @@ -482,13 +482,20 @@ def hamming_distance( tensor(0.2500) """ if task is not None: - kwargs = dict(multidim_average=multidim_average, ignore_index=ignore_index, validate_args=validate_args) + assert multidim_average is not None if task == "binary": - return binary_hamming_distance(preds, target, threshold, **kwargs) + return binary_hamming_distance(preds, target, threshold, multidim_average, ignore_index, validate_args) if task == "multiclass": - return multiclass_hamming_distance(preds, target, num_classes, average, top_k, **kwargs) + assert isinstance(num_classes, int) + assert isinstance(top_k, int) + return multiclass_hamming_distance( + preds, target, num_classes, average, top_k, multidim_average, ignore_index, validate_args + ) if task == "multilabel": - return multilabel_hamming_distance(preds, target, num_labels, threshold, average, **kwargs) + assert isinstance(num_labels, int) + return multilabel_hamming_distance( + preds, target, num_labels, threshold, average, multidim_average, ignore_index, validate_args + ) raise ValueError( f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" ) diff --git a/src/torchmetrics/functional/classification/jaccard.py b/src/torchmetrics/functional/classification/jaccard.py index 05b49bb9390..1588f1b442b 100644 --- a/src/torchmetrics/functional/classification/jaccard.py +++ b/src/torchmetrics/functional/classification/jaccard.py @@ -378,7 +378,7 @@ def jaccard_index( preds: Tensor, target: Tensor, num_classes: int, - average: Optional[str] = "macro", + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro", ignore_index: Optional[int] = None, absent_score: float = 0.0, threshold: float = 0.5, diff --git a/src/torchmetrics/functional/classification/precision_recall.py b/src/torchmetrics/functional/classification/precision_recall.py index b00172cee7f..066038a8282 100644 --- a/src/torchmetrics/functional/classification/precision_recall.py +++ b/src/torchmetrics/functional/classification/precision_recall.py @@ -712,7 +712,7 @@ def _precision_compute( def precision( preds: Tensor, target: Tensor, - average: Optional[str] = "micro", + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro", mdmc_average: Optional[str] = None, ignore_index: Optional[int] = None, num_classes: Optional[int] = None, @@ -833,13 +833,20 @@ def precision( """ if task is not None: - kwargs = dict(multidim_average=multidim_average, ignore_index=ignore_index, validate_args=validate_args) + assert multidim_average is not None if task == "binary": - return binary_precision(preds, target, threshold, **kwargs) + return binary_precision(preds, target, threshold, multidim_average, ignore_index, validate_args) if task == "multiclass": - return multiclass_precision(preds, target, num_classes, average, top_k, **kwargs) + assert isinstance(num_classes, int) + assert isinstance(top_k, int) + return multiclass_precision( + preds, target, num_classes, average, top_k, multidim_average, ignore_index, validate_args + ) if task == "multilabel": - return multilabel_precision(preds, target, num_labels, threshold, average, **kwargs) + assert isinstance(num_labels, int) + return multilabel_precision( + preds, target, num_labels, threshold, average, multidim_average, ignore_index, validate_args + ) raise ValueError( f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" ) @@ -936,7 +943,7 @@ def _recall_compute( def recall( preds: Tensor, target: Tensor, - average: Optional[str] = "micro", + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro", mdmc_average: Optional[str] = None, ignore_index: Optional[int] = None, num_classes: Optional[int] = None, @@ -1058,13 +1065,20 @@ def recall( """ if task is not None: - kwargs = dict(multidim_average=multidim_average, ignore_index=ignore_index, validate_args=validate_args) + assert multidim_average is not None if task == "binary": - return binary_recall(preds, target, threshold, **kwargs) + return binary_recall(preds, target, threshold, multidim_average, ignore_index, validate_args) if task == "multiclass": - return multiclass_recall(preds, target, num_classes, average, top_k, **kwargs) + assert isinstance(num_classes, int) + assert isinstance(top_k, int) + return multiclass_recall( + preds, target, num_classes, average, top_k, multidim_average, ignore_index, validate_args + ) if task == "multilabel": - return multilabel_recall(preds, target, num_labels, threshold, average, **kwargs) + assert isinstance(num_labels, int) + return multilabel_recall( + preds, target, num_labels, threshold, average, multidim_average, ignore_index, validate_args + ) raise ValueError( f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" ) @@ -1110,7 +1124,7 @@ def recall( def precision_recall( preds: Tensor, target: Tensor, - average: Optional[str] = "micro", + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro", mdmc_average: Optional[str] = None, ignore_index: Optional[int] = None, num_classes: Optional[int] = None, diff --git a/src/torchmetrics/functional/classification/precision_recall_curve.py b/src/torchmetrics/functional/classification/precision_recall_curve.py index 84fc046ac7f..1733014b3a1 100644 --- a/src/torchmetrics/functional/classification/precision_recall_curve.py +++ b/src/torchmetrics/functional/classification/precision_recall_curve.py @@ -1053,13 +1053,15 @@ def precision_recall_curve( [tensor([0.7500]), tensor([0.7500]), tensor([0.0500, 0.7500]), tensor([0.0500, 0.7500]), tensor([0.0500])] """ if task is not None: - kwargs = dict(thresholds=thresholds, ignore_index=ignore_index, validate_args=validate_args) if task == "binary": - return binary_precision_recall_curve(preds, target, **kwargs) + return binary_precision_recall_curve(preds, target, thresholds, ignore_index, validate_args) if task == "multiclass": - return multiclass_precision_recall_curve(preds, target, num_classes, **kwargs) + assert isinstance(num_classes, int) + return multiclass_precision_recall_curve( + preds, target, num_classes, thresholds, ignore_index, validate_args + ) if task == "multilabel": - return multilabel_precision_recall_curve(preds, target, num_labels, **kwargs) + return multilabel_precision_recall_curve(preds, target, num_labels, thresholds, ignore_index, validate_args) raise ValueError( f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" ) diff --git a/src/torchmetrics/functional/classification/roc.py b/src/torchmetrics/functional/classification/roc.py index f064a294219..b9afd6160ea 100644 --- a/src/torchmetrics/functional/classification/roc.py +++ b/src/torchmetrics/functional/classification/roc.py @@ -696,13 +696,14 @@ def roc( tensor([1.1837, 0.1837, 0.1338, 0.1183, 0.1138])] """ if task is not None: - kwargs = dict(thresholds=thresholds, ignore_index=ignore_index, validate_args=validate_args) if task == "binary": - return binary_roc(preds, target, **kwargs) + return binary_roc(preds, target, thresholds, ignore_index, validate_args) if task == "multiclass": - return multiclass_roc(preds, target, num_classes, **kwargs) + assert isinstance(num_classes, int) + return multiclass_roc(preds, target, num_classes, thresholds, ignore_index, validate_args) if task == "multilabel": - return multilabel_roc(preds, target, num_labels, **kwargs) + assert isinstance(num_labels, int) + return multilabel_roc(preds, target, num_labels, thresholds, ignore_index, validate_args) raise ValueError( f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" ) diff --git a/src/torchmetrics/functional/classification/specificity.py b/src/torchmetrics/functional/classification/specificity.py index 5893c2fcea3..c032ef0dc14 100644 --- a/src/torchmetrics/functional/classification/specificity.py +++ b/src/torchmetrics/functional/classification/specificity.py @@ -409,7 +409,7 @@ def _specificity_compute( def specificity( preds: Tensor, target: Tensor, - average: Optional[str] = "micro", + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro", mdmc_average: Optional[str] = None, ignore_index: Optional[int] = None, num_classes: Optional[int] = None, @@ -529,13 +529,20 @@ def specificity( tensor(0.6250) """ if task is not None: - kwargs = dict(multidim_average=multidim_average, ignore_index=ignore_index, validate_args=validate_args) + assert multidim_average is not None if task == "binary": - return binary_specificity(preds, target, threshold, **kwargs) + return binary_specificity(preds, target, threshold, multidim_average, ignore_index, validate_args) if task == "multiclass": - return multiclass_specificity(preds, target, num_classes, average, top_k, **kwargs) + assert isinstance(num_classes, int) + assert isinstance(top_k, int) + return multiclass_specificity( + preds, target, num_classes, average, top_k, multidim_average, ignore_index, validate_args + ) if task == "multilabel": - return multilabel_specificity(preds, target, num_labels, threshold, average, **kwargs) + assert isinstance(num_labels, int) + return multilabel_specificity( + preds, target, num_labels, threshold, average, multidim_average, ignore_index, validate_args + ) raise ValueError( f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" ) diff --git a/src/torchmetrics/functional/classification/stat_scores.py b/src/torchmetrics/functional/classification/stat_scores.py index f68e1de8080..2602107fbce 100644 --- a/src/torchmetrics/functional/classification/stat_scores.py +++ b/src/torchmetrics/functional/classification/stat_scores.py @@ -1095,7 +1095,7 @@ def stat_scores( ignore_index: Optional[int] = None, task: Optional[Literal["binary", "multiclass", "multilabel"]] = None, num_labels: Optional[int] = None, - average: Optional[str] = "micro", + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro", multidim_average: Optional[Literal["global", "samplewise"]] = "global", validate_args: bool = True, ) -> Tensor: @@ -1225,13 +1225,20 @@ def stat_scores( """ if task is not None: - kwargs = dict(multidim_average=multidim_average, ignore_index=ignore_index, validate_args=validate_args) + assert multidim_average is not None if task == "binary": - return binary_stat_scores(preds, target, threshold, **kwargs) + return binary_stat_scores(preds, target, threshold, multidim_average, ignore_index, validate_args) if task == "multiclass": - return multiclass_stat_scores(preds, target, num_classes, average, top_k, **kwargs) + assert isinstance(num_classes, int) + assert isinstance(top_k, int) + return multiclass_stat_scores( + preds, target, num_classes, average, top_k, multidim_average, ignore_index, validate_args + ) if task == "multilabel": - return multilabel_stat_scores(preds, target, num_labels, threshold, average, **kwargs) + assert isinstance(num_labels, int) + return multilabel_stat_scores( + preds, target, num_labels, threshold, average, multidim_average, ignore_index, validate_args + ) raise ValueError( f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" ) From 663060fbfb3f5a299e12345b79790aa0bf7c3d5d Mon Sep 17 00:00:00 2001 From: Jirka Date: Tue, 13 Sep 2022 01:58:12 +0200 Subject: [PATCH 05/10] assert --- .../functional/classification/accuracy.py | 15 +++++++++++---- .../classification/average_precision.py | 11 +++++++---- .../classification/calibration_error.py | 9 +++++---- .../functional/classification/confusion_matrix.py | 13 ++++++++----- .../functional/classification/f_beta.py | 15 +++++++++++---- .../functional/classification/hinge.py | 11 +++++++---- .../functional/classification/jaccard.py | 9 +++++---- .../classification/matthews_corrcoef.py | 9 +++++---- .../classification/precision_recall_curve.py | 1 + 9 files changed, 60 insertions(+), 33 deletions(-) diff --git a/src/torchmetrics/functional/classification/accuracy.py b/src/torchmetrics/functional/classification/accuracy.py index d5d77c91334..f8484bf0fdb 100644 --- a/src/torchmetrics/functional/classification/accuracy.py +++ b/src/torchmetrics/functional/classification/accuracy.py @@ -770,13 +770,20 @@ def accuracy( tensor(0.6667) """ if task is not None: - kwargs = dict(multidim_average=multidim_average, ignore_index=ignore_index, validate_args=validate_args) + assert multidim_average is not None if task == "binary": - return binary_accuracy(preds, target, threshold, **kwargs) + return binary_accuracy(preds, target, threshold, multidim_average, ignore_index, validate_args) if task == "multiclass": - return multiclass_accuracy(preds, target, num_classes, average, top_k, **kwargs) + assert isinstance(num_classes, int) + assert isinstance(top_k, int) + return multiclass_accuracy( + preds, target, num_classes, average, top_k, multidim_average, ignore_index, validate_args + ) if task == "multilabel": - return multilabel_accuracy(preds, target, num_labels, threshold, average, **kwargs) + assert isinstance(num_labels, int) + return multilabel_accuracy( + preds, target, num_labels, threshold, average, multidim_average, ignore_index, validate_args + ) raise ValueError( f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" ) diff --git a/src/torchmetrics/functional/classification/average_precision.py b/src/torchmetrics/functional/classification/average_precision.py index f23b2600e2b..af6364ced35 100644 --- a/src/torchmetrics/functional/classification/average_precision.py +++ b/src/torchmetrics/functional/classification/average_precision.py @@ -621,13 +621,16 @@ def average_precision( [tensor(1.), tensor(1.), tensor(0.2500), tensor(0.2500), tensor(nan)] """ if task is not None: - kwargs = dict(thresholds=thresholds, ignore_index=ignore_index, validate_args=validate_args) if task == "binary": - return binary_average_precision(preds, target, **kwargs) + return binary_average_precision(preds, target, thresholds, ignore_index, validate_args) if task == "multiclass": - return multiclass_average_precision(preds, target, num_classes, average, **kwargs) + assert isinstance(num_classes, int) + return multiclass_average_precision( + preds, target, num_classes, average, thresholds, ignore_index, validate_args + ) if task == "multilabel": - return multilabel_average_precision(preds, target, num_labels, **kwargs) + assert isinstance(num_labels, int) + return multilabel_average_precision(preds, target, num_labels, thresholds, ignore_index, validate_args) raise ValueError( f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" ) diff --git a/src/torchmetrics/functional/classification/calibration_error.py b/src/torchmetrics/functional/classification/calibration_error.py index e4698d4de47..f0763472d88 100644 --- a/src/torchmetrics/functional/classification/calibration_error.py +++ b/src/torchmetrics/functional/classification/calibration_error.py @@ -393,7 +393,7 @@ def calibration_error( preds: Tensor, target: Tensor, n_bins: int = 15, - norm: str = "l1", + norm: Literal["l1", "l2", "max"] = "l1", task: Optional[Literal["binary", "multiclass", "multilabel"]] = None, num_classes: Optional[int] = None, ignore_index: Optional[int] = None, @@ -441,11 +441,12 @@ def calibration_error( Defaults to "l1", or Expected Calibration Error. """ if task is not None: - kwargs = dict(norm=norm, ignore_index=ignore_index, validate_args=validate_args) + assert norm is not None if task == "binary": - return binary_calibration_error(preds, target, n_bins, **kwargs) + return binary_calibration_error(preds, target, n_bins, norm, ignore_index, validate_args) if task == "multiclass": - return multiclass_calibration_error(preds, target, num_classes, n_bins, **kwargs) + assert isinstance(num_classes, int) + return multiclass_calibration_error(preds, target, num_classes, n_bins, norm, ignore_index, validate_args) raise ValueError(f"Expected argument `task` to either be `'binary'`, `'multiclass'` but got {task}") else: rank_zero_warn( diff --git a/src/torchmetrics/functional/classification/confusion_matrix.py b/src/torchmetrics/functional/classification/confusion_matrix.py index 41cbc67ffa5..7912caa1866 100644 --- a/src/torchmetrics/functional/classification/confusion_matrix.py +++ b/src/torchmetrics/functional/classification/confusion_matrix.py @@ -693,7 +693,7 @@ def confusion_matrix( preds: Tensor, target: Tensor, num_classes: int, - normalize: Optional[str] = None, + normalize: Optional[Literal["true", "pred", "all", "none"]] = None, threshold: float = 0.5, multilabel: bool = False, task: Optional[Literal["binary", "multiclass", "multilabel"]] = None, @@ -771,13 +771,16 @@ def confusion_matrix( """ if task is not None: - kwargs = dict(normalize=normalize, ignore_index=ignore_index, validate_args=validate_args) if task == "binary": - return binary_confusion_matrix(preds, target, threshold, **kwargs) + return binary_confusion_matrix(preds, target, threshold, normalize, ignore_index, validate_args) if task == "multiclass": - return multiclass_confusion_matrix(preds, target, num_classes, **kwargs) + assert isinstance(num_classes, int) + return multiclass_confusion_matrix(preds, target, num_classes, normalize, ignore_index, validate_args) if task == "multilabel": - return multilabel_confusion_matrix(preds, target, num_labels, threshold, **kwargs) + assert isinstance(num_labels, int) + return multilabel_confusion_matrix( + preds, target, num_labels, threshold, normalize, ignore_index, validate_args + ) raise ValueError( f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" ) diff --git a/src/torchmetrics/functional/classification/f_beta.py b/src/torchmetrics/functional/classification/f_beta.py index e863fac7aa9..5906533ddfc 100644 --- a/src/torchmetrics/functional/classification/f_beta.py +++ b/src/torchmetrics/functional/classification/f_beta.py @@ -910,13 +910,20 @@ def fbeta_score( """ if task is not None: - kwargs = dict(multidim_average=multidim_average, ignore_index=ignore_index, validate_args=validate_args) + assert multidim_average is not None if task == "binary": - return binary_fbeta_score(preds, target, beta, threshold, **kwargs) + return binary_fbeta_score(preds, target, beta, threshold, multidim_average, ignore_index, validate_args) if task == "multiclass": - return multiclass_fbeta_score(preds, target, beta, num_classes, average, top_k, **kwargs) + assert isinstance(num_classes, int) + assert isinstance(top_k, int) + return multiclass_fbeta_score( + preds, target, beta, num_classes, average, top_k, multidim_average, ignore_index, validate_args + ) if task == "multilabel": - return multilabel_fbeta_score(preds, target, beta, num_labels, threshold, average, **kwargs) + assert isinstance(num_labels, int) + return multilabel_fbeta_score( + preds, target, beta, num_labels, threshold, average, multidim_average, ignore_index, validate_args + ) raise ValueError( f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" ) diff --git a/src/torchmetrics/functional/classification/hinge.py b/src/torchmetrics/functional/classification/hinge.py index e6b9388a3cd..de6e445927d 100644 --- a/src/torchmetrics/functional/classification/hinge.py +++ b/src/torchmetrics/functional/classification/hinge.py @@ -383,7 +383,7 @@ def hinge_loss( preds: Tensor, target: Tensor, squared: bool = False, - multiclass_mode: Optional[Union[str, MulticlassMode]] = None, + multiclass_mode: Optional[Literal["crammer-singer", "one-vs-all"]] = None, task: Optional[Literal["binary", "multiclass", "multilabel"]] = None, num_classes: Optional[int] = None, ignore_index: Optional[int] = None, @@ -464,11 +464,14 @@ def hinge_loss( tensor([2.2333, 1.5000, 1.2333]) """ if task is not None: - kwargs = dict(ignore_index=ignore_index, validate_args=validate_args) if task == "binary": - return binary_hinge_loss(preds, target, squared, **kwargs) + return binary_hinge_loss(preds, target, squared, ignore_index, validate_args) if task == "multiclass": - return multiclass_hinge_loss(preds, target, num_classes, squared, multiclass_mode, **kwargs) + assert isinstance(num_classes, int) + assert multiclass_mode is not None + return multiclass_hinge_loss( + preds, target, num_classes, squared, multiclass_mode, ignore_index, validate_args + ) raise ValueError( f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" ) diff --git a/src/torchmetrics/functional/classification/jaccard.py b/src/torchmetrics/functional/classification/jaccard.py index 1588f1b442b..4f64819bbc0 100644 --- a/src/torchmetrics/functional/classification/jaccard.py +++ b/src/torchmetrics/functional/classification/jaccard.py @@ -454,13 +454,14 @@ def jaccard_index( tensor(0.9660) """ if task is not None: - kwargs = dict(ignore_index=ignore_index, validate_args=validate_args) if task == "binary": - return binary_jaccard_index(preds, target, threshold, **kwargs) + return binary_jaccard_index(preds, target, threshold, ignore_index, validate_args) if task == "multiclass": - return multiclass_jaccard_index(preds, target, num_classes, average, **kwargs) + assert isinstance(num_classes, int) + return multiclass_jaccard_index(preds, target, num_classes, average, ignore_index, validate_args) if task == "multilabel": - return multilabel_jaccard_index(preds, target, num_labels, threshold, average, **kwargs) + assert isinstance(num_labels, int) + return multilabel_jaccard_index(preds, target, num_labels, threshold, average, ignore_index, validate_args) raise ValueError( f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" ) diff --git a/src/torchmetrics/functional/classification/matthews_corrcoef.py b/src/torchmetrics/functional/classification/matthews_corrcoef.py index a5774b2a8f7..3fb4fce6136 100644 --- a/src/torchmetrics/functional/classification/matthews_corrcoef.py +++ b/src/torchmetrics/functional/classification/matthews_corrcoef.py @@ -315,13 +315,14 @@ def matthews_corrcoef( """ if task is not None: - kwargs = dict(ignore_index=ignore_index, validate_args=validate_args) if task == "binary": - return binary_matthews_corrcoef(preds, target, threshold, **kwargs) + return binary_matthews_corrcoef(preds, target, threshold, ignore_index, validate_args) if task == "multiclass": - return multiclass_matthews_corrcoef(preds, target, num_classes, **kwargs) + assert isinstance(num_classes, int) + return multiclass_matthews_corrcoef(preds, target, num_classes, ignore_index, validate_args) if task == "multilabel": - return multilabel_matthews_corrcoef(preds, target, num_labels, threshold, **kwargs) + assert isinstance(num_labels, int) + return multilabel_matthews_corrcoef(preds, target, num_labels, threshold, ignore_index, validate_args) raise ValueError( f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" ) diff --git a/src/torchmetrics/functional/classification/precision_recall_curve.py b/src/torchmetrics/functional/classification/precision_recall_curve.py index 1733014b3a1..8aaa26ec755 100644 --- a/src/torchmetrics/functional/classification/precision_recall_curve.py +++ b/src/torchmetrics/functional/classification/precision_recall_curve.py @@ -1061,6 +1061,7 @@ def precision_recall_curve( preds, target, num_classes, thresholds, ignore_index, validate_args ) if task == "multilabel": + assert isinstance(num_labels, int) return multilabel_precision_recall_curve(preds, target, num_labels, thresholds, ignore_index, validate_args) raise ValueError( f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" From 4693ba7a917abc2e9cd96bba57f2c5d4e1e88899 Mon Sep 17 00:00:00 2001 From: Jirka Date: Tue, 13 Sep 2022 02:06:04 +0200 Subject: [PATCH 06/10] assert --- src/torchmetrics/classification/accuracy.py | 3 +++ src/torchmetrics/classification/auroc.py | 2 ++ src/torchmetrics/classification/average_precision.py | 2 ++ src/torchmetrics/classification/calibration_error.py | 1 + src/torchmetrics/classification/cohen_kappa.py | 1 + src/torchmetrics/classification/confusion_matrix.py | 2 ++ src/torchmetrics/classification/f_beta.py | 6 ++++++ src/torchmetrics/classification/hamming.py | 3 +++ src/torchmetrics/classification/hinge.py | 1 + src/torchmetrics/classification/jaccard.py | 2 ++ src/torchmetrics/classification/matthews_corrcoef.py | 2 ++ src/torchmetrics/classification/precision_recall.py | 3 +++ src/torchmetrics/classification/precision_recall_curve.py | 2 ++ src/torchmetrics/classification/roc.py | 2 ++ src/torchmetrics/classification/specificity.py | 3 +++ src/torchmetrics/classification/stat_scores.py | 3 +++ 16 files changed, 38 insertions(+) diff --git a/src/torchmetrics/classification/accuracy.py b/src/torchmetrics/classification/accuracy.py index 783713fb205..78d3f1c104a 100644 --- a/src/torchmetrics/classification/accuracy.py +++ b/src/torchmetrics/classification/accuracy.py @@ -481,8 +481,11 @@ def __new__( if task == "binary": return BinaryAccuracy(threshold, **kwargs) if task == "multiclass": + assert isinstance(num_classes, int) + assert isinstance(top_k, int) return MulticlassAccuracy(num_classes, top_k, average, **kwargs) if task == "multilabel": + assert isinstance(num_labels, int) return MultilabelAccuracy(num_labels, threshold, average, **kwargs) raise ValueError( f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" diff --git a/src/torchmetrics/classification/auroc.py b/src/torchmetrics/classification/auroc.py index 2526870d52d..6ba478f456f 100644 --- a/src/torchmetrics/classification/auroc.py +++ b/src/torchmetrics/classification/auroc.py @@ -422,8 +422,10 @@ def __new__( if task == "binary": return BinaryAUROC(max_fpr, **kwargs) if task == "multiclass": + assert isinstance(num_classes, int) return MulticlassAUROC(num_classes, average, **kwargs) if task == "multilabel": + assert isinstance(num_labels, int) return MultilabelAUROC(num_labels, average, **kwargs) raise ValueError( f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" diff --git a/src/torchmetrics/classification/average_precision.py b/src/torchmetrics/classification/average_precision.py index 504af498f9c..a19548784dd 100644 --- a/src/torchmetrics/classification/average_precision.py +++ b/src/torchmetrics/classification/average_precision.py @@ -401,8 +401,10 @@ def __new__( if task == "binary": return BinaryAveragePrecision(**kwargs) if task == "multiclass": + assert isinstance(num_classes, int) return MulticlassAveragePrecision(num_classes, average, **kwargs) if task == "multilabel": + assert isinstance(num_labels, int) return MultilabelAveragePrecision(num_labels, average, **kwargs) raise ValueError( f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" diff --git a/src/torchmetrics/classification/calibration_error.py b/src/torchmetrics/classification/calibration_error.py index ae4ff259351..12eb29615ab 100644 --- a/src/torchmetrics/classification/calibration_error.py +++ b/src/torchmetrics/classification/calibration_error.py @@ -280,6 +280,7 @@ def __new__( if task == "binary": return BinaryCalibrationError(**kwargs) if task == "multiclass": + assert isinstance(num_classes, int) return MulticlassCalibrationError(num_classes, **kwargs) raise ValueError( f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" diff --git a/src/torchmetrics/classification/cohen_kappa.py b/src/torchmetrics/classification/cohen_kappa.py index 52977052391..68d8f3653ac 100644 --- a/src/torchmetrics/classification/cohen_kappa.py +++ b/src/torchmetrics/classification/cohen_kappa.py @@ -251,6 +251,7 @@ def __new__( if task == "binary": return BinaryCohenKappa(threshold, **kwargs) if task == "multiclass": + assert isinstance(num_classes, int) return MulticlassCohenKappa(num_classes, **kwargs) raise ValueError( f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" diff --git a/src/torchmetrics/classification/confusion_matrix.py b/src/torchmetrics/classification/confusion_matrix.py index 912bbec5ff7..7c674f11f42 100644 --- a/src/torchmetrics/classification/confusion_matrix.py +++ b/src/torchmetrics/classification/confusion_matrix.py @@ -407,8 +407,10 @@ def __new__( if task == "binary": return BinaryConfusionMatrix(threshold, **kwargs) if task == "multiclass": + assert isinstance(num_classes, int) return MulticlassConfusionMatrix(num_classes, **kwargs) if task == "multilabel": + assert isinstance(num_labels, int) return MultilabelConfusionMatrix(num_labels, threshold, **kwargs) raise ValueError( f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" diff --git a/src/torchmetrics/classification/f_beta.py b/src/torchmetrics/classification/f_beta.py index 384aa26b2be..51a4255067e 100644 --- a/src/torchmetrics/classification/f_beta.py +++ b/src/torchmetrics/classification/f_beta.py @@ -841,8 +841,11 @@ def __new__( if task == "binary": return BinaryFBetaScore(beta, threshold, **kwargs) if task == "multiclass": + assert isinstance(num_classes, int) + assert isinstance(top_k, int) return MulticlassFBetaScore(beta, num_classes, top_k, average, **kwargs) if task == "multilabel": + assert isinstance(num_labels, int) return MultilabelFBetaScore(beta, num_labels, threshold, average, **kwargs) raise ValueError( f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" @@ -1002,8 +1005,11 @@ def __new__( if task == "binary": return BinaryF1Score(threshold, **kwargs) if task == "multiclass": + assert isinstance(num_classes, int) + assert isinstance(top_k, int) return MulticlassF1Score(num_classes, top_k, average, **kwargs) if task == "multilabel": + assert isinstance(num_labels, int) return MultilabelF1Score(num_labels, threshold, average, **kwargs) raise ValueError( f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" diff --git a/src/torchmetrics/classification/hamming.py b/src/torchmetrics/classification/hamming.py index 701387cbaba..273f3a97cb2 100644 --- a/src/torchmetrics/classification/hamming.py +++ b/src/torchmetrics/classification/hamming.py @@ -382,8 +382,11 @@ def __new__( if task == "binary": return BinaryHammingDistance(threshold, **kwargs) if task == "multiclass": + assert isinstance(num_classes, int) + assert isinstance(top_k, int) return MulticlassHammingDistance(num_classes, top_k, average, **kwargs) if task == "multilabel": + assert isinstance(num_labels, int) return MultilabelHammingDistance(num_labels, threshold, average, **kwargs) raise ValueError( f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" diff --git a/src/torchmetrics/classification/hinge.py b/src/torchmetrics/classification/hinge.py index dede90b5609..ad5e1641ecb 100644 --- a/src/torchmetrics/classification/hinge.py +++ b/src/torchmetrics/classification/hinge.py @@ -296,6 +296,7 @@ def __new__( if task == "binary": return BinaryHingeLoss(squared, **kwargs) if task == "multiclass": + assert isinstance(num_classes, int) return MulticlassHingeLoss(num_classes, squared, multiclass_mode, **kwargs) raise ValueError( f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" diff --git a/src/torchmetrics/classification/jaccard.py b/src/torchmetrics/classification/jaccard.py index e00f3fb5a1e..8575edb37a2 100644 --- a/src/torchmetrics/classification/jaccard.py +++ b/src/torchmetrics/classification/jaccard.py @@ -336,8 +336,10 @@ def __new__( if task == "binary": return BinaryJaccardIndex(threshold, **kwargs) if task == "multiclass": + assert isinstance(num_classes, int) return MulticlassJaccardIndex(num_classes, average, **kwargs) if task == "multilabel": + assert isinstance(num_labels, int) return MultilabelJaccardIndex(num_labels, threshold, average, **kwargs) raise ValueError( f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" diff --git a/src/torchmetrics/classification/matthews_corrcoef.py b/src/torchmetrics/classification/matthews_corrcoef.py index 0108e3ea2f3..754b6b613bd 100644 --- a/src/torchmetrics/classification/matthews_corrcoef.py +++ b/src/torchmetrics/classification/matthews_corrcoef.py @@ -286,8 +286,10 @@ def __new__( if task == "binary": return BinaryMatthewsCorrCoef(threshold, **kwargs) if task == "multiclass": + assert isinstance(num_classes, int) return MulticlassMatthewsCorrCoef(num_classes, **kwargs) if task == "multilabel": + assert isinstance(num_labels, int) return MultilabelMatthewsCorrCoef(num_labels, threshold, **kwargs) raise ValueError( f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" diff --git a/src/torchmetrics/classification/precision_recall.py b/src/torchmetrics/classification/precision_recall.py index a6b225cf037..3f8ae519a5a 100644 --- a/src/torchmetrics/classification/precision_recall.py +++ b/src/torchmetrics/classification/precision_recall.py @@ -722,8 +722,11 @@ def __new__( if task == "binary": return BinaryPrecision(threshold, **kwargs) if task == "multiclass": + assert isinstance(num_classes, int) + assert isinstance(top_k, int) return MulticlassPrecision(num_classes, top_k, average, **kwargs) if task == "multilabel": + assert isinstance(num_labels, int) return MultilabelPrecision(num_labels, threshold, average, **kwargs) raise ValueError( f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" diff --git a/src/torchmetrics/classification/precision_recall_curve.py b/src/torchmetrics/classification/precision_recall_curve.py index 26bbfcdb484..a98542e4f11 100644 --- a/src/torchmetrics/classification/precision_recall_curve.py +++ b/src/torchmetrics/classification/precision_recall_curve.py @@ -487,8 +487,10 @@ def __new__( if task == "binary": return BinaryPrecisionRecallCurve(**kwargs) if task == "multiclass": + assert isinstance(num_classes, int) return MulticlassPrecisionRecallCurve(num_classes, **kwargs) if task == "multilabel": + assert isinstance(num_labels, int) return MultilabelPrecisionRecallCurve(num_labels, **kwargs) raise ValueError( f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" diff --git a/src/torchmetrics/classification/roc.py b/src/torchmetrics/classification/roc.py index 6b6f1bf1eb1..e20548e6be6 100644 --- a/src/torchmetrics/classification/roc.py +++ b/src/torchmetrics/classification/roc.py @@ -411,8 +411,10 @@ def __new__( if task == "binary": return BinaryROC(**kwargs) if task == "multiclass": + assert isinstance(num_classes, int) return MulticlassROC(num_classes, **kwargs) if task == "multilabel": + assert isinstance(num_labels, int) return MultilabelROC(num_labels, **kwargs) raise ValueError( f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" diff --git a/src/torchmetrics/classification/specificity.py b/src/torchmetrics/classification/specificity.py index 7fa1be2d19a..199cba5f340 100644 --- a/src/torchmetrics/classification/specificity.py +++ b/src/torchmetrics/classification/specificity.py @@ -412,8 +412,11 @@ def __new__( if task == "binary": return BinarySpecificity(threshold, **kwargs) if task == "multiclass": + assert isinstance(num_classes, int) + assert isinstance(top_k, int) return MulticlassSpecificity(num_classes, top_k, average, **kwargs) if task == "multilabel": + assert isinstance(num_labels, int) return MultilabelSpecificity(num_labels, threshold, average, **kwargs) raise ValueError( f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" diff --git a/src/torchmetrics/classification/stat_scores.py b/src/torchmetrics/classification/stat_scores.py index 32b02b083a5..2a097b7982e 100644 --- a/src/torchmetrics/classification/stat_scores.py +++ b/src/torchmetrics/classification/stat_scores.py @@ -610,8 +610,11 @@ def __new__( if task == "binary": return BinaryStatScores(threshold, **kwargs) if task == "multiclass": + assert isinstance(num_classes, int) + assert isinstance(top_k, int) return MulticlassStatScores(num_classes, top_k, average, **kwargs) if task == "multilabel": + assert isinstance(num_labels, int) return MultilabelStatScores(num_labels, threshold, average, **kwargs) raise ValueError( f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" From 12aa4d8d561b5813d378c93e167e18bb06dec24b Mon Sep 17 00:00:00 2001 From: Jirka Date: Tue, 13 Sep 2022 02:17:40 +0200 Subject: [PATCH 07/10] assert --- src/torchmetrics/classification/accuracy.py | 5 +++-- src/torchmetrics/classification/auroc.py | 4 ++-- .../classification/average_precision.py | 4 ++-- src/torchmetrics/classification/dice.py | 4 ++-- src/torchmetrics/classification/f_beta.py | 10 ++++++---- src/torchmetrics/classification/hamming.py | 3 ++- src/torchmetrics/classification/jaccard.py | 4 ++-- src/torchmetrics/classification/precision_recall.py | 13 +++++++++---- src/torchmetrics/classification/specificity.py | 5 +++-- src/torchmetrics/classification/stat_scores.py | 7 ++++++- 10 files changed, 37 insertions(+), 22 deletions(-) diff --git a/src/torchmetrics/classification/accuracy.py b/src/torchmetrics/classification/accuracy.py index 78d3f1c104a..3676f684fb4 100644 --- a/src/torchmetrics/classification/accuracy.py +++ b/src/torchmetrics/classification/accuracy.py @@ -462,7 +462,7 @@ def __new__( cls, threshold: float = 0.5, num_classes: Optional[int] = None, - average: Optional[str] = "micro", + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro", mdmc_average: Optional[str] = None, ignore_index: Optional[int] = None, top_k: Optional[int] = None, @@ -475,6 +475,7 @@ def __new__( **kwargs: Any, ) -> Metric: if task is not None: + assert multidim_average is not None kwargs.update( dict(multidim_average=multidim_average, ignore_index=ignore_index, validate_args=validate_args) ) @@ -496,7 +497,7 @@ def __init__( self, threshold: float = 0.5, num_classes: Optional[int] = None, - average: Optional[str] = "micro", + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro", mdmc_average: Optional[str] = None, ignore_index: Optional[int] = None, top_k: Optional[int] = None, diff --git a/src/torchmetrics/classification/auroc.py b/src/torchmetrics/classification/auroc.py index 6ba478f456f..d3d3a8a430a 100644 --- a/src/torchmetrics/classification/auroc.py +++ b/src/torchmetrics/classification/auroc.py @@ -408,7 +408,7 @@ def __new__( cls, num_classes: Optional[int] = None, pos_label: Optional[int] = None, - average: Optional[str] = "macro", + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro", max_fpr: Optional[float] = None, task: Optional[Literal["binary", "multiclass", "multilabel"]] = None, thresholds: Optional[Union[int, List[float], Tensor]] = None, @@ -436,7 +436,7 @@ def __init__( self, num_classes: Optional[int] = None, pos_label: Optional[int] = None, - average: Optional[str] = "macro", + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro", max_fpr: Optional[float] = None, task: Optional[Literal["binary", "multiclass", "multilabel"]] = None, thresholds: Optional[Union[int, List[float], Tensor]] = None, diff --git a/src/torchmetrics/classification/average_precision.py b/src/torchmetrics/classification/average_precision.py index a19548784dd..84fd23da8d7 100644 --- a/src/torchmetrics/classification/average_precision.py +++ b/src/torchmetrics/classification/average_precision.py @@ -388,7 +388,7 @@ def __new__( cls, num_classes: Optional[int] = None, pos_label: Optional[int] = None, - average: Optional[str] = "macro", + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro", task: Optional[Literal["binary", "multiclass", "multilabel"]] = None, thresholds: Optional[Union[int, List[float], Tensor]] = None, num_labels: Optional[int] = None, @@ -415,7 +415,7 @@ def __init__( self, num_classes: Optional[int] = None, pos_label: Optional[int] = None, - average: Optional[str] = "macro", + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro", task: Optional[Literal["binary", "multiclass", "multilabel"]] = None, thresholds: Optional[Union[int, List[float], Tensor]] = None, num_labels: Optional[int] = None, diff --git a/src/torchmetrics/classification/dice.py b/src/torchmetrics/classification/dice.py index 7f4684a27ee..65e5dcd20e0 100644 --- a/src/torchmetrics/classification/dice.py +++ b/src/torchmetrics/classification/dice.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional +from typing import Any, Literal, Optional from torch import Tensor @@ -124,7 +124,7 @@ def __init__( zero_division: int = 0, num_classes: Optional[int] = None, threshold: float = 0.5, - average: Optional[str] = "micro", + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro", mdmc_average: Optional[str] = "global", ignore_index: Optional[int] = None, top_k: Optional[int] = None, diff --git a/src/torchmetrics/classification/f_beta.py b/src/torchmetrics/classification/f_beta.py index 51a4255067e..9de586e7410 100644 --- a/src/torchmetrics/classification/f_beta.py +++ b/src/torchmetrics/classification/f_beta.py @@ -823,7 +823,7 @@ def __new__( num_classes: Optional[int] = None, beta: float = 1.0, threshold: float = 0.5, - average: Optional[str] = "micro", + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro", mdmc_average: Optional[str] = None, ignore_index: Optional[int] = None, top_k: Optional[int] = None, @@ -835,6 +835,7 @@ def __new__( **kwargs: Any, ) -> Metric: if task is not None: + assert multidim_average is not None kwargs.update( dict(multidim_average=multidim_average, ignore_index=ignore_index, validate_args=validate_args) ) @@ -857,7 +858,7 @@ def __init__( num_classes: Optional[int] = None, beta: float = 1.0, threshold: float = 0.5, - average: Optional[str] = "micro", + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro", mdmc_average: Optional[str] = None, ignore_index: Optional[int] = None, top_k: Optional[int] = None, @@ -987,7 +988,7 @@ def __new__( cls, num_classes: Optional[int] = None, threshold: float = 0.5, - average: Optional[str] = "micro", + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro", mdmc_average: Optional[str] = None, ignore_index: Optional[int] = None, top_k: Optional[int] = None, @@ -999,6 +1000,7 @@ def __new__( **kwargs: Any, ) -> Metric: if task is not None: + assert multidim_average is not None kwargs.update( dict(multidim_average=multidim_average, ignore_index=ignore_index, validate_args=validate_args) ) @@ -1020,7 +1022,7 @@ def __init__( self, num_classes: Optional[int] = None, threshold: float = 0.5, - average: Optional[str] = "micro", + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro", mdmc_average: Optional[str] = None, ignore_index: Optional[int] = None, top_k: Optional[int] = None, diff --git a/src/torchmetrics/classification/hamming.py b/src/torchmetrics/classification/hamming.py index 273f3a97cb2..edda3b1d1ec 100644 --- a/src/torchmetrics/classification/hamming.py +++ b/src/torchmetrics/classification/hamming.py @@ -368,7 +368,7 @@ def __new__( task: Optional[Literal["binary", "multiclass", "multilabel"]] = None, num_classes: Optional[int] = None, num_labels: Optional[int] = None, - average: Optional[str] = "micro", + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro", multidim_average: Optional[Literal["global", "samplewise"]] = "global", top_k: Optional[int] = None, ignore_index: Optional[int] = None, @@ -376,6 +376,7 @@ def __new__( **kwargs: Any, ) -> Metric: if task is not None: + assert multidim_average is not None kwargs.update( dict(multidim_average=multidim_average, ignore_index=ignore_index, validate_args=validate_args) ) diff --git a/src/torchmetrics/classification/jaccard.py b/src/torchmetrics/classification/jaccard.py index 8575edb37a2..4871045673b 100644 --- a/src/torchmetrics/classification/jaccard.py +++ b/src/torchmetrics/classification/jaccard.py @@ -321,7 +321,7 @@ class JaccardIndex(ConfusionMatrix): def __new__( cls, num_classes: int, - average: Optional[str] = "macro", + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro", ignore_index: Optional[int] = None, absent_score: float = 0.0, threshold: float = 0.5, @@ -349,7 +349,7 @@ def __new__( def __init__( self, num_classes: int, - average: Optional[str] = "macro", + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro", ignore_index: Optional[int] = None, absent_score: float = 0.0, threshold: float = 0.5, diff --git a/src/torchmetrics/classification/precision_recall.py b/src/torchmetrics/classification/precision_recall.py index 3f8ae519a5a..a1b324522c0 100644 --- a/src/torchmetrics/classification/precision_recall.py +++ b/src/torchmetrics/classification/precision_recall.py @@ -704,7 +704,7 @@ def __new__( cls, threshold: float = 0.5, num_classes: Optional[int] = None, - average: Optional[str] = "micro", + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro", mdmc_average: Optional[str] = None, ignore_index: Optional[int] = None, top_k: Optional[int] = None, @@ -716,6 +716,7 @@ def __new__( **kwargs: Any, ) -> Metric: if task is not None: + assert multidim_average is not None kwargs.update( dict(multidim_average=multidim_average, ignore_index=ignore_index, validate_args=validate_args) ) @@ -737,7 +738,7 @@ def __init__( self, num_classes: Optional[int] = None, threshold: float = 0.5, - average: Optional[str] = "micro", + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro", mdmc_average: Optional[str] = None, ignore_index: Optional[int] = None, top_k: Optional[int] = None, @@ -876,7 +877,7 @@ def __new__( cls, threshold: float = 0.5, num_classes: Optional[int] = None, - average: Optional[str] = "micro", + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro", mdmc_average: Optional[str] = None, ignore_index: Optional[int] = None, top_k: Optional[int] = None, @@ -888,14 +889,18 @@ def __new__( **kwargs: Any, ) -> Metric: if task is not None: + assert multidim_average is not None kwargs.update( dict(multidim_average=multidim_average, ignore_index=ignore_index, validate_args=validate_args) ) if task == "binary": return BinaryRecall(threshold, **kwargs) if task == "multiclass": + assert isinstance(num_classes, int) + assert isinstance(top_k, int) return MulticlassRecall(num_classes, top_k, average, **kwargs) if task == "multilabel": + assert isinstance(num_labels, int) return MultilabelRecall(num_labels, threshold, average, **kwargs) raise ValueError( f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" @@ -906,7 +911,7 @@ def __init__( self, num_classes: Optional[int] = None, threshold: float = 0.5, - average: Optional[str] = "micro", + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro", mdmc_average: Optional[str] = None, ignore_index: Optional[int] = None, top_k: Optional[int] = None, diff --git a/src/torchmetrics/classification/specificity.py b/src/torchmetrics/classification/specificity.py index 199cba5f340..5304273196d 100644 --- a/src/torchmetrics/classification/specificity.py +++ b/src/torchmetrics/classification/specificity.py @@ -394,7 +394,7 @@ def __new__( cls, num_classes: Optional[int] = None, threshold: float = 0.5, - average: Optional[str] = "micro", + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro", mdmc_average: Optional[str] = None, ignore_index: Optional[int] = None, top_k: Optional[int] = None, @@ -406,6 +406,7 @@ def __new__( **kwargs: Any, ) -> Metric: if task is not None: + assert multidim_average is not None kwargs.update( dict(multidim_average=multidim_average, ignore_index=ignore_index, validate_args=validate_args) ) @@ -427,7 +428,7 @@ def __init__( self, num_classes: Optional[int] = None, threshold: float = 0.5, - average: Optional[str] = "micro", + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro", mdmc_average: Optional[str] = None, ignore_index: Optional[int] = None, top_k: Optional[int] = None, diff --git a/src/torchmetrics/classification/stat_scores.py b/src/torchmetrics/classification/stat_scores.py index 2a097b7982e..7d9b1e7bc03 100644 --- a/src/torchmetrics/classification/stat_scores.py +++ b/src/torchmetrics/classification/stat_scores.py @@ -592,7 +592,7 @@ def __new__( cls, num_classes: Optional[int] = None, threshold: float = 0.5, - average: Optional[str] = "micro", + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro", mdmc_average: Optional[str] = None, ignore_index: Optional[int] = None, top_k: Optional[int] = None, @@ -604,6 +604,7 @@ def __new__( **kwargs: Any, ) -> Metric: if task is not None: + assert multidim_average is not None kwargs.update( dict(multidim_average=multidim_average, ignore_index=ignore_index, validate_args=validate_args) ) @@ -639,14 +640,18 @@ def __init__( ) -> None: self.task = task if self.task is not None: + assert multidim_average is not None kwargs.update( dict(multidim_average=multidim_average, ignore_index=ignore_index, validate_args=validate_args) ) if task == "binary": BinaryStatScores.__init__(self, threshold, **kwargs) if task == "multiclass": + assert isinstance(num_classes, int) + assert isinstance(top_k, int) MulticlassStatScores.__init__(self, num_classes, top_k, average, **kwargs) if task == "multilabel": + assert isinstance(num_labels, int) MultilabelStatScores.__init__(self, num_labels, threshold, average, **kwargs) raise ValueError( f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" From b37c7ecf50c64891ff7b2d61debb3b4e60d3d482 Mon Sep 17 00:00:00 2001 From: Jirka Date: Tue, 13 Sep 2022 02:28:40 +0200 Subject: [PATCH 08/10] fixing --- src/torchmetrics/classification/auroc.py | 3 ++- src/torchmetrics/classification/average_precision.py | 2 +- src/torchmetrics/classification/hinge.py | 1 + src/torchmetrics/classification/jaccard.py | 4 ++-- .../functional/classification/average_precision.py | 4 +--- 5 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/torchmetrics/classification/auroc.py b/src/torchmetrics/classification/auroc.py index d3d3a8a430a..76da48025d2 100644 --- a/src/torchmetrics/classification/auroc.py +++ b/src/torchmetrics/classification/auroc.py @@ -149,6 +149,7 @@ class MulticlassAUROC(MulticlassPrecisionRecallCurve): - ``macro``: Calculate score for each class and average them - ``weighted``: Calculates score for each class and computes weighted average using their support - ``"none"`` or ``None``: Calculates score for each class and applies no reduction + thresholds: Can be one of: @@ -408,7 +409,7 @@ def __new__( cls, num_classes: Optional[int] = None, pos_label: Optional[int] = None, - average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro", + average: Optional[Literal["macro", "weighted", "none"]] = "macro", max_fpr: Optional[float] = None, task: Optional[Literal["binary", "multiclass", "multilabel"]] = None, thresholds: Optional[Union[int, List[float], Tensor]] = None, diff --git a/src/torchmetrics/classification/average_precision.py b/src/torchmetrics/classification/average_precision.py index 84fd23da8d7..f0a0847e55b 100644 --- a/src/torchmetrics/classification/average_precision.py +++ b/src/torchmetrics/classification/average_precision.py @@ -388,7 +388,7 @@ def __new__( cls, num_classes: Optional[int] = None, pos_label: Optional[int] = None, - average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro", + average: Optional[Literal["macro", "weighted", "none"]] = "macro", task: Optional[Literal["binary", "multiclass", "multilabel"]] = None, thresholds: Optional[Union[int, List[float], Tensor]] = None, num_labels: Optional[int] = None, diff --git a/src/torchmetrics/classification/hinge.py b/src/torchmetrics/classification/hinge.py index ad5e1641ecb..2e651274665 100644 --- a/src/torchmetrics/classification/hinge.py +++ b/src/torchmetrics/classification/hinge.py @@ -297,6 +297,7 @@ def __new__( return BinaryHingeLoss(squared, **kwargs) if task == "multiclass": assert isinstance(num_classes, int) + assert multiclass_mode is not None return MulticlassHingeLoss(num_classes, squared, multiclass_mode, **kwargs) raise ValueError( f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" diff --git a/src/torchmetrics/classification/jaccard.py b/src/torchmetrics/classification/jaccard.py index 4871045673b..8bcc1b4f637 100644 --- a/src/torchmetrics/classification/jaccard.py +++ b/src/torchmetrics/classification/jaccard.py @@ -156,8 +156,8 @@ class MulticlassJaccardIndex(MulticlassConfusionMatrix): def __init__( self, num_classes: int, - ignore_index: Optional[int] = None, average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro", + ignore_index: Optional[int] = None, validate_args: bool = True, **kwargs: Any, ) -> None: @@ -233,8 +233,8 @@ def __init__( self, num_labels: int, threshold: float = 0.5, - ignore_index: Optional[int] = None, average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro", + ignore_index: Optional[int] = None, validate_args: bool = True, **kwargs: Any, ) -> None: diff --git a/src/torchmetrics/functional/classification/average_precision.py b/src/torchmetrics/functional/classification/average_precision.py index af6364ced35..8672423316b 100644 --- a/src/torchmetrics/functional/classification/average_precision.py +++ b/src/torchmetrics/functional/classification/average_precision.py @@ -562,7 +562,7 @@ def average_precision( target: Tensor, num_classes: Optional[int] = None, pos_label: Optional[int] = None, - average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro", + average: Optional[Literal["macro", "weighted", "none"]] = "macro", task: Optional[Literal["binary", "multiclass", "multilabel"]] = None, thresholds: Optional[Union[int, List[float], Tensor]] = None, num_labels: Optional[int] = None, @@ -593,8 +593,6 @@ def average_precision( - ``'macro'`` [default]: Calculate the metric for each class separately, and average the metrics across classes (with equal weights for each class). - - ``'micro'``: Calculate the metric globally, across all samples and classes. Cannot be - used with multiclass input. - ``'weighted'``: Calculate the metric for each class separately, and average the metrics across classes, weighting each class by its support. - ``'none'`` or ``None``: Calculate the metric for each class separately, and return From 8aba687a83bae30e4355f9764e8b142a7bd6b092 Mon Sep 17 00:00:00 2001 From: Jirka Date: Tue, 13 Sep 2022 02:31:13 +0200 Subject: [PATCH 09/10] fixing --- src/torchmetrics/functional/classification/auroc.py | 3 +-- .../functional/classification/average_precision.py | 4 +++- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/src/torchmetrics/functional/classification/auroc.py b/src/torchmetrics/functional/classification/auroc.py index 40bd2f4ba3a..610293a59d4 100644 --- a/src/torchmetrics/functional/classification/auroc.py +++ b/src/torchmetrics/functional/classification/auroc.py @@ -604,7 +604,7 @@ def auroc( target: Tensor, num_classes: Optional[int] = None, pos_label: Optional[int] = None, - average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro", + average: Optional[Literal["macro", "weighted", "none"]] = "macro", max_fpr: Optional[float] = None, sample_weights: Optional[Sequence] = None, task: Optional[Literal["binary", "multiclass", "multilabel"]] = None, @@ -644,7 +644,6 @@ def auroc( range [0,num_classes-1] average: - - ``'micro'`` computes metric globally. Only works for multilabel problems - ``'macro'`` computes metric for each class and uniformly averages them - ``'weighted'`` computes metric for each class and does a weighted-average, where each class is weighted by their support (accounts for class imbalance) diff --git a/src/torchmetrics/functional/classification/average_precision.py b/src/torchmetrics/functional/classification/average_precision.py index 8672423316b..ddfeff1df1e 100644 --- a/src/torchmetrics/functional/classification/average_precision.py +++ b/src/torchmetrics/functional/classification/average_precision.py @@ -628,7 +628,9 @@ def average_precision( ) if task == "multilabel": assert isinstance(num_labels, int) - return multilabel_average_precision(preds, target, num_labels, thresholds, ignore_index, validate_args) + return multilabel_average_precision( + preds, target, num_labels, average, thresholds, ignore_index, validate_args + ) raise ValueError( f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" ) From c6745463cdca179ceaeee07d5354ac99a36a96ad Mon Sep 17 00:00:00 2001 From: Jirka Date: Tue, 13 Sep 2022 02:43:10 +0200 Subject: [PATCH 10/10] imports --- src/torchmetrics/classification/accuracy.py | 2 +- src/torchmetrics/classification/f_beta.py | 2 +- src/torchmetrics/classification/jaccard.py | 2 +- src/torchmetrics/classification/precision_recall.py | 2 +- src/torchmetrics/classification/specificity.py | 2 +- 5 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/torchmetrics/classification/accuracy.py b/src/torchmetrics/classification/accuracy.py index 3676f684fb4..40ba7197a7f 100644 --- a/src/torchmetrics/classification/accuracy.py +++ b/src/torchmetrics/classification/accuracy.py @@ -17,7 +17,6 @@ from torch import Tensor, tensor from typing_extensions import Literal -from torchmetrics import Metric from torchmetrics.functional.classification.accuracy import ( _accuracy_compute, _accuracy_reduce, @@ -27,6 +26,7 @@ _subset_accuracy_compute, _subset_accuracy_update, ) +from torchmetrics.metric import Metric from torchmetrics.utilities.enums import AverageMethod, DataType from torchmetrics.classification.stat_scores import ( # isort:skip diff --git a/src/torchmetrics/classification/f_beta.py b/src/torchmetrics/classification/f_beta.py index 9de586e7410..7a8f458b474 100644 --- a/src/torchmetrics/classification/f_beta.py +++ b/src/torchmetrics/classification/f_beta.py @@ -17,7 +17,6 @@ from torch import Tensor from typing_extensions import Literal -from torchmetrics import Metric from torchmetrics.classification.stat_scores import ( BinaryStatScores, MulticlassStatScores, @@ -31,6 +30,7 @@ _multiclass_fbeta_score_arg_validation, _multilabel_fbeta_score_arg_validation, ) +from torchmetrics.metric import Metric from torchmetrics.utilities.enums import AverageMethod diff --git a/src/torchmetrics/classification/jaccard.py b/src/torchmetrics/classification/jaccard.py index 8bcc1b4f637..c748bb68c67 100644 --- a/src/torchmetrics/classification/jaccard.py +++ b/src/torchmetrics/classification/jaccard.py @@ -17,7 +17,6 @@ from torch import Tensor from typing_extensions import Literal -from torchmetrics import Metric from torchmetrics.classification import BinaryConfusionMatrix, MulticlassConfusionMatrix, MultilabelConfusionMatrix from torchmetrics.classification.confusion_matrix import ConfusionMatrix from torchmetrics.functional.classification.jaccard import ( @@ -26,6 +25,7 @@ _multiclass_jaccard_index_arg_validation, _multilabel_jaccard_index_arg_validation, ) +from torchmetrics.metric import Metric class BinaryJaccardIndex(BinaryConfusionMatrix): diff --git a/src/torchmetrics/classification/precision_recall.py b/src/torchmetrics/classification/precision_recall.py index a1b324522c0..191c97b50ed 100644 --- a/src/torchmetrics/classification/precision_recall.py +++ b/src/torchmetrics/classification/precision_recall.py @@ -17,7 +17,6 @@ from torch import Tensor from typing_extensions import Literal -from torchmetrics import Metric from torchmetrics.classification.stat_scores import ( BinaryStatScores, MulticlassStatScores, @@ -29,6 +28,7 @@ _precision_recall_reduce, _recall_compute, ) +from torchmetrics.metric import Metric from torchmetrics.utilities.enums import AverageMethod diff --git a/src/torchmetrics/classification/specificity.py b/src/torchmetrics/classification/specificity.py index 5304273196d..adf212652c5 100644 --- a/src/torchmetrics/classification/specificity.py +++ b/src/torchmetrics/classification/specificity.py @@ -17,7 +17,6 @@ from torch import Tensor from typing_extensions import Literal -from torchmetrics import Metric from torchmetrics.classification.stat_scores import ( BinaryStatScores, MulticlassStatScores, @@ -25,6 +24,7 @@ StatScores, ) from torchmetrics.functional.classification.specificity import _specificity_compute, _specificity_reduce +from torchmetrics.metric import Metric from torchmetrics.utilities.enums import AverageMethod