diff --git a/pyproject.toml b/pyproject.toml index 735e7e554aa..84e8cc4adca 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -109,7 +109,6 @@ unfixable = ["F401"] "__init__.py" = ["D100"] "src/**" = [ "D101", # todo # Missing docstring in public class - "D102", # todo # Missing docstring in public method "D103", # todo # Missing docstring in public function "D105", # todo # Missing docstring in magic method "D205", # todo # 1 blank line required between summary line and description diff --git a/src/torchmetrics/classification/accuracy.py b/src/torchmetrics/classification/accuracy.py index f68a05db6db..0cfebff56d1 100644 --- a/src/torchmetrics/classification/accuracy.py +++ b/src/torchmetrics/classification/accuracy.py @@ -491,6 +491,7 @@ def __new__( validate_args: bool = True, **kwargs: Any, ) -> Metric: + """Initialize task metric.""" task = ClassificationTask.from_str(task) kwargs.update( {"multidim_average": multidim_average, "ignore_index": ignore_index, "validate_args": validate_args} diff --git a/src/torchmetrics/classification/auroc.py b/src/torchmetrics/classification/auroc.py index 14bb0b2e953..0f09abbabd4 100644 --- a/src/torchmetrics/classification/auroc.py +++ b/src/torchmetrics/classification/auroc.py @@ -108,6 +108,7 @@ def __init__( self.max_fpr = max_fpr def compute(self) -> Tensor: + """Compute metric.""" state = [dim_zero_cat(self.preds), dim_zero_cat(self.target)] if self.thresholds is None else self.confmat return _binary_auroc_compute(state, self.thresholds, self.max_fpr) @@ -207,6 +208,7 @@ def __init__( self.validate_args = validate_args def compute(self) -> Tensor: + """Compute metric.""" state = [dim_zero_cat(self.preds), dim_zero_cat(self.target)] if self.thresholds is None else self.confmat return _multiclass_auroc_compute(state, self.num_classes, self.average, self.thresholds) @@ -308,6 +310,7 @@ def __init__( self.validate_args = validate_args def compute(self) -> Tensor: + """Compute metric.""" state = [dim_zero_cat(self.preds), dim_zero_cat(self.target)] if self.thresholds is None else self.confmat return _multilabel_auroc_compute(state, self.num_labels, self.average, self.thresholds, self.ignore_index) @@ -353,6 +356,7 @@ def __new__( validate_args: bool = True, **kwargs: Any, ) -> Metric: + """Initialize task metric.""" task = ClassificationTask.from_str(task) kwargs.update({"thresholds": thresholds, "ignore_index": ignore_index, "validate_args": validate_args}) if task == ClassificationTask.BINARY: diff --git a/src/torchmetrics/classification/average_precision.py b/src/torchmetrics/classification/average_precision.py index c50d4708bf4..15deb00349a 100644 --- a/src/torchmetrics/classification/average_precision.py +++ b/src/torchmetrics/classification/average_precision.py @@ -98,6 +98,7 @@ class BinaryAveragePrecision(BinaryPrecisionRecallCurve): full_state_update: bool = False def compute(self) -> Tensor: + """Compute metric.""" state = [dim_zero_cat(self.preds), dim_zero_cat(self.target)] if self.thresholds is None else self.confmat return _binary_average_precision_compute(state, self.thresholds) @@ -201,6 +202,7 @@ def __init__( self.validate_args = validate_args def compute(self) -> Tensor: + """Compute metric.""" state = [dim_zero_cat(self.preds), dim_zero_cat(self.target)] if self.thresholds is None else self.confmat return _multiclass_average_precision_compute(state, self.num_classes, self.average, self.thresholds) @@ -307,6 +309,7 @@ def __init__( self.validate_args = validate_args def compute(self) -> Tensor: + """Compute metric.""" state = [dim_zero_cat(self.preds), dim_zero_cat(self.target)] if self.thresholds is None else self.confmat return _multilabel_average_precision_compute( state, self.num_labels, self.average, self.thresholds, self.ignore_index @@ -357,6 +360,7 @@ def __new__( validate_args: bool = True, **kwargs: Any, ) -> Metric: + """Initialize task metric.""" task = ClassificationTask.from_str(task) kwargs.update({"thresholds": thresholds, "ignore_index": ignore_index, "validate_args": validate_args}) if task == ClassificationTask.BINARY: diff --git a/src/torchmetrics/classification/calibration_error.py b/src/torchmetrics/classification/calibration_error.py index 3518870d7fa..082efa48b50 100644 --- a/src/torchmetrics/classification/calibration_error.py +++ b/src/torchmetrics/classification/calibration_error.py @@ -114,6 +114,7 @@ def __init__( self.add_state("accuracies", [], dist_reduce_fx="cat") def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore + """Update metric states with predictions and targets.""" if self.validate_args: _binary_calibration_error_tensor_validation(preds, target, self.ignore_index) preds, target = _binary_confusion_matrix_format( @@ -124,6 +125,7 @@ def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore self.accuracies.append(accuracies) def compute(self) -> Tensor: + """Compute metric.""" confidences = dim_zero_cat(self.confidences) accuracies = dim_zero_cat(self.accuracies) return _ce_compute(confidences, accuracies, self.n_bins, norm=self.norm) @@ -217,6 +219,7 @@ def __init__( self.add_state("accuracies", [], dist_reduce_fx="cat") def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore + """Update metric states with predictions and targets.""" if self.validate_args: _multiclass_calibration_error_tensor_validation(preds, target, self.num_classes, self.ignore_index) preds, target = _multiclass_confusion_matrix_format( @@ -227,6 +230,7 @@ def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore self.accuracies.append(accuracies) def compute(self) -> Tensor: + """Compute metric.""" confidences = dim_zero_cat(self.confidences) accuracies = dim_zero_cat(self.accuracies) return _ce_compute(confidences, accuracies, self.n_bins, norm=self.norm) @@ -268,6 +272,7 @@ def __new__( validate_args: bool = True, **kwargs: Any, ) -> Metric: + """Initialize task metric.""" task = ClassificationTaskNoMultilabel.from_str(task) kwargs.update({"n_bins": n_bins, "norm": norm, "ignore_index": ignore_index, "validate_args": validate_args}) if task == ClassificationTaskNoMultilabel.BINARY: diff --git a/src/torchmetrics/classification/cohen_kappa.py b/src/torchmetrics/classification/cohen_kappa.py index 524d8ff5d7d..65f87fe19ac 100644 --- a/src/torchmetrics/classification/cohen_kappa.py +++ b/src/torchmetrics/classification/cohen_kappa.py @@ -102,6 +102,7 @@ def __init__( self.validate_args = validate_args def compute(self) -> Tensor: + """Compute metric.""" return _cohen_kappa_reduce(self.confmat, self.weights) @@ -184,6 +185,7 @@ def __init__( self.validate_args = validate_args def compute(self) -> Tensor: + """Compute metric.""" return _cohen_kappa_reduce(self.confmat, self.weights) @@ -222,6 +224,7 @@ def __new__( validate_args: bool = True, **kwargs: Any, ) -> Metric: + """Initialize task metric.""" task = ClassificationTaskNoMultilabel.from_str(task) kwargs.update({"weights": weights, "ignore_index": ignore_index, "validate_args": validate_args}) if task == ClassificationTaskNoMultilabel.BINARY: diff --git a/src/torchmetrics/classification/confusion_matrix.py b/src/torchmetrics/classification/confusion_matrix.py index 2f9277d76a1..6f21d52ae5f 100644 --- a/src/torchmetrics/classification/confusion_matrix.py +++ b/src/torchmetrics/classification/confusion_matrix.py @@ -398,6 +398,7 @@ def __new__( validate_args: bool = True, **kwargs: Any, ) -> Metric: + """Initialize task metric.""" task = ClassificationTask.from_str(task) kwargs.update({"normalize": normalize, "ignore_index": ignore_index, "validate_args": validate_args}) if task == ClassificationTask.BINARY: diff --git a/src/torchmetrics/classification/exact_match.py b/src/torchmetrics/classification/exact_match.py index 82dfa3e8747..2a99048d911 100644 --- a/src/torchmetrics/classification/exact_match.py +++ b/src/torchmetrics/classification/exact_match.py @@ -118,6 +118,7 @@ def __init__( ) def update(self, preds, target) -> None: + """Update metric states with predictions and targets.""" if self.validate_args: _multiclass_stat_scores_tensor_validation( preds, target, self.num_classes, self.multidim_average, self.ignore_index @@ -132,6 +133,7 @@ def update(self, preds, target) -> None: self.total += total def compute(self) -> Tensor: + """Compute metric.""" correct = dim_zero_cat(self.correct) if isinstance(self.correct, list) else self.correct return _exact_match_reduce(correct, self.total) @@ -250,6 +252,7 @@ def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore self.total += total def compute(self) -> Tensor: + """Compute metric.""" correct = dim_zero_cat(self.correct) if isinstance(self.correct, list) else self.correct return _exact_match_reduce(correct, self.total) @@ -289,6 +292,7 @@ def __new__( validate_args: bool = True, **kwargs: Any, ) -> Metric: + """Initialize task metric.""" task = ClassificationTaskNoBinary.from_str(task) kwargs.update( {"multidim_average": multidim_average, "ignore_index": ignore_index, "validate_args": validate_args} diff --git a/src/torchmetrics/classification/f_beta.py b/src/torchmetrics/classification/f_beta.py index 4dd02190e27..c24e4261673 100644 --- a/src/torchmetrics/classification/f_beta.py +++ b/src/torchmetrics/classification/f_beta.py @@ -117,6 +117,7 @@ def __init__( self.beta = beta def compute(self) -> Tensor: + """Compute metric.""" tp, fp, tn, fn = self._final_state() return _fbeta_reduce(tp, fp, tn, fn, self.beta, average="binary", multidim_average=self.multidim_average) @@ -245,6 +246,7 @@ def __init__( self.beta = beta def compute(self) -> Tensor: + """Compute metric.""" tp, fp, tn, fn = self._final_state() return _fbeta_reduce(tp, fp, tn, fn, self.beta, average=self.average, multidim_average=self.multidim_average) @@ -369,6 +371,7 @@ def __init__( self.beta = beta def compute(self) -> Tensor: + """Compute metric.""" tp, fp, tn, fn = self._final_state() return _fbeta_reduce(tp, fp, tn, fn, self.beta, average=self.average, multidim_average=self.multidim_average) @@ -726,6 +729,7 @@ def __new__( validate_args: bool = True, **kwargs: Any, ) -> Metric: + """Initialize task metric.""" assert multidim_average is not None kwargs.update( {"multidim_average": multidim_average, "ignore_index": ignore_index, "validate_args": validate_args} @@ -777,6 +781,7 @@ def __new__( validate_args: bool = True, **kwargs: Any, ) -> Metric: + """Initialize task metric.""" task = ClassificationTask.from_str(task) assert multidim_average is not None kwargs.update( diff --git a/src/torchmetrics/classification/hamming.py b/src/torchmetrics/classification/hamming.py index ccdd13400f8..3d911c063d9 100644 --- a/src/torchmetrics/classification/hamming.py +++ b/src/torchmetrics/classification/hamming.py @@ -94,6 +94,7 @@ class BinaryHammingDistance(BinaryStatScores): full_state_update: bool = False def compute(self) -> Tensor: + """Compute metric.""" tp, fp, tn, fn = self._final_state() return _hamming_distance_reduce(tp, fp, tn, fn, average="binary", multidim_average=self.multidim_average) @@ -199,6 +200,7 @@ class MulticlassHammingDistance(MulticlassStatScores): full_state_update: bool = False def compute(self) -> Tensor: + """Compute metric.""" tp, fp, tn, fn = self._final_state() return _hamming_distance_reduce(tp, fp, tn, fn, average=self.average, multidim_average=self.multidim_average) @@ -302,6 +304,7 @@ class MultilabelHammingDistance(MultilabelStatScores): full_state_update: bool = False def compute(self) -> Tensor: + """Compute metric.""" tp, fp, tn, fn = self._final_state() return _hamming_distance_reduce( tp, fp, tn, fn, average=self.average, multidim_average=self.multidim_average, multilabel=True @@ -345,6 +348,7 @@ def __new__( validate_args: bool = True, **kwargs: Any, ) -> Metric: + """Initialize task metric.""" task = ClassificationTask.from_str(task) assert multidim_average is not None kwargs.update( diff --git a/src/torchmetrics/classification/hinge.py b/src/torchmetrics/classification/hinge.py index b68584e7f1b..247728f14ee 100644 --- a/src/torchmetrics/classification/hinge.py +++ b/src/torchmetrics/classification/hinge.py @@ -99,6 +99,7 @@ def __init__( self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum") def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore + """Update metric state.""" if self.validate_args: _binary_hinge_loss_tensor_validation(preds, target, self.ignore_index) preds, target = _binary_confusion_matrix_format( @@ -109,6 +110,7 @@ def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore self.total += total def compute(self) -> Tensor: + """Compute metric.""" return _hinge_loss_compute(self.measures, self.total) @@ -203,6 +205,7 @@ def __init__( self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum") def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore + """Update metric state.""" if self.validate_args: _multiclass_hinge_loss_tensor_validation(preds, target, self.num_classes, self.ignore_index) preds, target = _multiclass_confusion_matrix_format(preds, target, self.ignore_index, convert_to_labels=False) @@ -211,6 +214,7 @@ def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore self.total += total def compute(self) -> Tensor: + """Compute metric.""" return _hinge_loss_compute(self.measures, self.total) @@ -253,6 +257,7 @@ def __new__( validate_args: bool = True, **kwargs: Any, ) -> Metric: + """Initialize task metric.""" task = ClassificationTaskNoMultilabel.from_str(task) kwargs.update({"ignore_index": ignore_index, "validate_args": validate_args}) if task == ClassificationTaskNoMultilabel.BINARY: diff --git a/src/torchmetrics/classification/jaccard.py b/src/torchmetrics/classification/jaccard.py index ae1c37fd93c..fa9e9317897 100644 --- a/src/torchmetrics/classification/jaccard.py +++ b/src/torchmetrics/classification/jaccard.py @@ -89,6 +89,7 @@ def __init__( ) def compute(self) -> Tensor: + """Compute metric.""" return _jaccard_index_reduce(self.confmat, average="binary") @@ -172,6 +173,7 @@ def __init__( self.average = average def compute(self) -> Tensor: + """Compute metric.""" return _jaccard_index_reduce(self.confmat, average=self.average) @@ -259,6 +261,7 @@ def __init__( self.average = average def compute(self) -> Tensor: + """Compute metric.""" return _jaccard_index_reduce(self.confmat, average=self.average) @@ -296,6 +299,7 @@ def __new__( validate_args: bool = True, **kwargs: Any, ) -> Metric: + """Initialize task metric.""" task = ClassificationTask.from_str(task) kwargs.update({"ignore_index": ignore_index, "validate_args": validate_args}) if task == ClassificationTask.BINARY: diff --git a/src/torchmetrics/classification/matthews_corrcoef.py b/src/torchmetrics/classification/matthews_corrcoef.py index cceaa4e4b43..6084893b17d 100644 --- a/src/torchmetrics/classification/matthews_corrcoef.py +++ b/src/torchmetrics/classification/matthews_corrcoef.py @@ -80,6 +80,7 @@ def __init__( super().__init__(threshold, ignore_index, normalize=None, validate_args=validate_args, **kwargs) def compute(self) -> Tensor: + """Compute metric.""" return _matthews_corrcoef_reduce(self.confmat) @@ -144,6 +145,7 @@ def __init__( super().__init__(num_classes, ignore_index, normalize=None, validate_args=validate_args, **kwargs) def compute(self) -> Tensor: + """Compute metric.""" return _matthews_corrcoef_reduce(self.confmat) @@ -207,6 +209,7 @@ def __init__( super().__init__(num_labels, threshold, ignore_index, normalize=None, validate_args=validate_args, **kwargs) def compute(self) -> Tensor: + """Compute metric.""" return _matthews_corrcoef_reduce(self.confmat) @@ -238,6 +241,7 @@ def __new__( validate_args: bool = True, **kwargs: Any, ) -> Metric: + """Initialize task metric.""" task = ClassificationTask.from_str(task) kwargs.update({"ignore_index": ignore_index, "validate_args": validate_args}) if task == ClassificationTask.BINARY: diff --git a/src/torchmetrics/classification/precision_recall.py b/src/torchmetrics/classification/precision_recall.py index 300dc3364aa..bbe7c61e2ae 100644 --- a/src/torchmetrics/classification/precision_recall.py +++ b/src/torchmetrics/classification/precision_recall.py @@ -89,6 +89,7 @@ class BinaryPrecision(BinaryStatScores): full_state_update: bool = False def compute(self) -> Tensor: + """Compute metric.""" tp, fp, tn, fn = self._final_state() return _precision_recall_reduce( "precision", tp, fp, tn, fn, average="binary", multidim_average=self.multidim_average @@ -193,6 +194,7 @@ class MulticlassPrecision(MulticlassStatScores): full_state_update: bool = False def compute(self) -> Tensor: + """Compute metric.""" tp, fp, tn, fn = self._final_state() return _precision_recall_reduce( "precision", tp, fp, tn, fn, average=self.average, multidim_average=self.multidim_average @@ -295,6 +297,7 @@ class MultilabelPrecision(MultilabelStatScores): full_state_update: bool = False def compute(self) -> Tensor: + """Compute metric.""" tp, fp, tn, fn = self._final_state() return _precision_recall_reduce( "precision", tp, fp, tn, fn, average=self.average, multidim_average=self.multidim_average @@ -368,6 +371,7 @@ class BinaryRecall(BinaryStatScores): full_state_update: bool = False def compute(self) -> Tensor: + """Compute metric.""" tp, fp, tn, fn = self._final_state() return _precision_recall_reduce( "recall", tp, fp, tn, fn, average="binary", multidim_average=self.multidim_average @@ -472,6 +476,7 @@ class MulticlassRecall(MulticlassStatScores): full_state_update: bool = False def compute(self) -> Tensor: + """Compute metric.""" tp, fp, tn, fn = self._final_state() return _precision_recall_reduce( "recall", tp, fp, tn, fn, average=self.average, multidim_average=self.multidim_average @@ -573,6 +578,7 @@ class MultilabelRecall(MultilabelStatScores): full_state_update: bool = False def compute(self) -> Tensor: + """Compute metric.""" tp, fp, tn, fn = self._final_state() return _precision_recall_reduce( "recall", tp, fp, tn, fn, average=self.average, multidim_average=self.multidim_average @@ -617,6 +623,7 @@ def __new__( validate_args: bool = True, **kwargs: Any, ) -> Metric: + """Initialize task metric.""" assert multidim_average is not None kwargs.update( {"multidim_average": multidim_average, "ignore_index": ignore_index, "validate_args": validate_args} @@ -671,6 +678,7 @@ def __new__( validate_args: bool = True, **kwargs: Any, ) -> Metric: + """Initialize task metric.""" task = ClassificationTask.from_str(task) assert multidim_average is not None kwargs.update( diff --git a/src/torchmetrics/classification/precision_recall_curve.py b/src/torchmetrics/classification/precision_recall_curve.py index dafdea2e6cf..36fb608cff2 100644 --- a/src/torchmetrics/classification/precision_recall_curve.py +++ b/src/torchmetrics/classification/precision_recall_curve.py @@ -138,6 +138,7 @@ def __init__( ) def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore + """Update metric states.""" if self.validate_args: _binary_precision_recall_curve_tensor_validation(preds, target, self.ignore_index) preds, target, _ = _binary_precision_recall_curve_format(preds, target, self.thresholds, self.ignore_index) @@ -149,6 +150,7 @@ def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore self.target.append(state[1]) def compute(self) -> Tuple[Tensor, Tensor, Tensor]: + """Compute metric.""" state = [dim_zero_cat(self.preds), dim_zero_cat(self.target)] if self.thresholds is None else self.confmat return _binary_precision_recall_curve_compute(state, self.thresholds) @@ -264,6 +266,7 @@ def __init__( ) def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore + """Update metric states.""" if self.validate_args: _multiclass_precision_recall_curve_tensor_validation(preds, target, self.num_classes, self.ignore_index) preds, target, _ = _multiclass_precision_recall_curve_format( @@ -277,6 +280,7 @@ def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore self.target.append(state[1]) def compute(self) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]: + """Compute metric.""" state = [dim_zero_cat(self.preds), dim_zero_cat(self.target)] if self.thresholds is None else self.confmat return _multiclass_precision_recall_curve_compute(state, self.num_classes, self.thresholds) @@ -403,6 +407,7 @@ def __init__( ) def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore + """Update metric states.""" if self.validate_args: _multilabel_precision_recall_curve_tensor_validation(preds, target, self.num_labels, self.ignore_index) preds, target, _ = _multilabel_precision_recall_curve_format( @@ -416,6 +421,7 @@ def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore self.target.append(state[1]) def compute(self) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]: + """Compute metric.""" state = [dim_zero_cat(self.preds), dim_zero_cat(self.target)] if self.thresholds is None else self.confmat return _multilabel_precision_recall_curve_compute(state, self.num_labels, self.thresholds, self.ignore_index) @@ -467,6 +473,7 @@ def __new__( validate_args: bool = True, **kwargs: Any, ) -> Metric: + """Initialize task metric.""" task = ClassificationTask.from_str(task) kwargs.update({"thresholds": thresholds, "ignore_index": ignore_index, "validate_args": validate_args}) if task == ClassificationTask.BINARY: diff --git a/src/torchmetrics/classification/ranking.py b/src/torchmetrics/classification/ranking.py index a24a59866f4..0dcffb68f21 100644 --- a/src/torchmetrics/classification/ranking.py +++ b/src/torchmetrics/classification/ranking.py @@ -85,6 +85,7 @@ def __init__( self.add_state("total", torch.tensor(0.0), dist_reduce_fx="sum") def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore + """Update metric states.""" if self.validate_args: _multilabel_ranking_tensor_validation(preds, target, self.num_labels, self.ignore_index) preds, target = _multilabel_confusion_matrix_format( @@ -95,6 +96,7 @@ def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore self.total += n_elements def compute(self) -> Tensor: + """Compute metric.""" return _ranking_reduce(self.measure, self.total) @@ -156,6 +158,7 @@ def __init__( self.add_state("total", torch.tensor(0.0), dist_reduce_fx="sum") def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore + """Update metric states.""" if self.validate_args: _multilabel_ranking_tensor_validation(preds, target, self.num_labels, self.ignore_index) preds, target = _multilabel_confusion_matrix_format( @@ -166,6 +169,7 @@ def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore self.total += n_elements def compute(self) -> Tensor: + """Compute metric.""" return _ranking_reduce(self.measure, self.total) @@ -229,6 +233,7 @@ def __init__( self.add_state("total", torch.tensor(0.0), dist_reduce_fx="sum") def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore + """Update metric states.""" if self.validate_args: _multilabel_ranking_tensor_validation(preds, target, self.num_labels, self.ignore_index) preds, target = _multilabel_confusion_matrix_format( @@ -239,4 +244,5 @@ def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore self.total += n_elements def compute(self) -> Tensor: + """Compute metric.""" return _ranking_reduce(self.measure, self.total) diff --git a/src/torchmetrics/classification/recall_at_fixed_precision.py b/src/torchmetrics/classification/recall_at_fixed_precision.py index 188b71ce272..f32d118230e 100644 --- a/src/torchmetrics/classification/recall_at_fixed_precision.py +++ b/src/torchmetrics/classification/recall_at_fixed_precision.py @@ -111,6 +111,7 @@ def __init__( self.min_precision = min_precision def compute(self) -> Tuple[Tensor, Tensor]: # type: ignore[override] + """Compute metric.""" state = [dim_zero_cat(self.preds), dim_zero_cat(self.target)] if self.thresholds is None else self.confmat return _binary_recall_at_fixed_precision_compute(state, self.thresholds, self.min_precision) @@ -201,6 +202,7 @@ def __init__( self.min_precision = min_precision def compute(self) -> Tuple[Tensor, Tensor]: # type: ignore[override] + """Compute metric.""" state = [dim_zero_cat(self.preds), dim_zero_cat(self.target)] if self.thresholds is None else self.confmat return _multiclass_recall_at_fixed_precision_arg_compute( state, self.num_classes, self.thresholds, self.min_precision @@ -296,6 +298,7 @@ def __init__( self.min_precision = min_precision def compute(self) -> Tuple[Tensor, Tensor]: # type: ignore[override] + """Compute metric.""" state = [dim_zero_cat(self.preds), dim_zero_cat(self.target)] if self.thresholds is None else self.confmat return _multilabel_recall_at_fixed_precision_arg_compute( state, self.num_labels, self.thresholds, self.ignore_index, self.min_precision @@ -324,6 +327,7 @@ def __new__( validate_args: bool = True, **kwargs: Any, ) -> Metric: + """Initialize task metric.""" task = ClassificationTask.from_str(task) if task == ClassificationTask.BINARY: return BinaryRecallAtFixedPrecision(min_precision, thresholds, ignore_index, validate_args, **kwargs) diff --git a/src/torchmetrics/classification/roc.py b/src/torchmetrics/classification/roc.py index aa7d86c1cc0..12c5ae03888 100644 --- a/src/torchmetrics/classification/roc.py +++ b/src/torchmetrics/classification/roc.py @@ -103,6 +103,7 @@ class BinaryROC(BinaryPrecisionRecallCurve): full_state_update: bool = False def compute(self) -> Tuple[Tensor, Tensor, Tensor]: + """Compute metric.""" state = [dim_zero_cat(self.preds), dim_zero_cat(self.target)] if self.thresholds is None else self.confmat return _binary_roc_compute(state, self.thresholds) @@ -204,6 +205,7 @@ class MulticlassROC(MulticlassPrecisionRecallCurve): full_state_update: bool = False def compute(self) -> Tuple[Tensor, Tensor, Tensor]: + """Compute metric.""" state = [dim_zero_cat(self.preds), dim_zero_cat(self.target)] if self.thresholds is None else self.confmat return _multiclass_roc_compute(state, self.num_classes, self.thresholds) @@ -307,6 +309,7 @@ class MultilabelROC(MultilabelPrecisionRecallCurve): full_state_update: bool = False def compute(self) -> Tuple[Tensor, Tensor, Tensor]: + """Compute metric.""" state = [dim_zero_cat(self.preds), dim_zero_cat(self.target)] if self.thresholds is None else self.confmat return _multilabel_roc_compute(state, self.num_labels, self.thresholds, self.ignore_index) @@ -382,6 +385,7 @@ def __new__( validate_args: bool = True, **kwargs: Any, ) -> Metric: + """Initialize task metric.""" task = ClassificationTask.from_str(task) kwargs.update({"thresholds": thresholds, "ignore_index": ignore_index, "validate_args": validate_args}) if task == ClassificationTask.BINARY: diff --git a/src/torchmetrics/classification/specificity.py b/src/torchmetrics/classification/specificity.py index 0965a99bffc..d7f9e90a394 100644 --- a/src/torchmetrics/classification/specificity.py +++ b/src/torchmetrics/classification/specificity.py @@ -85,6 +85,7 @@ class BinarySpecificity(BinaryStatScores): """ def compute(self) -> Tensor: + """Compute metric.""" tp, fp, tn, fn = self._final_state() return _specificity_reduce(tp, fp, tn, fn, average="binary", multidim_average=self.multidim_average) @@ -184,6 +185,7 @@ class MulticlassSpecificity(MulticlassStatScores): """ def compute(self) -> Tensor: + """Compute metric.""" tp, fp, tn, fn = self._final_state() return _specificity_reduce(tp, fp, tn, fn, average=self.average, multidim_average=self.multidim_average) @@ -279,6 +281,7 @@ class MultilabelSpecificity(MultilabelStatScores): """ def compute(self) -> Tensor: + """Compute metric.""" tp, fp, tn, fn = self._final_state() return _specificity_reduce(tp, fp, tn, fn, average=self.average, multidim_average=self.multidim_average) @@ -321,6 +324,7 @@ def __new__( validate_args: bool = True, **kwargs: Any, ) -> Metric: + """Initialize task metric.""" task = ClassificationTask.from_str(task) assert multidim_average is not None kwargs.update( diff --git a/src/torchmetrics/classification/specificity_at_sensitivity.py b/src/torchmetrics/classification/specificity_at_sensitivity.py index 05f4aa03c14..386d5742842 100644 --- a/src/torchmetrics/classification/specificity_at_sensitivity.py +++ b/src/torchmetrics/classification/specificity_at_sensitivity.py @@ -109,6 +109,7 @@ def __init__( self.min_sensitivity = min_sensitivity def compute(self) -> Tuple[Tensor, Tensor]: # type: ignore[override] + """Compute metric.""" if self.thresholds is None: state = [dim_zero_cat(self.preds), dim_zero_cat(self.target)] # type: ignore else: @@ -202,6 +203,7 @@ def __init__( self.min_sensitivity = min_sensitivity def compute(self) -> Tuple[Tensor, Tensor]: # type: ignore + """Compute metric.""" if self.thresholds is None: state = [dim_zero_cat(self.preds), dim_zero_cat(self.target)] # type: ignore else: @@ -297,6 +299,7 @@ def __init__( self.min_sensitivity = min_sensitivity def compute(self) -> Tuple[Tensor, Tensor]: # type: ignore[override] + """Compute metric.""" if self.thresholds is None: state = [dim_zero_cat(self.preds), dim_zero_cat(self.target)] # type: ignore else: @@ -328,6 +331,7 @@ def __new__( # type: ignore validate_args: bool = True, **kwargs: Any, ) -> Metric: + """Initialize task metric.""" task = ClassificationTask.from_str(task) if task == ClassificationTask.BINARY: return BinarySpecificityAtSensitivity(min_sensitivity, thresholds, ignore_index, validate_args, **kwargs) diff --git a/src/torchmetrics/classification/stat_scores.py b/src/torchmetrics/classification/stat_scores.py index 8fd61a10748..7082bef3e61 100644 --- a/src/torchmetrics/classification/stat_scores.py +++ b/src/torchmetrics/classification/stat_scores.py @@ -496,6 +496,7 @@ def __new__( validate_args: bool = True, **kwargs: Any, ) -> Metric: + """Initialize task metric.""" task = ClassificationTask.from_str(task) assert multidim_average is not None kwargs.update( diff --git a/src/torchmetrics/image/fid.py b/src/torchmetrics/image/fid.py index 49bbc50b451..bac25fa11ce 100644 --- a/src/torchmetrics/image/fid.py +++ b/src/torchmetrics/image/fid.py @@ -54,6 +54,7 @@ def train(self, mode: bool) -> "NoTrainInceptionV3": return super().train(False) def forward(self, x: Tensor) -> Tensor: + """Forward pass of neural network with reshaping of output.""" out = super().forward(x) return out[0].reshape(x.shape[0], -1) @@ -66,6 +67,7 @@ class MatrixSquareRoot(Function): @staticmethod def forward(ctx: Any, input_data: Tensor) -> Tensor: + """Implements the forward pass for the matrix square root.""" # TODO: update whenever pytorch gets an matrix square root function # Issue: https://github.com/pytorch/pytorch/issues/9983 m = input_data.detach().cpu().numpy().astype(np.float_) @@ -76,6 +78,7 @@ def forward(ctx: Any, input_data: Tensor) -> Tensor: @staticmethod def backward(ctx: Any, grad_output: Tensor) -> Tensor: + """Implements the backward pass for matrix square root.""" grad_input = None if ctx.needs_input_grad[0]: (sqrtm,) = ctx.saved_tensors @@ -288,6 +291,7 @@ def compute(self) -> Tensor: return _compute_fid(mean_real.squeeze(0), cov_real, mean_fake.squeeze(0), cov_fake).to(self.orig_dtype) def reset(self) -> None: + """Reset metric states.""" if not self.reset_real_features: real_features_sum = deepcopy(self.real_features_sum) real_features_cov_sum = deepcopy(self.real_features_cov_sum) diff --git a/src/torchmetrics/image/inception.py b/src/torchmetrics/image/inception.py index 0dd01c2cdd0..ef137816796 100644 --- a/src/torchmetrics/image/inception.py +++ b/src/torchmetrics/image/inception.py @@ -141,6 +141,7 @@ def update(self, imgs: Tensor) -> None: # type: ignore self.features.append(features) def compute(self) -> Tuple[Tensor, Tensor]: + """Compute metric.""" features = dim_zero_cat(self.features) # random permute the features idx = torch.randperm(features.shape[0]) diff --git a/src/torchmetrics/image/kid.py b/src/torchmetrics/image/kid.py index c4403d99d05..da956bcfa01 100644 --- a/src/torchmetrics/image/kid.py +++ b/src/torchmetrics/image/kid.py @@ -268,6 +268,7 @@ def compute(self) -> Tuple[Tensor, Tensor]: return kid_scores.mean(), kid_scores.std(unbiased=False) def reset(self) -> None: + """Reset metric states.""" if not self.reset_real_features: # remove temporarily to avoid resetting value = self._defaults.pop("real_features") diff --git a/src/torchmetrics/metric.py b/src/torchmetrics/metric.py index 70e191273fc..552efaaf841 100644 --- a/src/torchmetrics/metric.py +++ b/src/torchmetrics/metric.py @@ -687,6 +687,15 @@ def state_dict( prefix: str = "", keep_vars: bool = False, ) -> Optional[Dict[str, Any]]: + """Get the current state of metric as an dictionary. + + Args: + destination: Optional dictionary, that if provided, the state of module will be updated into the dict and + the same object is returned. Otherwise, an ``OrderedDict`` will be created and returned. + prefix: optional string, a prefix added to parameter and buffer names to compose the keys in state_dict. + keep_vars: by default the :class:`~torch.Tensor`s returned in the state dict are detached from autograd. + If set to ``True``, detaching will not be performed. + """ destination = super().state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars) # Register metric states to be part of the state_dict for key in self._defaults: @@ -912,6 +921,7 @@ def _sync_dist(self, dist_sync_fn: Optional[Callable] = None, process_group: Opt pass def update(self, *args: Any, **kwargs: Any) -> None: + """Update metric state.""" if isinstance(self.metric_a, Metric): self.metric_a.update(*args, **self.metric_a._filter_kwargs(**kwargs)) @@ -919,6 +929,7 @@ def update(self, *args: Any, **kwargs: Any) -> None: self.metric_b.update(*args, **self.metric_b._filter_kwargs(**kwargs)) def compute(self) -> Any: + """Compute metric.""" # also some parsing for kwargs? val_a = self.metric_a.compute() if isinstance(self.metric_a, Metric) else self.metric_a @@ -931,6 +942,7 @@ def compute(self) -> Any: @torch.jit.unused def forward(self, *args: Any, **kwargs: Any) -> Any: + """Calculate metric on current batch and accumulate to global state.""" val_a = ( self.metric_a(*args, **self.metric_a._filter_kwargs(**kwargs)) if isinstance(self.metric_a, Metric) @@ -956,6 +968,7 @@ def forward(self, *args: Any, **kwargs: Any) -> Any: return self.op(val_a, val_b) def reset(self) -> None: + """Reset metric state.""" if isinstance(self.metric_a, Metric): self.metric_a.reset() @@ -963,6 +976,12 @@ def reset(self) -> None: self.metric_b.reset() def persistent(self, mode: bool = False) -> None: + """Change if metric state is persistent (save as part of state_dict) or not. + + Args: + mode: bool indicating if all states should be persistent or not + + """ if isinstance(self.metric_a, Metric): self.metric_a.persistent(mode=mode) if isinstance(self.metric_b, Metric): diff --git a/src/torchmetrics/regression/cosine_similarity.py b/src/torchmetrics/regression/cosine_similarity.py index 93e4a6d0f83..33cf0908df6 100644 --- a/src/torchmetrics/regression/cosine_similarity.py +++ b/src/torchmetrics/regression/cosine_similarity.py @@ -80,6 +80,7 @@ def update(self, preds: Tensor, target: Tensor) -> None: self.target.append(target) def compute(self) -> Tensor: + """Compute metric.""" preds = dim_zero_cat(self.preds) target = dim_zero_cat(self.target) return _cosine_similarity_compute(preds, target, self.reduction) diff --git a/src/torchmetrics/regression/kl_divergence.py b/src/torchmetrics/regression/kl_divergence.py index 581ed539078..7d33ff8199b 100644 --- a/src/torchmetrics/regression/kl_divergence.py +++ b/src/torchmetrics/regression/kl_divergence.py @@ -98,6 +98,7 @@ def __init__( self.add_state("total", torch.tensor(0), dist_reduce_fx="sum") def update(self, p: Tensor, q: Tensor) -> None: # type: ignore + """Update metric states with predictions and targets.""" measures, total = _kld_update(p, q, self.log_prob) if self.reduction is None or self.reduction == "none": self.measures.append(measures) @@ -106,5 +107,6 @@ def update(self, p: Tensor, q: Tensor) -> None: # type: ignore self.total += total def compute(self) -> Tensor: + """Compute metric.""" measures = dim_zero_cat(self.measures) if self.reduction is None or self.reduction == "none" else self.measures return _kld_compute(measures, self.total, self.reduction) diff --git a/src/torchmetrics/regression/tweedie_deviance.py b/src/torchmetrics/regression/tweedie_deviance.py index e0acbc83dcc..77dee0c4626 100644 --- a/src/torchmetrics/regression/tweedie_deviance.py +++ b/src/torchmetrics/regression/tweedie_deviance.py @@ -97,4 +97,5 @@ def update(self, preds: Tensor, targets: Tensor) -> None: self.num_observations += num_observations def compute(self) -> Tensor: + """Compute metric.""" return _tweedie_deviance_score_compute(self.sum_deviance_score, self.num_observations) diff --git a/src/torchmetrics/retrieval/precision_recall_curve.py b/src/torchmetrics/retrieval/precision_recall_curve.py index aecbb5bca4d..d2dafa1a12c 100644 --- a/src/torchmetrics/retrieval/precision_recall_curve.py +++ b/src/torchmetrics/retrieval/precision_recall_curve.py @@ -172,6 +172,7 @@ def update(self, preds: Tensor, target: Tensor, indexes: Tensor) -> None: # typ self.target.append(target) def compute(self) -> Tuple[Tensor, Tensor, Tensor]: + """Compute metric.""" # concat all data indexes = dim_zero_cat(self.indexes) preds = dim_zero_cat(self.preds) @@ -304,6 +305,7 @@ def __init__( self.min_precision = min_precision def compute(self) -> Tuple[Tensor, Tensor]: # type: ignore + """Compute metric.""" precisions, recalls, top_k = super().compute() return _retrieval_recall_at_fixed_precision(precisions, recalls, top_k, self.min_precision) diff --git a/src/torchmetrics/wrappers/classwise.py b/src/torchmetrics/wrappers/classwise.py index fa17e009981..ba571b60c2c 100644 --- a/src/torchmetrics/wrappers/classwise.py +++ b/src/torchmetrics/wrappers/classwise.py @@ -90,15 +90,19 @@ def _convert(self, x: Tensor) -> Dict[str, Any]: return {f"{name}_{lab}": val for lab, val in zip(self.labels, x)} def forward(self, *args: Any, **kwargs: Any) -> Any: + """Calculate on batch and accumulate to global state.""" return self._convert(self.metric(*args, **kwargs)) def update(self, *args: Any, **kwargs: Any) -> None: + """Update state.""" self.metric.update(*args, **kwargs) def compute(self) -> Dict[str, Tensor]: + """Compute metric.""" return self._convert(self.metric.compute()) def reset(self) -> None: + """Reset metric.""" self.metric.reset() def _wrap_update(self, update: Callable) -> Callable: