diff --git a/CHANGELOG.md b/CHANGELOG.md index 4256c8623b..08496b9146 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,6 +19,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added GIoU loss ([#347](https://github.com/PyTorchLightning/pytorch-lightning-bolts/pull/347)) +- Added IoU loss ([#469](https://github.com/PyTorchLightning/pytorch-lightning-bolts/pull/469)) + ### Changed - Decoupled datamodules from models ([#332](https://github.com/PyTorchLightning/pytorch-lightning-bolts/pull/332), [#270](https://github.com/PyTorchLightning/pytorch-lightning-bolts/pull/270)) diff --git a/docs/source/losses.rst b/docs/source/losses.rst index 901bb71964..7a0f09aee0 100644 --- a/docs/source/losses.rst +++ b/docs/source/losses.rst @@ -27,6 +27,14 @@ GIoU Loss --------------- +IoU Loss +-------- + +.. autofunction:: pl_bolts.losses.object_detection.iou_loss + :noindex: + +--------------- + Reinforcement Learning ====================== These are common losses used in RL. diff --git a/pl_bolts/losses/object_detection.py b/pl_bolts/losses/object_detection.py index 81d0404813..ccbba5e707 100644 --- a/pl_bolts/losses/object_detection.py +++ b/pl_bolts/losses/object_detection.py @@ -1,10 +1,34 @@ """ -Generalized Intersection over Union (GIoU) loss (Rezatofighi et. al) +Loss functions for Object Detection task """ import torch -from pl_bolts.metrics.object_detection import giou +from pl_bolts.metrics.object_detection import giou, iou + + +def iou_loss(preds: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + """ + Calculates the intersection over union loss. + + Args: + preds: batch of prediction bounding boxes with representation ``[x_min, y_min, x_max, y_max]`` + target: batch of target bounding boxes with representation ``[x_min, y_min, x_max, y_max]`` + + Example: + + >>> import torch + >>> from pl_bolts.losses.object_detection import iou_loss + >>> preds = torch.tensor([[100, 100, 200, 200]]) + >>> target = torch.tensor([[150, 150, 250, 250]]) + >>> iou_loss(preds, target) + tensor([[0.8571]]) + + Returns: + IoU loss + """ + loss = 1 - iou(preds, target) + return loss def giou_loss(preds: torch.Tensor, target: torch.Tensor) -> torch.Tensor: diff --git a/pl_bolts/metrics/object_detection.py b/pl_bolts/metrics/object_detection.py index 3175f3ce24..21352888b8 100644 --- a/pl_bolts/metrics/object_detection.py +++ b/pl_bolts/metrics/object_detection.py @@ -1,6 +1,39 @@ import torch +def iou(preds: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + """ + Calculates the intersection over union. + + Args: + preds: an Nx4 batch of prediction bounding boxes with representation ``[x_min, y_min, x_max, y_max]`` + target: an Mx4 batch of target bounding boxes with representation ``[x_min, y_min, x_max, y_max]`` + + Example: + + >>> import torch + >>> from pl_bolts.metrics.object_detection import iou + >>> preds = torch.tensor([[100, 100, 200, 200]]) + >>> target = torch.tensor([[150, 150, 250, 250]]) + >>> iou(preds, target) + tensor([[0.1429]]) + + Returns: + IoU tensor: an NxM tensor containing the pairwise IoU values for every element in preds and target, + where N is the number of prediction bounding boxes and M is the number of target bounding boxes + """ + x_min = torch.max(preds[:, None, 0], target[:, 0]) + y_min = torch.max(preds[:, None, 1], target[:, 1]) + x_max = torch.min(preds[:, None, 2], target[:, 2]) + y_max = torch.min(preds[:, None, 3], target[:, 3]) + intersection = (x_max - x_min).clamp(min=0) * (y_max - y_min).clamp(min=0) + pred_area = (preds[:, 2] - preds[:, 0]) * (preds[:, 3] - preds[:, 1]) + target_area = (target[:, 2] - target[:, 0]) * (target[:, 3] - target[:, 1]) + union = pred_area[:, None] + target_area - intersection + iou = torch.true_divide(intersection, union) + return iou + + def giou(preds: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """ Calculates the generalized intersection over union. diff --git a/tests/losses/test_object_detection.py b/tests/losses/test_object_detection.py index 30f0ab4576..a6117f7eb8 100644 --- a/tests/losses/test_object_detection.py +++ b/tests/losses/test_object_detection.py @@ -5,7 +5,22 @@ import pytest import torch -from pl_bolts.losses.object_detection import giou_loss +from pl_bolts.losses.object_detection import giou_loss, iou_loss + + +@pytest.mark.parametrize("preds, target, expected_loss", [ + (torch.tensor([[100, 100, 200, 200]]), torch.tensor([[100, 100, 200, 200]]), torch.tensor([0.0])) +]) +def test_iou_complete_overlap(preds, target, expected_loss): + torch.testing.assert_allclose(iou_loss(preds, target), expected_loss) + + +@pytest.mark.parametrize("preds, target, expected_loss", [ + (torch.tensor([[100, 100, 200, 200]]), torch.tensor([[100, 200, 200, 300]]), torch.tensor([1.0])), + (torch.tensor([[100, 100, 200, 200]]), torch.tensor([[200, 200, 300, 300]]), torch.tensor([1.0])), +]) +def test_iou_no_overlap(preds, target, expected_loss): + torch.testing.assert_allclose(iou_loss(preds, target), expected_loss) @pytest.mark.parametrize("preds, target, expected_loss", [ diff --git a/tests/metrics/test_object_detection.py b/tests/metrics/test_object_detection.py index a998502314..59b2d8f32e 100644 --- a/tests/metrics/test_object_detection.py +++ b/tests/metrics/test_object_detection.py @@ -5,7 +5,33 @@ import pytest import torch -from pl_bolts.metrics.object_detection import giou +from pl_bolts.metrics.object_detection import giou, iou + + +@pytest.mark.parametrize("preds, target, expected_iou", [ + (torch.tensor([[100, 100, 200, 200]]), torch.tensor([[100, 100, 200, 200]]), torch.tensor([1.0])) +]) +def test_iou_complete_overlap(preds, target, expected_iou): + torch.testing.assert_allclose(iou(preds, target), expected_iou) + + +@pytest.mark.parametrize("preds, target, expected_iou", [ + (torch.tensor([[100, 100, 200, 200]]), torch.tensor([[100, 200, 200, 300]]), torch.tensor([0.0])), + (torch.tensor([[100, 100, 200, 200]]), torch.tensor([[200, 200, 300, 300]]), torch.tensor([0.0])), +]) +def test_iou_no_overlap(preds, target, expected_iou): + torch.testing.assert_allclose(iou(preds, target), expected_iou) + + +@pytest.mark.parametrize("preds, target, expected_iou", [ + ( + torch.tensor([[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]]), + torch.tensor([[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]]), + torch.tensor([[1.0, 0.25, 0.0], [0.25, 1.0, 0.0], [0.0, 0.0, 1.0]]) + ) +]) +def test_iou_multi(preds, target, expected_iou): + torch.testing.assert_allclose(iou(preds, target), expected_iou) @pytest.mark.parametrize("preds, target, expected_giou", [ @@ -21,3 +47,14 @@ def test_complete_overlap(preds, target, expected_giou): ]) def test_no_overlap(preds, target, expected_giou): torch.testing.assert_allclose(giou(preds, target), expected_giou) + + +@pytest.mark.parametrize("preds, target, expected_giou", [ + ( + torch.tensor([[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]]), + torch.tensor([[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]]), + torch.tensor([[1.0, 0.25, -0.7778], [0.25, 1.0, -0.8611], [-0.7778, -0.8611, 1.0]]) + ) +]) +def test_giou_multi(preds, target, expected_giou): + torch.testing.assert_allclose(giou(preds, target), expected_giou)