Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Docs: fix trainer metric definitions #1924

Merged
merged 3 commits into from
Mar 3, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 16 additions & 14 deletions torchgeo/trainers/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,14 +98,15 @@ def configure_losses(self) -> None:
def configure_metrics(self) -> None:
"""Initialize the performance metrics.

* Multiclass Overall Accuracy (OA): Ratio of correctly classified pixels.
Uses 'micro' averaging. Higher values are better.
* Multiclass Average Accuracy (AA): Ratio of correctly classified classes.
Uses 'macro' averaging. Higher values are better.
* Multiclass Jaccard Index (IoU): Per-class overlap between predicted and
actual classes. Uses 'macro' averaging. Higher valuers are better.
* Multiclass F1 Score: The harmonic mean of precision and recall.
Uses 'micro' averaging. Higher values are better.
* :class:`~torchmetrics.classification.MulticlassAccuracy`: The number of
true positives divided by the dataset size. Both overall accuracy (OA)
using 'micro' averaging and average accuracy (AA) using 'macro' averaging
are reported. Higher values are better.
* :class:`~torchmetrics.classification.MulticlassJaccardIndex`: Intersection
over union (IoU). Uses 'macro' averaging. Higher valuers are better.
* :class:`~torchmetrics.classification.MulticlassFBetaScore`: F1 score.
The harmonic mean of precision and recall. Uses 'micro' averaging.
Higher values are better.

.. note::
* 'Micro' averaging suits overall performance evaluation but may not reflect
Expand Down Expand Up @@ -270,12 +271,13 @@ class MultiLabelClassificationTask(ClassificationTask):
def configure_metrics(self) -> None:
"""Initialize the performance metrics.

* Multiclass Overall Accuracy (OA): Ratio of correctly classified pixels.
Uses 'micro' averaging. Higher values are better.
* Multiclass Average Accuracy (AA): Ratio of correctly classified classes.
Uses 'macro' averaging. Higher values are better.
* Multiclass F1 Score: The harmonic mean of precision and recall.
Uses 'micro' averaging. Higher values are better.
* :class:`~torchmetrics.classification.MultilabelAccuracy`: The number of
true positives divided by the dataset size. Both overall accuracy (OA)
using 'micro' averaging and average accuracy (AA) using 'macro' averaging
are reported. Higher values are better.
* :class:`~torchmetrics.classification.MultilabelFBetaScore`: F1 score.
The harmonic mean of precision and recall. Uses 'micro' averaging.
Higher values are better.

.. note::
* 'Micro' averaging suits overall performance evaluation but may not
Expand Down
11 changes: 6 additions & 5 deletions torchgeo/trainers/detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,18 +206,19 @@ def configure_models(self) -> None:
def configure_metrics(self) -> None:
"""Initialize the performance metrics.

* Mean Average Precision (mAP): Computes the Mean-Average-Precision (mAP) and
Mean-Average-Recall (mAR) for object detection. Prediction is based on the
intersection over union (IoU) between the predicted bounding boxes and the
ground truth bounding boxes. Uses 'macro' averaging. Higher values are better.
* :class:`~torchmetrics.detection.mean_ap.MeanAveragePrecision`: Mean average
precision (mAP) and mean average recall (mAR). Precision is the number of
true positives divided by the number of true positives + false positives.
Recall is the number of true positives divived by the number of true positives
+ false negatives. Uses 'macro' averaging. Higher values are better.

.. note::
* 'Micro' averaging suits overall performance evaluation but may not
reflect minority class accuracy.
* 'Macro' averaging gives equal weight to each class, and is useful for
balanced performance assessment across imbalanced classes.
"""
metrics = MetricCollection([MeanAveragePrecision()])
metrics = MetricCollection([MeanAveragePrecision(average="macro")])
adamjstewart marked this conversation as resolved.
Show resolved Hide resolved
self.val_metrics = metrics.clone(prefix="val_")
self.test_metrics = metrics.clone(prefix="test_")

Expand Down
18 changes: 6 additions & 12 deletions torchgeo/trainers/regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,18 +99,12 @@ def configure_losses(self) -> None:
def configure_metrics(self) -> None:
"""Initialize the performance metrics.

* Root Mean Squared Error (RMSE): The square root of the average of the squared
differences between the predicted and actual values. Lower values are better.
* Mean Squared Error (MSE): The average of the squared differences between the
predicted and actual values. Lower values are better.
* Mean Absolute Error (MAE): The average of the absolute differences between the
predicted and actual values. Lower values are better.

.. note::
* 'Micro' averaging suits overall performance evaluation but may not reflect
minority class accuracy.
* 'Macro' averaging gives equal weight to each class, and is useful for
balanced performance assessment across imbalanced classes.
* :class:`~torchmetrics.MeanSquaredError`: The average of the squared
differences between the predicted and actual values (MSE) and its
square root (RMSE). Lower values are better.
* :class:`~torchmetrics.MeanAbsoluteError`: The average of the absolute
differences between the predicted and actual values (MAE).
Lower values are better.
"""
metrics = MetricCollection(
{
Expand Down
9 changes: 5 additions & 4 deletions torchgeo/trainers/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,10 +124,11 @@ def configure_losses(self) -> None:
def configure_metrics(self) -> None:
"""Initialize the performance metrics.

* Multiclass Pixel Accuracy: Ratio of correctly classified pixels.
Uses 'micro' averaging. Higher values are better.
* Multiclass Jaccard Index (IoU): Per-pixel overlap between predicted and
actual segments. Uses 'macro' averaging. Higher values are better.
* :class:`~torchmetrics.classification.MulticlassAccuracy`: Overall accuracy
(OA) using 'micro' averaging. The number of true positives divided by the
dataset size. Higher values are better.
* :class:`~torchmetrics.classification.MulticlassJaccardIndex`: Intersection
over union (IoU). Uses 'micro' averaging. Higher valuers are better.

.. note::
* 'Micro' averaging suits overall performance evaluation but may not reflect
Expand Down
Loading