Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implemented GIoU #347

Merged
merged 54 commits into from
Dec 20, 2020
Merged
Show file tree
Hide file tree
Changes from 52 commits
Commits
Show all changes
54 commits
Select commit Hold shift + click to select a range
09bb792
Implemented GIoU
briankosw Nov 9, 2020
53ba53c
Changing max/min to torch.max/torch.min
briankosw Nov 15, 2020
f886be0
Adding tests for GIoU
briankosw Nov 15, 2020
604c213
Adding documentation for GIoU
briankosw Nov 15, 2020
bdb9d2a
Updated docstring
briankosw Nov 15, 2020
ec6a390
Fixed isort error and added link to paper in docstring
briankosw Nov 16, 2020
2b7fb78
Parametrizing tests using pytest
briankosw Nov 17, 2020
e64d51a
Adding changelog and removing eps
briankosw Nov 18, 2020
18ab5db
Fixing error
briankosw Nov 18, 2020
d44be4c
isort test
briankosw Nov 18, 2020
8fe5f2e
Renaming file to object_detection.py
briankosw Dec 2, 2020
f9fc3c7
Reflecting module name change in test
briankosw Dec 2, 2020
08bd127
Fixing doc to reflect module name change
briankosw Dec 2, 2020
f25e4f5
Implemented GIoU
briankosw Nov 9, 2020
a4878db
Changing max/min to torch.max/torch.min
briankosw Nov 15, 2020
417f554
Adding tests for GIoU
briankosw Nov 15, 2020
fb7dce5
Updated docstring
briankosw Nov 15, 2020
09c0514
Fixed isort error and added link to paper in docstring
briankosw Nov 16, 2020
3b6ea80
Parametrizing tests using pytest
briankosw Nov 17, 2020
1d23707
Adding changelog and removing eps
briankosw Nov 18, 2020
aa9f846
Fixing error
briankosw Nov 18, 2020
c871a4e
isort test
briankosw Nov 18, 2020
a2d915e
Renaming file to object_detection.py
briankosw Dec 2, 2020
6723ee5
Reflecting module name change in test
briankosw Dec 2, 2020
a7a3aac
Implemented GIoU
briankosw Nov 9, 2020
4f3db7f
Changing max/min to torch.max/torch.min
briankosw Nov 15, 2020
be61639
Adding tests for GIoU
briankosw Nov 15, 2020
d9b6e64
Updated docstring
briankosw Nov 15, 2020
394261f
Fixed isort error and added link to paper in docstring
briankosw Nov 16, 2020
cc940ab
Parametrizing tests using pytest
briankosw Nov 17, 2020
cc2f486
Adding changelog and removing eps
briankosw Nov 18, 2020
b80828c
Fixing error
briankosw Nov 18, 2020
d289d9d
isort test
briankosw Nov 18, 2020
e0a68c4
Renaming file to object_detection.py
briankosw Dec 2, 2020
7f7bec9
Reflecting module name change in test
briankosw Dec 2, 2020
8ebd93b
Implemented GIoU
briankosw Nov 9, 2020
b4ac0e1
Changing max/min to torch.max/torch.min
briankosw Nov 15, 2020
b32b369
Adding tests for GIoU
briankosw Nov 15, 2020
b5e5d85
Updated docstring
briankosw Nov 15, 2020
2e29634
Fixed isort error and added link to paper in docstring
briankosw Nov 16, 2020
d61ed9b
Parametrizing tests using pytest
briankosw Nov 17, 2020
b6a4dbb
Adding changelog and removing eps
briankosw Nov 18, 2020
a42e111
Fixing error
briankosw Nov 18, 2020
8370947
isort test
briankosw Nov 18, 2020
721e734
Renaming file to object_detection.py
briankosw Dec 2, 2020
edffe9f
Reflecting module name change in test
briankosw Dec 2, 2020
f1d1646
Update pl_bolts/losses/object_detection.py
briankosw Dec 15, 2020
637792b
Updating docstring for giou
briankosw Dec 15, 2020
34c1987
Fixing code formatting
briankosw Dec 15, 2020
730bcdc
Adding doctest
briankosw Dec 15, 2020
8fa52f7
Adding tests for giou
briankosw Dec 15, 2020
ebda4b0
refactor
Borda Dec 20, 2020
d684d2b
fix
Borda Dec 20, 2020
1e0b2d1
format
Borda Dec 20, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added `VisionDataModule` as parent class for `BinaryMNISTDataModule`, `CIFAR10DataModule`, `FashionMNISTDataModule`,
and `MNISTDataModule` ([#400](https://github.com/PyTorchLightning/pytorch-lightning-bolts/pull/400))

- Added GIoU loss ([#347](https://github.com/PyTorchLightning/pytorch-lightning-bolts/pull/347))

### Changed

- Decoupled datamodules from models ([#332](https://github.com/PyTorchLightning/pytorch-lightning-bolts/pull/332), [#270](https://github.com/PyTorchLightning/pytorch-lightning-bolts/pull/270))
Expand Down
14 changes: 14 additions & 0 deletions docs/source/losses.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,20 @@ We're cleaning up many of our losses, but in the meantime, submit a PR to add yo

-------------

Object Detection
======================
These are common losses used in object detection.

---------------

GIoU Loss
---------

.. autofunction:: pl_bolts.losses.object_detection.giou_loss
:noindex:

---------------

Reinforcement Learning
======================
These are common losses used in RL.
Expand Down
34 changes: 34 additions & 0 deletions pl_bolts/losses/object_detection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
"""
Generalized Intersection over Union (GIoU) loss (Rezatofighi et. al)
"""

import torch

from pl_bolts.metrics.object_detection import giou


def giou_loss(preds: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""
Calculates the generalized intersection over union loss.

It has been proposed in `Generalized Intersection over Union: A Metric and A
Loss for Bounding Box Regression <https://arxiv.org/abs/1902.09630>`_.

Args:
preds: batch of prediction bounding boxes with representation ``[x_min, y_min, x_max, y_max]``
target: batch of target bounding boxes with representation ``[x_min, y_min, x_max, y_max]``

Example:

>>> import torch
>>> from pl_bolts.losses.object_detection import giou_loss
>>> preds = torch.tensor([[100, 100, 200, 200]])
>>> target = torch.tensor([[150, 150, 250, 250]])
>>> giou_loss(preds, target)
tensor([[1.0794]])

Returns:
GIoU loss
"""
Borda marked this conversation as resolved.
Show resolved Hide resolved
loss = 1 - giou(preds, target)
return loss
42 changes: 42 additions & 0 deletions pl_bolts/metrics/object_detection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import torch


def giou(preds: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""
Calculates the generalized intersection over union.

It has been proposed in `Generalized Intersection over Union: A Metric and A
Loss for Bounding Box Regression <https://arxiv.org/abs/1902.09630>`_.

Args:
preds: batch of prediction bounding boxes with representation ``[x_min, y_min, x_max, y_max]``
target: batch of target bounding boxes with representation ``[x_min, y_min, x_max, y_max]``

Example:

>>> import torch
>>> from pl_bolts.metrics.object_detection import giou
>>> preds = torch.tensor([[100, 100, 200, 200]])
>>> target = torch.tensor([[150, 150, 250, 250]])
>>> giou(preds, target)
tensor([[-0.0794]])

Returns:
GIoU value
"""
x_min = torch.max(preds[:, None, 0], target[:, 0])
y_min = torch.max(preds[:, None, 1], target[:, 1])
x_max = torch.min(preds[:, None, 2], target[:, 2])
y_max = torch.min(preds[:, None, 3], target[:, 3])
intersection = (x_max - x_min).clamp(min=0) * (y_max - y_min).clamp(min=0)
pred_area = (preds[:, 2] - preds[:, 0]) * (preds[:, 3] - preds[:, 1])
target_area = (target[:, 2] - target[:, 0]) * (target[:, 3] - target[:, 1])
union = pred_area[:, None] + target_area - intersection
C_x_min = torch.min(preds[:, None, 0], target[:, 0])
C_y_min = torch.min(preds[:, None, 1], target[:, 1])
C_x_max = torch.max(preds[:, None, 2], target[:, 2])
C_y_max = torch.max(preds[:, None, 3], target[:, 3])
C_area = (C_x_max - C_x_min).clamp(min=0) * (C_y_max - C_y_min).clamp(min=0)
iou = torch.true_divide(intersection, union)
giou = iou - torch.true_divide((C_area - union), C_area)
return giou
41 changes: 41 additions & 0 deletions tests/losses/test_object_detection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
"""
Test Object Detection Loss Functions
"""

import pytest
import torch

from pl_bolts.losses.object_detection import giou_loss


@pytest.mark.parametrize(
"preds, target, expected_loss",
[
(
torch.tensor([[100, 100, 200, 200]]),
torch.tensor([[100, 100, 200, 200]]),
torch.tensor([0.0]),
)
],
)
def test_complete_overlap(pred, target, expected_loss):
torch.testing.assert_allclose(giou_loss(pred, target), expected_loss)


@pytest.mark.parametrize(
"preds, target, expected_loss",
[
(
torch.tensor([[100, 100, 200, 200]]),
torch.tensor([[100, 200, 200, 300]]),
torch.tensor([1.0]),
),
(
torch.tensor([[100, 100, 200, 200]]),
torch.tensor([[200, 200, 300, 300]]),
torch.tensor([1.5]),
),
],
)
def test_no_overlap(pred, target, expected_loss):
torch.testing.assert_allclose(giou_loss(pred, target), expected_loss)
Empty file added tests/metrics/__init__.py
Empty file.
41 changes: 41 additions & 0 deletions tests/metrics/test_object_detection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
"""
Test Object Detection Metric Functions
"""

import pytest
import torch

from pl_bolts.metrics.object_detection import giou


@pytest.mark.parametrize(
"preds, target, expected_giou",
[
(
torch.tensor([[100, 100, 200, 200]]),
torch.tensor([[100, 100, 200, 200]]),
torch.tensor([1.0]),
)
],
)
def test_complete_overlap(pred, target, expected_giou):
torch.testing.assert_allclose(giou(pred, target), expected_giou)


@pytest.mark.parametrize(
"preds, target, expected_giou",
[
(
torch.tensor([[100, 100, 200, 200]]),
torch.tensor([[100, 200, 200, 300]]),
torch.tensor([0.0]),
),
(
torch.tensor([[100, 100, 200, 200]]),
torch.tensor([[200, 200, 300, 300]]),
torch.tensor([-0.5]),
),
],
)
def test_no_overlap(pred, target, expected_giou):
torch.testing.assert_allclose(giou(pred, target), expected_giou)