diff --git a/src/torchmetrics/classification/dice.py b/src/torchmetrics/classification/dice.py index bd9c42219cb..770830e19ba 100644 --- a/src/torchmetrics/classification/dice.py +++ b/src/torchmetrics/classification/dice.py @@ -56,7 +56,7 @@ class Dice(Metric): Args: num_classes: - Number of classes. Necessary for ``'macro'``, ``'weighted'`` and ``None`` average methods. + Number of classes. Necessary for ``'macro'``, and ``None`` average methods. threshold: Threshold for transforming probability or logit predictions to binary (0,1) predictions, in the case of binary or multi-label inputs. Default value of 0.5 corresponds to input being probabilities. @@ -116,7 +116,7 @@ class Dice(Metric): Raises: ValueError: - If ``average`` is none of ``"micro"``, ``"macro"``, ``"weighted"``, ``"samples"``, ``"none"``, ``None``. + If ``average`` is none of ``"micro"``, ``"macro"``, ``"samples"``, ``"none"``, ``None``. ValueError: If ``mdmc_average`` is not one of ``None``, ``"samplewise"``, ``"global"``. ValueError: @@ -146,7 +146,7 @@ def __init__( zero_division: int = 0, num_classes: Optional[int] = None, threshold: float = 0.5, - average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro", + average: Optional[Literal["micro", "macro", "none"]] = "micro", mdmc_average: Optional[str] = "global", ignore_index: Optional[int] = None, top_k: Optional[int] = None, @@ -154,7 +154,7 @@ def __init__( **kwargs: Any, ) -> None: super().__init__(**kwargs) - allowed_average = ("micro", "macro", "weighted", "samples", "none", None) + allowed_average = ("micro", "macro", "samples", "none", None) if average not in allowed_average: raise ValueError(f"The `average` has to be one of {allowed_average}, got {average}.")