Skip to content

Commit

Permalink
Implemented GIoU (#347)
Browse files Browse the repository at this point in the history
* Implemented GIoU

* Changing max/min to torch.max/torch.min

* Adding tests for GIoU

* Adding documentation for GIoU

* Updated docstring

Co-authored-by: Jeff Yang <ydcjeff@outlook.com>

* Fixed isort error and added link to paper in docstring

* Parametrizing tests using pytest

* Adding changelog and removing eps

* Fixing error

* isort test

* Renaming file to object_detection.py

* Reflecting module name change in test

* Fixing doc to reflect module name change

* Implemented GIoU

* Changing max/min to torch.max/torch.min

* Adding tests for GIoU

* Updated docstring

Co-authored-by: Jeff Yang <ydcjeff@outlook.com>

* Fixed isort error and added link to paper in docstring

* Parametrizing tests using pytest

* Adding changelog and removing eps

* Fixing error

* isort test

* Renaming file to object_detection.py

* Reflecting module name change in test

* Implemented GIoU

* Changing max/min to torch.max/torch.min

* Adding tests for GIoU

* Updated docstring

Co-authored-by: Jeff Yang <ydcjeff@outlook.com>

* Fixed isort error and added link to paper in docstring

* Parametrizing tests using pytest

* Adding changelog and removing eps

* Fixing error

* isort test

* Renaming file to object_detection.py

* Reflecting module name change in test

* Implemented GIoU

* Changing max/min to torch.max/torch.min

* Adding tests for GIoU

* Updated docstring

Co-authored-by: Jeff Yang <ydcjeff@outlook.com>

* Fixed isort error and added link to paper in docstring

* Parametrizing tests using pytest

* Adding changelog and removing eps

* Fixing error

* isort test

* Renaming file to object_detection.py

* Reflecting module name change in test

* Update pl_bolts/losses/object_detection.py

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>

* Updating docstring for giou

* Fixing code formatting

* Adding doctest

* Adding tests for giou

* refactor

* fix

* format

Co-authored-by: Jeff Yang <ydcjeff@outlook.com>
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
Co-authored-by: Jirka Borovec <jirka.borovec@seznam.cz>
  • Loading branch information
4 people authored Dec 20, 2020
1 parent 4309cab commit a72766c
Show file tree
Hide file tree
Showing 7 changed files with 138 additions and 0 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added `VisionDataModule` as parent class for `BinaryMNISTDataModule`, `CIFAR10DataModule`, `FashionMNISTDataModule`,
and `MNISTDataModule` ([#400](https://github.com/PyTorchLightning/pytorch-lightning-bolts/pull/400))

- Added GIoU loss ([#347](https://github.com/PyTorchLightning/pytorch-lightning-bolts/pull/347))

### Changed

- Decoupled datamodules from models ([#332](https://github.com/PyTorchLightning/pytorch-lightning-bolts/pull/332), [#270](https://github.com/PyTorchLightning/pytorch-lightning-bolts/pull/270))
Expand Down
14 changes: 14 additions & 0 deletions docs/source/losses.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,20 @@ We're cleaning up many of our losses, but in the meantime, submit a PR to add yo

-------------

Object Detection
======================
These are common losses used in object detection.

---------------

GIoU Loss
---------

.. autofunction:: pl_bolts.losses.object_detection.giou_loss
:noindex:

---------------

Reinforcement Learning
======================
These are common losses used in RL.
Expand Down
34 changes: 34 additions & 0 deletions pl_bolts/losses/object_detection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
"""
Generalized Intersection over Union (GIoU) loss (Rezatofighi et. al)
"""

import torch

from pl_bolts.metrics.object_detection import giou


def giou_loss(preds: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""
Calculates the generalized intersection over union loss.
It has been proposed in `Generalized Intersection over Union: A Metric and A
Loss for Bounding Box Regression <https://arxiv.org/abs/1902.09630>`_.
Args:
preds: batch of prediction bounding boxes with representation ``[x_min, y_min, x_max, y_max]``
target: batch of target bounding boxes with representation ``[x_min, y_min, x_max, y_max]``
Example:
>>> import torch
>>> from pl_bolts.losses.object_detection import giou_loss
>>> preds = torch.tensor([[100, 100, 200, 200]])
>>> target = torch.tensor([[150, 150, 250, 250]])
>>> giou_loss(preds, target)
tensor([[1.0794]])
Returns:
GIoU loss
"""
loss = 1 - giou(preds, target)
return loss
42 changes: 42 additions & 0 deletions pl_bolts/metrics/object_detection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import torch


def giou(preds: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""
Calculates the generalized intersection over union.
It has been proposed in `Generalized Intersection over Union: A Metric and A
Loss for Bounding Box Regression <https://arxiv.org/abs/1902.09630>`_.
Args:
preds: batch of prediction bounding boxes with representation ``[x_min, y_min, x_max, y_max]``
target: batch of target bounding boxes with representation ``[x_min, y_min, x_max, y_max]``
Example:
>>> import torch
>>> from pl_bolts.metrics.object_detection import giou
>>> preds = torch.tensor([[100, 100, 200, 200]])
>>> target = torch.tensor([[150, 150, 250, 250]])
>>> giou(preds, target)
tensor([[-0.0794]])
Returns:
GIoU value
"""
x_min = torch.max(preds[:, None, 0], target[:, 0])
y_min = torch.max(preds[:, None, 1], target[:, 1])
x_max = torch.min(preds[:, None, 2], target[:, 2])
y_max = torch.min(preds[:, None, 3], target[:, 3])
intersection = (x_max - x_min).clamp(min=0) * (y_max - y_min).clamp(min=0)
pred_area = (preds[:, 2] - preds[:, 0]) * (preds[:, 3] - preds[:, 1])
target_area = (target[:, 2] - target[:, 0]) * (target[:, 3] - target[:, 1])
union = pred_area[:, None] + target_area - intersection
C_x_min = torch.min(preds[:, None, 0], target[:, 0])
C_y_min = torch.min(preds[:, None, 1], target[:, 1])
C_x_max = torch.max(preds[:, None, 2], target[:, 2])
C_y_max = torch.max(preds[:, None, 3], target[:, 3])
C_area = (C_x_max - C_x_min).clamp(min=0) * (C_y_max - C_y_min).clamp(min=0)
iou = torch.true_divide(intersection, union)
giou = iou - torch.true_divide((C_area - union), C_area)
return giou
23 changes: 23 additions & 0 deletions tests/losses/test_object_detection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
"""
Test Object Detection Loss Functions
"""

import pytest
import torch

from pl_bolts.losses.object_detection import giou_loss


@pytest.mark.parametrize("preds, target, expected_loss", [
(torch.tensor([[100, 100, 200, 200]]), torch.tensor([[100, 100, 200, 200]]), torch.tensor([0.0]))
])
def test_complete_overlap(preds, target, expected_loss):
torch.testing.assert_allclose(giou_loss(preds, target), expected_loss)


@pytest.mark.parametrize("preds, target, expected_loss", [
(torch.tensor([[100, 100, 200, 200]]), torch.tensor([[100, 200, 200, 300]]), torch.tensor([1.0])),
(torch.tensor([[100, 100, 200, 200]]), torch.tensor([[200, 200, 300, 300]]), torch.tensor([1.5])),
])
def test_no_overlap(preds, target, expected_loss):
torch.testing.assert_allclose(giou_loss(preds, target), expected_loss)
Empty file added tests/metrics/__init__.py
Empty file.
23 changes: 23 additions & 0 deletions tests/metrics/test_object_detection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
"""
Test Object Detection Metric Functions
"""

import pytest
import torch

from pl_bolts.metrics.object_detection import giou


@pytest.mark.parametrize("preds, target, expected_giou", [
(torch.tensor([[100, 100, 200, 200]]), torch.tensor([[100, 100, 200, 200]]), torch.tensor([1.0]))
])
def test_complete_overlap(preds, target, expected_giou):
torch.testing.assert_allclose(giou(preds, target), expected_giou)


@pytest.mark.parametrize("preds, target, expected_giou", [
(torch.tensor([[100, 100, 200, 200]]), torch.tensor([[100, 200, 200, 300]]), torch.tensor([0.0])),
(torch.tensor([[100, 100, 200, 200]]), torch.tensor([[200, 200, 300, 300]]), torch.tensor([-0.5])),
])
def test_no_overlap(preds, target, expected_giou):
torch.testing.assert_allclose(giou(preds, target), expected_giou)

0 comments on commit a72766c

Please sign in to comment.