diff --git a/CHANGELOG.md b/CHANGELOG.md index e986a3a5fc..4256c8623b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,6 +17,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added `VisionDataModule` as parent class for `BinaryMNISTDataModule`, `CIFAR10DataModule`, `FashionMNISTDataModule`, and `MNISTDataModule` ([#400](https://github.com/PyTorchLightning/pytorch-lightning-bolts/pull/400)) +- Added GIoU loss ([#347](https://github.com/PyTorchLightning/pytorch-lightning-bolts/pull/347)) + ### Changed - Decoupled datamodules from models ([#332](https://github.com/PyTorchLightning/pytorch-lightning-bolts/pull/332), [#270](https://github.com/PyTorchLightning/pytorch-lightning-bolts/pull/270)) diff --git a/docs/source/losses.rst b/docs/source/losses.rst index 44b401dfcc..901bb71964 100644 --- a/docs/source/losses.rst +++ b/docs/source/losses.rst @@ -13,6 +13,20 @@ We're cleaning up many of our losses, but in the meantime, submit a PR to add yo ------------- +Object Detection +====================== +These are common losses used in object detection. + +--------------- + +GIoU Loss +--------- + +.. autofunction:: pl_bolts.losses.object_detection.giou_loss + :noindex: + +--------------- + Reinforcement Learning ====================== These are common losses used in RL. diff --git a/pl_bolts/losses/object_detection.py b/pl_bolts/losses/object_detection.py new file mode 100644 index 0000000000..81d0404813 --- /dev/null +++ b/pl_bolts/losses/object_detection.py @@ -0,0 +1,34 @@ +""" +Generalized Intersection over Union (GIoU) loss (Rezatofighi et. al) +""" + +import torch + +from pl_bolts.metrics.object_detection import giou + + +def giou_loss(preds: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + """ + Calculates the generalized intersection over union loss. + + It has been proposed in `Generalized Intersection over Union: A Metric and A + Loss for Bounding Box Regression `_. + + Args: + preds: batch of prediction bounding boxes with representation ``[x_min, y_min, x_max, y_max]`` + target: batch of target bounding boxes with representation ``[x_min, y_min, x_max, y_max]`` + + Example: + + >>> import torch + >>> from pl_bolts.losses.object_detection import giou_loss + >>> preds = torch.tensor([[100, 100, 200, 200]]) + >>> target = torch.tensor([[150, 150, 250, 250]]) + >>> giou_loss(preds, target) + tensor([[1.0794]]) + + Returns: + GIoU loss + """ + loss = 1 - giou(preds, target) + return loss diff --git a/pl_bolts/metrics/object_detection.py b/pl_bolts/metrics/object_detection.py new file mode 100644 index 0000000000..3175f3ce24 --- /dev/null +++ b/pl_bolts/metrics/object_detection.py @@ -0,0 +1,42 @@ +import torch + + +def giou(preds: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + """ + Calculates the generalized intersection over union. + + It has been proposed in `Generalized Intersection over Union: A Metric and A + Loss for Bounding Box Regression `_. + + Args: + preds: batch of prediction bounding boxes with representation ``[x_min, y_min, x_max, y_max]`` + target: batch of target bounding boxes with representation ``[x_min, y_min, x_max, y_max]`` + + Example: + + >>> import torch + >>> from pl_bolts.metrics.object_detection import giou + >>> preds = torch.tensor([[100, 100, 200, 200]]) + >>> target = torch.tensor([[150, 150, 250, 250]]) + >>> giou(preds, target) + tensor([[-0.0794]]) + + Returns: + GIoU value + """ + x_min = torch.max(preds[:, None, 0], target[:, 0]) + y_min = torch.max(preds[:, None, 1], target[:, 1]) + x_max = torch.min(preds[:, None, 2], target[:, 2]) + y_max = torch.min(preds[:, None, 3], target[:, 3]) + intersection = (x_max - x_min).clamp(min=0) * (y_max - y_min).clamp(min=0) + pred_area = (preds[:, 2] - preds[:, 0]) * (preds[:, 3] - preds[:, 1]) + target_area = (target[:, 2] - target[:, 0]) * (target[:, 3] - target[:, 1]) + union = pred_area[:, None] + target_area - intersection + C_x_min = torch.min(preds[:, None, 0], target[:, 0]) + C_y_min = torch.min(preds[:, None, 1], target[:, 1]) + C_x_max = torch.max(preds[:, None, 2], target[:, 2]) + C_y_max = torch.max(preds[:, None, 3], target[:, 3]) + C_area = (C_x_max - C_x_min).clamp(min=0) * (C_y_max - C_y_min).clamp(min=0) + iou = torch.true_divide(intersection, union) + giou = iou - torch.true_divide((C_area - union), C_area) + return giou diff --git a/tests/losses/test_object_detection.py b/tests/losses/test_object_detection.py new file mode 100644 index 0000000000..30f0ab4576 --- /dev/null +++ b/tests/losses/test_object_detection.py @@ -0,0 +1,23 @@ +""" +Test Object Detection Loss Functions +""" + +import pytest +import torch + +from pl_bolts.losses.object_detection import giou_loss + + +@pytest.mark.parametrize("preds, target, expected_loss", [ + (torch.tensor([[100, 100, 200, 200]]), torch.tensor([[100, 100, 200, 200]]), torch.tensor([0.0])) +]) +def test_complete_overlap(preds, target, expected_loss): + torch.testing.assert_allclose(giou_loss(preds, target), expected_loss) + + +@pytest.mark.parametrize("preds, target, expected_loss", [ + (torch.tensor([[100, 100, 200, 200]]), torch.tensor([[100, 200, 200, 300]]), torch.tensor([1.0])), + (torch.tensor([[100, 100, 200, 200]]), torch.tensor([[200, 200, 300, 300]]), torch.tensor([1.5])), +]) +def test_no_overlap(preds, target, expected_loss): + torch.testing.assert_allclose(giou_loss(preds, target), expected_loss) diff --git a/tests/metrics/__init__.py b/tests/metrics/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/metrics/test_object_detection.py b/tests/metrics/test_object_detection.py new file mode 100644 index 0000000000..a998502314 --- /dev/null +++ b/tests/metrics/test_object_detection.py @@ -0,0 +1,23 @@ +""" +Test Object Detection Metric Functions +""" + +import pytest +import torch + +from pl_bolts.metrics.object_detection import giou + + +@pytest.mark.parametrize("preds, target, expected_giou", [ + (torch.tensor([[100, 100, 200, 200]]), torch.tensor([[100, 100, 200, 200]]), torch.tensor([1.0])) +]) +def test_complete_overlap(preds, target, expected_giou): + torch.testing.assert_allclose(giou(preds, target), expected_giou) + + +@pytest.mark.parametrize("preds, target, expected_giou", [ + (torch.tensor([[100, 100, 200, 200]]), torch.tensor([[100, 200, 200, 300]]), torch.tensor([0.0])), + (torch.tensor([[100, 100, 200, 200]]), torch.tensor([[200, 200, 300, 300]]), torch.tensor([-0.5])), +]) +def test_no_overlap(preds, target, expected_giou): + torch.testing.assert_allclose(giou(preds, target), expected_giou)