Skip to content

Commit

Permalink
Fix the dice metric. Incorrect anotation (#1684)
Browse files Browse the repository at this point in the history
  • Loading branch information
YeaMerci authored Apr 4, 2023
1 parent 714a8ca commit 2509448
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions src/torchmetrics/classification/dice.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ class Dice(Metric):
Args:
num_classes:
Number of classes. Necessary for ``'macro'``, ``'weighted'`` and ``None`` average methods.
Number of classes. Necessary for ``'macro'``, and ``None`` average methods.
threshold:
Threshold for transforming probability or logit predictions to binary (0,1) predictions, in the case
of binary or multi-label inputs. Default value of 0.5 corresponds to input being probabilities.
Expand Down Expand Up @@ -116,7 +116,7 @@ class Dice(Metric):
Raises:
ValueError:
If ``average`` is none of ``"micro"``, ``"macro"``, ``"weighted"``, ``"samples"``, ``"none"``, ``None``.
If ``average`` is none of ``"micro"``, ``"macro"``, ``"samples"``, ``"none"``, ``None``.
ValueError:
If ``mdmc_average`` is not one of ``None``, ``"samplewise"``, ``"global"``.
ValueError:
Expand Down Expand Up @@ -146,15 +146,15 @@ def __init__(
zero_division: int = 0,
num_classes: Optional[int] = None,
threshold: float = 0.5,
average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro",
average: Optional[Literal["micro", "macro", "none"]] = "micro",
mdmc_average: Optional[str] = "global",
ignore_index: Optional[int] = None,
top_k: Optional[int] = None,
multiclass: Optional[bool] = None,
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
allowed_average = ("micro", "macro", "weighted", "samples", "none", None)
allowed_average = ("micro", "macro", "samples", "none", None)
if average not in allowed_average:
raise ValueError(f"The `average` has to be one of {allowed_average}, got {average}.")

Expand Down

0 comments on commit 2509448

Please sign in to comment.