diff --git a/docs/conf.py b/docs/conf.py index aa54a1526b3..67eebbc2865 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -121,6 +121,7 @@ "sklearn": ("https://scikit-learn.org/stable/", None), "timm": ("https://huggingface.co/docs/timm/main/en/", None), "torch": ("https://pytorch.org/docs/stable", None), + "torchmetrics": ("https://lightning.ai/docs/torchmetrics/stable/", None), "torchvision": ("https://pytorch.org/vision/stable", None), } diff --git a/torchgeo/trainers/classification.py b/torchgeo/trainers/classification.py index 9c3ca6abcd3..ab802827e4c 100644 --- a/torchgeo/trainers/classification.py +++ b/torchgeo/trainers/classification.py @@ -98,14 +98,15 @@ def configure_losses(self) -> None: def configure_metrics(self) -> None: """Initialize the performance metrics. - * Multiclass Overall Accuracy (OA): Ratio of correctly classified pixels. - Uses 'micro' averaging. Higher values are better. - * Multiclass Average Accuracy (AA): Ratio of correctly classified classes. - Uses 'macro' averaging. Higher values are better. - * Multiclass Jaccard Index (IoU): Per-class overlap between predicted and - actual classes. Uses 'macro' averaging. Higher valuers are better. - * Multiclass F1 Score: The harmonic mean of precision and recall. - Uses 'micro' averaging. Higher values are better. + * :class:`~torchmetrics.classification.MulticlassAccuracy`: The number of + true positives divided by the dataset size. Both overall accuracy (OA) + using 'micro' averaging and average accuracy (AA) using 'macro' averaging + are reported. Higher values are better. + * :class:`~torchmetrics.classification.MulticlassJaccardIndex`: Intersection + over union (IoU). Uses 'macro' averaging. Higher valuers are better. + * :class:`~torchmetrics.classification.MulticlassFBetaScore`: F1 score. + The harmonic mean of precision and recall. Uses 'micro' averaging. + Higher values are better. .. note:: * 'Micro' averaging suits overall performance evaluation but may not reflect @@ -270,12 +271,13 @@ class MultiLabelClassificationTask(ClassificationTask): def configure_metrics(self) -> None: """Initialize the performance metrics. - * Multiclass Overall Accuracy (OA): Ratio of correctly classified pixels. - Uses 'micro' averaging. Higher values are better. - * Multiclass Average Accuracy (AA): Ratio of correctly classified classes. - Uses 'macro' averaging. Higher values are better. - * Multiclass F1 Score: The harmonic mean of precision and recall. - Uses 'micro' averaging. Higher values are better. + * :class:`~torchmetrics.classification.MultilabelAccuracy`: The number of + true positives divided by the dataset size. Both overall accuracy (OA) + using 'micro' averaging and average accuracy (AA) using 'macro' averaging + are reported. Higher values are better. + * :class:`~torchmetrics.classification.MultilabelFBetaScore`: F1 score. + The harmonic mean of precision and recall. Uses 'micro' averaging. + Higher values are better. .. note:: * 'Micro' averaging suits overall performance evaluation but may not diff --git a/torchgeo/trainers/detection.py b/torchgeo/trainers/detection.py index b0918d1140b..b783719f6e0 100644 --- a/torchgeo/trainers/detection.py +++ b/torchgeo/trainers/detection.py @@ -206,10 +206,11 @@ def configure_models(self) -> None: def configure_metrics(self) -> None: """Initialize the performance metrics. - * Mean Average Precision (mAP): Computes the Mean-Average-Precision (mAP) and - Mean-Average-Recall (mAR) for object detection. Prediction is based on the - intersection over union (IoU) between the predicted bounding boxes and the - ground truth bounding boxes. Uses 'macro' averaging. Higher values are better. + * :class:`~torchmetrics.detection.mean_ap.MeanAveragePrecision`: Mean average + precision (mAP) and mean average recall (mAR). Precision is the number of + true positives divided by the number of true positives + false positives. + Recall is the number of true positives divived by the number of true positives + + false negatives. Uses 'macro' averaging. Higher values are better. .. note:: * 'Micro' averaging suits overall performance evaluation but may not @@ -217,7 +218,7 @@ def configure_metrics(self) -> None: * 'Macro' averaging gives equal weight to each class, and is useful for balanced performance assessment across imbalanced classes. """ - metrics = MetricCollection([MeanAveragePrecision()]) + metrics = MetricCollection([MeanAveragePrecision(average="macro")]) self.val_metrics = metrics.clone(prefix="val_") self.test_metrics = metrics.clone(prefix="test_") diff --git a/torchgeo/trainers/regression.py b/torchgeo/trainers/regression.py index 8495a823f6c..5d7b3061383 100644 --- a/torchgeo/trainers/regression.py +++ b/torchgeo/trainers/regression.py @@ -99,18 +99,12 @@ def configure_losses(self) -> None: def configure_metrics(self) -> None: """Initialize the performance metrics. - * Root Mean Squared Error (RMSE): The square root of the average of the squared - differences between the predicted and actual values. Lower values are better. - * Mean Squared Error (MSE): The average of the squared differences between the - predicted and actual values. Lower values are better. - * Mean Absolute Error (MAE): The average of the absolute differences between the - predicted and actual values. Lower values are better. - - .. note:: - * 'Micro' averaging suits overall performance evaluation but may not reflect - minority class accuracy. - * 'Macro' averaging gives equal weight to each class, and is useful for - balanced performance assessment across imbalanced classes. + * :class:`~torchmetrics.MeanSquaredError`: The average of the squared + differences between the predicted and actual values (MSE) and its + square root (RMSE). Lower values are better. + * :class:`~torchmetrics.MeanAbsoluteError`: The average of the absolute + differences between the predicted and actual values (MAE). + Lower values are better. """ metrics = MetricCollection( { diff --git a/torchgeo/trainers/segmentation.py b/torchgeo/trainers/segmentation.py index 8d5dd591140..bc1e88747a0 100644 --- a/torchgeo/trainers/segmentation.py +++ b/torchgeo/trainers/segmentation.py @@ -124,10 +124,11 @@ def configure_losses(self) -> None: def configure_metrics(self) -> None: """Initialize the performance metrics. - * Multiclass Pixel Accuracy: Ratio of correctly classified pixels. - Uses 'micro' averaging. Higher values are better. - * Multiclass Jaccard Index (IoU): Per-pixel overlap between predicted and - actual segments. Uses 'macro' averaging. Higher values are better. + * :class:`~torchmetrics.classification.MulticlassAccuracy`: Overall accuracy + (OA) using 'micro' averaging. The number of true positives divided by the + dataset size. Higher values are better. + * :class:`~torchmetrics.classification.MulticlassJaccardIndex`: Intersection + over union (IoU). Uses 'micro' averaging. Higher valuers are better. .. note:: * 'Micro' averaging suits overall performance evaluation but may not reflect