-
Notifications
You must be signed in to change notification settings - Fork 517
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Feature/sg 858 ignore multiple labels segmentation metrics support (#…
…1177) * added unit tests * finalized unit tests * updated docs and test suite * updated docs * updated update() * updated update() test * renamed var in _handle_multiple_ignored_inds * faster index mapping for pixel accuracy * faster index mapping for pixel accuracy fix * fixed non inplace op in pixel accuracy * type checking for ignore index expanded to iterable --------- Co-authored-by: Eugene Khvedchenya <ekhvedchenya@gmail.com> Co-authored-by: Ofri Masad <ofrimasad@users.noreply.github.com>
- Loading branch information
1 parent
2ca8647
commit a2201ae
Showing
3 changed files
with
211 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
57 changes: 57 additions & 0 deletions
57
tests/unit_tests/multiple_ignore_indices_segmentation_metrics_test.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
import unittest | ||
|
||
import torch | ||
|
||
from super_gradients.training.metrics import IoU, PixelAccuracy, Dice | ||
|
||
|
||
class TestSegmentationMetricsMultipleIgnored(unittest.TestCase): | ||
def test_iou_with_multiple_ignored_classes_and_absent_score(self): | ||
metric_multi_ignored = IoU(num_classes=5, ignore_index=[3, 1, 2]) | ||
target_multi_ignored = torch.tensor([[3, 1, 2, 4, 4, 4]]) | ||
pred = torch.zeros((1, 5, 6)) | ||
pred[:, 4] = 1 | ||
|
||
# preds after onehot -> [4,4,4,4,4,4] | ||
# (1 + 0)/2 : 1.0 for class 4 score and 0 for absent score for class 0 | ||
self.assertEqual(metric_multi_ignored(pred, target_multi_ignored), 0.5) | ||
|
||
def test_iou_with_multiple_ignored_classes_no_absent_score(self): | ||
metric_multi_ignored = IoU(num_classes=5, ignore_index=[3, 1, 2]) | ||
target_multi_ignored = torch.tensor([[3, 1, 2, 0, 4, 4]]) | ||
pred = torch.zeros((1, 5, 6)) | ||
pred[:, 4] = 1 | ||
pred[0, 0, 3] = 2 | ||
|
||
# preds after onehot -> [4,4,4,0,4,4] | ||
# (1 + 1)/2 : 1.0 for class 4 score and 1 for class 0 | ||
self.assertEqual(metric_multi_ignored(pred, target_multi_ignored), 1) | ||
|
||
def test_dice_with_multiple_ignored_classes_and_absent_score(self): | ||
metric_multi_ignored = Dice(num_classes=5, ignore_index=[3, 1, 2]) | ||
target_multi_ignored = torch.tensor([[3, 1, 2, 4, 4, 4]]) | ||
pred = torch.zeros((1, 5, 6)) | ||
pred[:, 4] = 1 | ||
|
||
self.assertEqual(metric_multi_ignored(pred, target_multi_ignored), 0.5) | ||
|
||
def test_dice_with_multiple_ignored_classes_no_absent_score(self): | ||
metric_multi_ignored = Dice(num_classes=5, ignore_index=[3, 1, 2]) | ||
target_multi_ignored = torch.tensor([[3, 1, 2, 0, 4, 4]]) | ||
pred = torch.zeros((1, 5, 6)) | ||
pred[:, 4] = 1 | ||
pred[0, 0, 3] = 2 | ||
|
||
self.assertEqual(metric_multi_ignored(pred, target_multi_ignored), 1.0) | ||
|
||
def test_pixelaccuracy_with_multiple_ignored_classes(self): | ||
metric_multi_ignored = PixelAccuracy(ignore_label=[3, 1, 2]) | ||
target_multi_ignored = torch.tensor([[3, 1, 2, 4, 4, 4]]) | ||
pred = torch.zeros((1, 5, 6)) | ||
pred[:, 4] = 1 | ||
|
||
self.assertEqual(metric_multi_ignored(pred, target_multi_ignored), 1.0) | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |