Skip to content

Commit

Permalink
assert
Browse files Browse the repository at this point in the history
  • Loading branch information
Borda committed Sep 12, 2022
1 parent 85fcb31 commit 29f4bad
Show file tree
Hide file tree
Showing 12 changed files with 102 additions and 58 deletions.
2 changes: 1 addition & 1 deletion src/torchmetrics/functional/classification/accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -625,7 +625,7 @@ def _subset_accuracy_compute(correct: Tensor, total: Tensor) -> Tensor:
def accuracy(
preds: Tensor,
target: Tensor,
average: Optional[str] = "micro",
average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro",
mdmc_average: Optional[str] = "global",
threshold: float = 0.5,
top_k: Optional[int] = None,
Expand Down
22 changes: 11 additions & 11 deletions src/torchmetrics/functional/classification/auroc.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@ def _reduce_auroc(
weights: Optional[Tensor] = None,
) -> Tensor:
"""Utility function for reducing multiple average precision score into one number."""
res = []
if isinstance(fpr, Tensor):
res = _auc_compute_without_check(fpr, tpr, 1.0, axis=1)
else:
Expand Down Expand Up @@ -96,7 +95,7 @@ def _binary_auroc_compute(
thresholds: Optional[Tensor],
max_fpr: Optional[float] = None,
pos_label: int = 1,
) -> Tuple[Tensor, Tensor, Tensor]:
) -> Union[Tensor, Tuple[Tensor, Tensor, Tensor]]:
fpr, tpr, _ = _binary_roc_compute(state, thresholds, pos_label)
if max_fpr is None or max_fpr == 1:
return _auc_compute_without_check(fpr, tpr, 1.0)
Expand Down Expand Up @@ -221,7 +220,7 @@ def multiclass_auroc(
thresholds: Optional[Union[int, List[float], Tensor]] = None,
ignore_index: Optional[int] = None,
validate_args: bool = True,
) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]:
) -> Tensor:
r"""
Compute Area Under the Receiver Operating Characteristic Curve (`ROC AUC`_) for multiclass tasks. The AUROC score
summarizes the ROC curve into an single number that describes the performance of a model for multiple
Expand Down Expand Up @@ -317,7 +316,7 @@ def _multilabel_auroc_compute(
average: Optional[Literal["micro", "macro", "weighted", "none"]],
thresholds: Optional[Tensor],
ignore_index: Optional[int] = None,
) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]:
) -> Union[Tuple[Tensor, Tensor, Tensor], Tensor]:
if average == "micro":
if isinstance(state, Tensor) and thresholds is not None:
return _binary_auroc_compute(state.sum(1), thresholds, max_fpr=None)
Expand All @@ -328,7 +327,7 @@ def _multilabel_auroc_compute(
idx = target == ignore_index
preds = preds[~idx]
target = target[~idx]
return _binary_auroc_compute([preds, target], thresholds, max_fpr=None)
return _binary_auroc_compute((preds, target), thresholds, max_fpr=None)

else:
fpr, tpr, _ = _multilabel_roc_compute(state, num_labels, thresholds, ignore_index)
Expand Down Expand Up @@ -605,15 +604,15 @@ def auroc(
target: Tensor,
num_classes: Optional[int] = None,
pos_label: Optional[int] = None,
average: Optional[str] = "macro",
average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro",
max_fpr: Optional[float] = None,
sample_weights: Optional[Sequence] = None,
task: Optional[Literal["binary", "multiclass", "multilabel"]] = None,
thresholds: Optional[Union[int, List[float], Tensor]] = None,
num_labels: Optional[int] = None,
ignore_index: Optional[int] = None,
validate_args: bool = True,
) -> Tensor:
) -> Union[Tensor, Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]:
r"""
.. note::
From v0.10 an `'binary_*'`, `'multiclass_*', `'multilabel_*'` version now exist of each classification
Expand Down Expand Up @@ -686,13 +685,14 @@ def auroc(
tensor(0.7778)
"""
if task is not None:
kwargs = dict(thresholds=thresholds, ignore_index=ignore_index, validate_args=validate_args)
if task == "binary":
return binary_auroc(preds, target, max_fpr, **kwargs)
return binary_auroc(preds, target, max_fpr, thresholds, ignore_index, validate_args)
if task == "multiclass":
return multiclass_auroc(preds, target, num_classes, average, **kwargs)
assert isinstance(num_classes, int)
return multiclass_auroc(preds, target, num_classes, average, thresholds, ignore_index, validate_args)
if task == "multilabel":
return multilabel_auroc(preds, target, num_labels, average, **kwargs)
assert isinstance(num_labels, int)
return multilabel_auroc(preds, target, num_labels, average, thresholds, ignore_index, validate_args)
raise ValueError(
f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}"
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -562,7 +562,7 @@ def average_precision(
target: Tensor,
num_classes: Optional[int] = None,
pos_label: Optional[int] = None,
average: Optional[str] = "macro",
average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro",
task: Optional[Literal["binary", "multiclass", "multilabel"]] = None,
thresholds: Optional[Union[int, List[float], Tensor]] = None,
num_labels: Optional[int] = None,
Expand Down
7 changes: 3 additions & 4 deletions src/torchmetrics/functional/classification/cohen_kappa.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ def cohen_kappa(
preds: Tensor,
target: Tensor,
num_classes: int,
weights: Optional[str] = None,
weights: Optional[Literal["linear", "quadratic", "none"]] = None,
threshold: float = 0.5,
task: Optional[Literal["binary", "multiclass", "multilabel"]] = None,
ignore_index: Optional[int] = None,
Expand Down Expand Up @@ -335,11 +335,10 @@ class labels.
tensor(0.5000)
"""
if task is not None:
kwargs = dict(weights=weights, ignore_index=ignore_index, validate_args=validate_args)
if task == "binary":
return binary_cohen_kappa(preds, target, threshold, **kwargs)
return binary_cohen_kappa(preds, target, threshold, weights, ignore_index, validate_args)
if task == "multiclass":
return multiclass_cohen_kappa(preds, target, num_classes, **kwargs)
return multiclass_cohen_kappa(preds, target, num_classes, weights, ignore_index, validate_args)
raise ValueError(
f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}"
)
Expand Down
19 changes: 13 additions & 6 deletions src/torchmetrics/functional/classification/f_beta.py
Original file line number Diff line number Diff line change
Expand Up @@ -794,7 +794,7 @@ def fbeta_score(
preds: Tensor,
target: Tensor,
beta: float = 1.0,
average: Optional[str] = "micro",
average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro",
mdmc_average: Optional[str] = None,
ignore_index: Optional[int] = None,
num_classes: Optional[int] = None,
Expand Down Expand Up @@ -962,7 +962,7 @@ def f1_score(
preds: Tensor,
target: Tensor,
beta: float = 1.0,
average: Optional[str] = "micro",
average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro",
mdmc_average: Optional[str] = None,
ignore_index: Optional[int] = None,
num_classes: Optional[int] = None,
Expand Down Expand Up @@ -1076,13 +1076,20 @@ def f1_score(
tensor(0.3333)
"""
if task is not None:
kwargs = dict(multidim_average=multidim_average, ignore_index=ignore_index, validate_args=validate_args)
assert multidim_average is not None
if task == "binary":
return binary_f1_score(preds, target, threshold, **kwargs)
return binary_f1_score(preds, target, threshold, multidim_average, ignore_index, validate_args)
if task == "multiclass":
return multiclass_f1_score(preds, target, num_classes, average, top_k, **kwargs)
assert isinstance(num_classes, int)
assert isinstance(top_k, int)
return multiclass_f1_score(
preds, target, num_classes, average, top_k, multidim_average, ignore_index, validate_args
)
if task == "multilabel":
return multilabel_f1_score(preds, target, num_labels, threshold, average, **kwargs)
assert isinstance(num_labels, int)
return multilabel_f1_score(
preds, target, num_labels, threshold, average, multidim_average, ignore_index, validate_args
)
raise ValueError(
f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}"
)
Expand Down
17 changes: 12 additions & 5 deletions src/torchmetrics/functional/classification/hamming.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,7 +437,7 @@ def hamming_distance(
task: Optional[Literal["binary", "multiclass", "multilabel"]] = None,
num_classes: Optional[int] = None,
num_labels: Optional[int] = None,
average: Optional[str] = "macro",
average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro",
top_k: int = 1,
multidim_average: Optional[Literal["global", "samplewise"]] = "global",
ignore_index: Optional[int] = None,
Expand Down Expand Up @@ -482,13 +482,20 @@ def hamming_distance(
tensor(0.2500)
"""
if task is not None:
kwargs = dict(multidim_average=multidim_average, ignore_index=ignore_index, validate_args=validate_args)
assert multidim_average is not None
if task == "binary":
return binary_hamming_distance(preds, target, threshold, **kwargs)
return binary_hamming_distance(preds, target, threshold, multidim_average, ignore_index, validate_args)
if task == "multiclass":
return multiclass_hamming_distance(preds, target, num_classes, average, top_k, **kwargs)
assert isinstance(num_classes, int)
assert isinstance(top_k, int)
return multiclass_hamming_distance(
preds, target, num_classes, average, top_k, multidim_average, ignore_index, validate_args
)
if task == "multilabel":
return multilabel_hamming_distance(preds, target, num_labels, threshold, average, **kwargs)
assert isinstance(num_labels, int)
return multilabel_hamming_distance(
preds, target, num_labels, threshold, average, multidim_average, ignore_index, validate_args
)
raise ValueError(
f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}"
)
Expand Down
2 changes: 1 addition & 1 deletion src/torchmetrics/functional/classification/jaccard.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,7 +378,7 @@ def jaccard_index(
preds: Tensor,
target: Tensor,
num_classes: int,
average: Optional[str] = "macro",
average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro",
ignore_index: Optional[int] = None,
absent_score: float = 0.0,
threshold: float = 0.5,
Expand Down
36 changes: 25 additions & 11 deletions src/torchmetrics/functional/classification/precision_recall.py
Original file line number Diff line number Diff line change
Expand Up @@ -712,7 +712,7 @@ def _precision_compute(
def precision(
preds: Tensor,
target: Tensor,
average: Optional[str] = "micro",
average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro",
mdmc_average: Optional[str] = None,
ignore_index: Optional[int] = None,
num_classes: Optional[int] = None,
Expand Down Expand Up @@ -833,13 +833,20 @@ def precision(
"""
if task is not None:
kwargs = dict(multidim_average=multidim_average, ignore_index=ignore_index, validate_args=validate_args)
assert multidim_average is not None
if task == "binary":
return binary_precision(preds, target, threshold, **kwargs)
return binary_precision(preds, target, threshold, multidim_average, ignore_index, validate_args)
if task == "multiclass":
return multiclass_precision(preds, target, num_classes, average, top_k, **kwargs)
assert isinstance(num_classes, int)
assert isinstance(top_k, int)
return multiclass_precision(
preds, target, num_classes, average, top_k, multidim_average, ignore_index, validate_args
)
if task == "multilabel":
return multilabel_precision(preds, target, num_labels, threshold, average, **kwargs)
assert isinstance(num_labels, int)
return multilabel_precision(
preds, target, num_labels, threshold, average, multidim_average, ignore_index, validate_args
)
raise ValueError(
f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}"
)
Expand Down Expand Up @@ -936,7 +943,7 @@ def _recall_compute(
def recall(
preds: Tensor,
target: Tensor,
average: Optional[str] = "micro",
average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro",
mdmc_average: Optional[str] = None,
ignore_index: Optional[int] = None,
num_classes: Optional[int] = None,
Expand Down Expand Up @@ -1058,13 +1065,20 @@ def recall(
"""
if task is not None:
kwargs = dict(multidim_average=multidim_average, ignore_index=ignore_index, validate_args=validate_args)
assert multidim_average is not None
if task == "binary":
return binary_recall(preds, target, threshold, **kwargs)
return binary_recall(preds, target, threshold, multidim_average, ignore_index, validate_args)
if task == "multiclass":
return multiclass_recall(preds, target, num_classes, average, top_k, **kwargs)
assert isinstance(num_classes, int)
assert isinstance(top_k, int)
return multiclass_recall(
preds, target, num_classes, average, top_k, multidim_average, ignore_index, validate_args
)
if task == "multilabel":
return multilabel_recall(preds, target, num_labels, threshold, average, **kwargs)
assert isinstance(num_labels, int)
return multilabel_recall(
preds, target, num_labels, threshold, average, multidim_average, ignore_index, validate_args
)
raise ValueError(
f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}"
)
Expand Down Expand Up @@ -1110,7 +1124,7 @@ def recall(
def precision_recall(
preds: Tensor,
target: Tensor,
average: Optional[str] = "micro",
average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro",
mdmc_average: Optional[str] = None,
ignore_index: Optional[int] = None,
num_classes: Optional[int] = None,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1053,13 +1053,15 @@ def precision_recall_curve(
[tensor([0.7500]), tensor([0.7500]), tensor([0.0500, 0.7500]), tensor([0.0500, 0.7500]), tensor([0.0500])]
"""
if task is not None:
kwargs = dict(thresholds=thresholds, ignore_index=ignore_index, validate_args=validate_args)
if task == "binary":
return binary_precision_recall_curve(preds, target, **kwargs)
return binary_precision_recall_curve(preds, target, thresholds, ignore_index, validate_args)
if task == "multiclass":
return multiclass_precision_recall_curve(preds, target, num_classes, **kwargs)
assert isinstance(num_classes, int)
return multiclass_precision_recall_curve(
preds, target, num_classes, thresholds, ignore_index, validate_args
)
if task == "multilabel":
return multilabel_precision_recall_curve(preds, target, num_labels, **kwargs)
return multilabel_precision_recall_curve(preds, target, num_labels, thresholds, ignore_index, validate_args)
raise ValueError(
f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}"
)
Expand Down
9 changes: 5 additions & 4 deletions src/torchmetrics/functional/classification/roc.py
Original file line number Diff line number Diff line change
Expand Up @@ -696,13 +696,14 @@ def roc(
tensor([1.1837, 0.1837, 0.1338, 0.1183, 0.1138])]
"""
if task is not None:
kwargs = dict(thresholds=thresholds, ignore_index=ignore_index, validate_args=validate_args)
if task == "binary":
return binary_roc(preds, target, **kwargs)
return binary_roc(preds, target, thresholds, ignore_index, validate_args)
if task == "multiclass":
return multiclass_roc(preds, target, num_classes, **kwargs)
assert isinstance(num_classes, int)
return multiclass_roc(preds, target, num_classes, thresholds, ignore_index, validate_args)
if task == "multilabel":
return multilabel_roc(preds, target, num_labels, **kwargs)
assert isinstance(num_labels, int)
return multilabel_roc(preds, target, num_labels, thresholds, ignore_index, validate_args)
raise ValueError(
f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}"
)
Expand Down
17 changes: 12 additions & 5 deletions src/torchmetrics/functional/classification/specificity.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,7 @@ def _specificity_compute(
def specificity(
preds: Tensor,
target: Tensor,
average: Optional[str] = "micro",
average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro",
mdmc_average: Optional[str] = None,
ignore_index: Optional[int] = None,
num_classes: Optional[int] = None,
Expand Down Expand Up @@ -529,13 +529,20 @@ def specificity(
tensor(0.6250)
"""
if task is not None:
kwargs = dict(multidim_average=multidim_average, ignore_index=ignore_index, validate_args=validate_args)
assert multidim_average is not None
if task == "binary":
return binary_specificity(preds, target, threshold, **kwargs)
return binary_specificity(preds, target, threshold, multidim_average, ignore_index, validate_args)
if task == "multiclass":
return multiclass_specificity(preds, target, num_classes, average, top_k, **kwargs)
assert isinstance(num_classes, int)
assert isinstance(top_k, int)
return multiclass_specificity(
preds, target, num_classes, average, top_k, multidim_average, ignore_index, validate_args
)
if task == "multilabel":
return multilabel_specificity(preds, target, num_labels, threshold, average, **kwargs)
assert isinstance(num_labels, int)
return multilabel_specificity(
preds, target, num_labels, threshold, average, multidim_average, ignore_index, validate_args
)
raise ValueError(
f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}"
)
Expand Down
17 changes: 12 additions & 5 deletions src/torchmetrics/functional/classification/stat_scores.py
Original file line number Diff line number Diff line change
Expand Up @@ -1095,7 +1095,7 @@ def stat_scores(
ignore_index: Optional[int] = None,
task: Optional[Literal["binary", "multiclass", "multilabel"]] = None,
num_labels: Optional[int] = None,
average: Optional[str] = "micro",
average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro",
multidim_average: Optional[Literal["global", "samplewise"]] = "global",
validate_args: bool = True,
) -> Tensor:
Expand Down Expand Up @@ -1225,13 +1225,20 @@ def stat_scores(
"""
if task is not None:
kwargs = dict(multidim_average=multidim_average, ignore_index=ignore_index, validate_args=validate_args)
assert multidim_average is not None
if task == "binary":
return binary_stat_scores(preds, target, threshold, **kwargs)
return binary_stat_scores(preds, target, threshold, multidim_average, ignore_index, validate_args)
if task == "multiclass":
return multiclass_stat_scores(preds, target, num_classes, average, top_k, **kwargs)
assert isinstance(num_classes, int)
assert isinstance(top_k, int)
return multiclass_stat_scores(
preds, target, num_classes, average, top_k, multidim_average, ignore_index, validate_args
)
if task == "multilabel":
return multilabel_stat_scores(preds, target, num_labels, threshold, average, **kwargs)
assert isinstance(num_labels, int)
return multilabel_stat_scores(
preds, target, num_labels, threshold, average, multidim_average, ignore_index, validate_args
)
raise ValueError(
f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}"
)
Expand Down

0 comments on commit 29f4bad

Please sign in to comment.