-
Notifications
You must be signed in to change notification settings - Fork 323
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Implemented GIoU * Changing max/min to torch.max/torch.min * Adding tests for GIoU * Adding documentation for GIoU * Updated docstring Co-authored-by: Jeff Yang <ydcjeff@outlook.com> * Fixed isort error and added link to paper in docstring * Parametrizing tests using pytest * Adding changelog and removing eps * Fixing error * isort test * Renaming file to object_detection.py * Reflecting module name change in test * Fixing doc to reflect module name change * Implemented GIoU * Changing max/min to torch.max/torch.min * Adding tests for GIoU * Updated docstring Co-authored-by: Jeff Yang <ydcjeff@outlook.com> * Fixed isort error and added link to paper in docstring * Parametrizing tests using pytest * Adding changelog and removing eps * Fixing error * isort test * Renaming file to object_detection.py * Reflecting module name change in test * Implemented GIoU * Changing max/min to torch.max/torch.min * Adding tests for GIoU * Updated docstring Co-authored-by: Jeff Yang <ydcjeff@outlook.com> * Fixed isort error and added link to paper in docstring * Parametrizing tests using pytest * Adding changelog and removing eps * Fixing error * isort test * Renaming file to object_detection.py * Reflecting module name change in test * Implemented GIoU * Changing max/min to torch.max/torch.min * Adding tests for GIoU * Updated docstring Co-authored-by: Jeff Yang <ydcjeff@outlook.com> * Fixed isort error and added link to paper in docstring * Parametrizing tests using pytest * Adding changelog and removing eps * Fixing error * isort test * Renaming file to object_detection.py * Reflecting module name change in test * Update pl_bolts/losses/object_detection.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Updating docstring for giou * Fixing code formatting * Adding doctest * Adding tests for giou * refactor * fix * format Co-authored-by: Jeff Yang <ydcjeff@outlook.com> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: Jirka Borovec <jirka.borovec@seznam.cz>
- Loading branch information
1 parent
4309cab
commit a72766c
Showing
7 changed files
with
138 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
""" | ||
Generalized Intersection over Union (GIoU) loss (Rezatofighi et. al) | ||
""" | ||
|
||
import torch | ||
|
||
from pl_bolts.metrics.object_detection import giou | ||
|
||
|
||
def giou_loss(preds: torch.Tensor, target: torch.Tensor) -> torch.Tensor: | ||
""" | ||
Calculates the generalized intersection over union loss. | ||
It has been proposed in `Generalized Intersection over Union: A Metric and A | ||
Loss for Bounding Box Regression <https://arxiv.org/abs/1902.09630>`_. | ||
Args: | ||
preds: batch of prediction bounding boxes with representation ``[x_min, y_min, x_max, y_max]`` | ||
target: batch of target bounding boxes with representation ``[x_min, y_min, x_max, y_max]`` | ||
Example: | ||
>>> import torch | ||
>>> from pl_bolts.losses.object_detection import giou_loss | ||
>>> preds = torch.tensor([[100, 100, 200, 200]]) | ||
>>> target = torch.tensor([[150, 150, 250, 250]]) | ||
>>> giou_loss(preds, target) | ||
tensor([[1.0794]]) | ||
Returns: | ||
GIoU loss | ||
""" | ||
loss = 1 - giou(preds, target) | ||
return loss |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
import torch | ||
|
||
|
||
def giou(preds: torch.Tensor, target: torch.Tensor) -> torch.Tensor: | ||
""" | ||
Calculates the generalized intersection over union. | ||
It has been proposed in `Generalized Intersection over Union: A Metric and A | ||
Loss for Bounding Box Regression <https://arxiv.org/abs/1902.09630>`_. | ||
Args: | ||
preds: batch of prediction bounding boxes with representation ``[x_min, y_min, x_max, y_max]`` | ||
target: batch of target bounding boxes with representation ``[x_min, y_min, x_max, y_max]`` | ||
Example: | ||
>>> import torch | ||
>>> from pl_bolts.metrics.object_detection import giou | ||
>>> preds = torch.tensor([[100, 100, 200, 200]]) | ||
>>> target = torch.tensor([[150, 150, 250, 250]]) | ||
>>> giou(preds, target) | ||
tensor([[-0.0794]]) | ||
Returns: | ||
GIoU value | ||
""" | ||
x_min = torch.max(preds[:, None, 0], target[:, 0]) | ||
y_min = torch.max(preds[:, None, 1], target[:, 1]) | ||
x_max = torch.min(preds[:, None, 2], target[:, 2]) | ||
y_max = torch.min(preds[:, None, 3], target[:, 3]) | ||
intersection = (x_max - x_min).clamp(min=0) * (y_max - y_min).clamp(min=0) | ||
pred_area = (preds[:, 2] - preds[:, 0]) * (preds[:, 3] - preds[:, 1]) | ||
target_area = (target[:, 2] - target[:, 0]) * (target[:, 3] - target[:, 1]) | ||
union = pred_area[:, None] + target_area - intersection | ||
C_x_min = torch.min(preds[:, None, 0], target[:, 0]) | ||
C_y_min = torch.min(preds[:, None, 1], target[:, 1]) | ||
C_x_max = torch.max(preds[:, None, 2], target[:, 2]) | ||
C_y_max = torch.max(preds[:, None, 3], target[:, 3]) | ||
C_area = (C_x_max - C_x_min).clamp(min=0) * (C_y_max - C_y_min).clamp(min=0) | ||
iou = torch.true_divide(intersection, union) | ||
giou = iou - torch.true_divide((C_area - union), C_area) | ||
return giou |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
""" | ||
Test Object Detection Loss Functions | ||
""" | ||
|
||
import pytest | ||
import torch | ||
|
||
from pl_bolts.losses.object_detection import giou_loss | ||
|
||
|
||
@pytest.mark.parametrize("preds, target, expected_loss", [ | ||
(torch.tensor([[100, 100, 200, 200]]), torch.tensor([[100, 100, 200, 200]]), torch.tensor([0.0])) | ||
]) | ||
def test_complete_overlap(preds, target, expected_loss): | ||
torch.testing.assert_allclose(giou_loss(preds, target), expected_loss) | ||
|
||
|
||
@pytest.mark.parametrize("preds, target, expected_loss", [ | ||
(torch.tensor([[100, 100, 200, 200]]), torch.tensor([[100, 200, 200, 300]]), torch.tensor([1.0])), | ||
(torch.tensor([[100, 100, 200, 200]]), torch.tensor([[200, 200, 300, 300]]), torch.tensor([1.5])), | ||
]) | ||
def test_no_overlap(preds, target, expected_loss): | ||
torch.testing.assert_allclose(giou_loss(preds, target), expected_loss) |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
""" | ||
Test Object Detection Metric Functions | ||
""" | ||
|
||
import pytest | ||
import torch | ||
|
||
from pl_bolts.metrics.object_detection import giou | ||
|
||
|
||
@pytest.mark.parametrize("preds, target, expected_giou", [ | ||
(torch.tensor([[100, 100, 200, 200]]), torch.tensor([[100, 100, 200, 200]]), torch.tensor([1.0])) | ||
]) | ||
def test_complete_overlap(preds, target, expected_giou): | ||
torch.testing.assert_allclose(giou(preds, target), expected_giou) | ||
|
||
|
||
@pytest.mark.parametrize("preds, target, expected_giou", [ | ||
(torch.tensor([[100, 100, 200, 200]]), torch.tensor([[100, 200, 200, 300]]), torch.tensor([0.0])), | ||
(torch.tensor([[100, 100, 200, 200]]), torch.tensor([[200, 200, 300, 300]]), torch.tensor([-0.5])), | ||
]) | ||
def test_no_overlap(preds, target, expected_giou): | ||
torch.testing.assert_allclose(giou(preds, target), expected_giou) |