Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor: Consolidate Bounding box Transformation and IoU Calculation Functions #587

Merged
merged 21 commits into from
Nov 26, 2024
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 7 additions & 16 deletions src/netspresso_trainer/losses/detection/rtdetr.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
from scipy.optimize import linear_sum_assignment
from torchvision.ops.boxes import box_area

from netspresso_trainer.utils.bbox_utils import transform_bbox


def box_iou(boxes1, boxes2):
area1 = box_area(boxes1)
Expand All @@ -42,17 +44,6 @@ def box_iou(boxes1, boxes2):
iou = inter / union
return iou, union

def box_cxcywh_to_xyxy(x):
x_c, y_c, w, h = x.unbind(-1)
b = [(x_c - 0.5 * w), (y_c - 0.5 * h),
(x_c + 0.5 * w), (y_c + 0.5 * h)]
return torch.stack(b, dim=-1)

def box_xyxy_to_cxcywh(x):
x0, y0, x1, y1 = x.unbind(-1)
b = [(x0 + x1) / 2, (y0 + y1) / 2,
(x1 - x0), (y1 - y0)]
return torch.stack(b, dim=-1)

def generalized_box_iou(boxes1, boxes2):
"""
Expand Down Expand Up @@ -122,10 +113,10 @@ def forward(self, outputs, targets):
cost_class = -out_prob[:, tgt_ids]

# Compute the L1 cost between boxes
cost_bbox = torch.cdist(out_bbox, box_xyxy_to_cxcywh(tgt_bbox), p=1)
cost_bbox = torch.cdist(out_bbox, transform_bbox(tgt_bbox, "xyxy -> cxcywh"), p=1)

# Compute the giou cost betwen boxes
cost_giou = -generalized_box_iou(box_cxcywh_to_xyxy(out_bbox), tgt_bbox)
cost_giou = -generalized_box_iou(transform_bbox(out_bbox, "cxcywh -> xyxy"), tgt_bbox)

# Final cost matrix
C = self.cost_bbox * cost_bbox + self.cost_class * cost_class + self.cost_giou * cost_giou
Expand Down Expand Up @@ -204,7 +195,7 @@ def loss_labels_vfl(self, outputs, targets, indices, num_boxes, log=True):

src_boxes = outputs['pred_boxes'][idx]
target_boxes = torch.cat([t['boxes'][i] for t, (_, i) in zip(targets, indices)], dim=0)
ious, _ = box_iou(box_cxcywh_to_xyxy(src_boxes), target_boxes)
ious, _ = box_iou(transform_bbox(src_boxes, "cxcywh -> xyxy"), target_boxes)
ious = torch.diag(ious).detach()

src_logits = outputs['pred_logits']
Expand Down Expand Up @@ -250,11 +241,11 @@ def loss_boxes(self, outputs, targets, indices, num_boxes):
target_boxes = torch.cat([t['boxes'][i] for t, (_, i) in zip(targets, indices)], dim=0)
losses = {}

loss_bbox = F.l1_loss(src_boxes, box_xyxy_to_cxcywh(target_boxes), reduction='none')
loss_bbox = F.l1_loss(src_boxes, transform_bbox(target_boxes, "xyxy -> cxcywh"), reduction='none')
losses['loss_bbox'] = loss_bbox.sum() / num_boxes

loss_giou = 1 - torch.diag(generalized_box_iou(
box_cxcywh_to_xyxy(src_boxes),
transform_bbox(src_boxes, "cxcywh -> xyxy"),
target_boxes))
losses['loss_giou'] = loss_giou.sum() / num_boxes
return losses
Expand Down
38 changes: 3 additions & 35 deletions src/netspresso_trainer/losses/detection/yolo.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,41 +21,9 @@
import torch.nn as nn
import torch.nn.functional as F

from .yolox import IOUloss, YOLOXLoss, xyxy2cxcywh


def xyxy2cxcywhn(bboxes, img_size):
new_bboxes = bboxes.clone() / img_size
new_bboxes[:, 2] = new_bboxes[:, 2] - new_bboxes[:, 0]
new_bboxes[:, 3] = new_bboxes[:, 3] - new_bboxes[:, 1]
new_bboxes[:, 0] = new_bboxes[:, 0] + new_bboxes[:, 2] * 0.5
new_bboxes[:, 1] = new_bboxes[:, 1] + new_bboxes[:, 3] * 0.5
return new_bboxes

def bboxes_iou(bboxes_a, bboxes_b, xyxy=False):
if bboxes_a.shape[1] != 4 or bboxes_b.shape[1] != 4:
raise IndexError

if xyxy:
tl = torch.max(bboxes_a[:, None, :2], bboxes_b[:, :2])
br = torch.min(bboxes_a[:, None, 2:], bboxes_b[:, 2:])
area_a = torch.prod(bboxes_a[:, 2:] - bboxes_a[:, :2], 1)
area_b = torch.prod(bboxes_b[:, 2:] - bboxes_b[:, :2], 1)
else:
tl = torch.max(
(bboxes_a[:, None, :2] - bboxes_a[:, None, 2:] / 2),
(bboxes_b[:, :2] - bboxes_b[:, 2:] / 2),
)
br = torch.min(
(bboxes_a[:, None, :2] + bboxes_a[:, None, 2:] / 2),
(bboxes_b[:, :2] + bboxes_b[:, 2:] / 2),
)
from netspresso_trainer.utils.bbox_utils import transform_bbox

area_a = torch.prod(bboxes_a[:, 2:], 1)
area_b = torch.prod(bboxes_b[:, 2:], 1)
en = (tl < br).type(tl.type()).prod(dim=2)
area_i = torch.prod(br - tl, 2) * en # * ((tl < br).all())
return area_i / (area_a[:, None] + area_b - area_i)
from .yolox import IOUloss, YOLOXLoss


class YOLOFastestLoss(YOLOXLoss):
Expand Down Expand Up @@ -121,7 +89,7 @@ def forward(self, out: List, target: Dict) -> torch.Tensor:
# YOLOX model learns box cxcywh format directly,
# but our detection dataloader gives xyxy format.
for i in range(len(target)):
target[i]['boxes'] = xyxy2cxcywh(target[i]['boxes'])
target[i]['boxes'] = transform_bbox(target[i]['boxes'], "xyxy -> cxcywh")

# Ready for l1 loss
origin_preds = []
Expand Down
63 changes: 5 additions & 58 deletions src/netspresso_trainer/losses/detection/yolox.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,39 +20,7 @@
import torch.nn as nn
import torch.nn.functional as F


def bboxes_iou(bboxes_a, bboxes_b, xyxy=True):
if bboxes_a.shape[1] != 4 or bboxes_b.shape[1] != 4:
raise IndexError

if xyxy:
tl = torch.max(bboxes_a[:, None, :2], bboxes_b[:, :2])
br = torch.min(bboxes_a[:, None, 2:], bboxes_b[:, 2:])
area_a = torch.prod(bboxes_a[:, 2:] - bboxes_a[:, :2], 1)
area_b = torch.prod(bboxes_b[:, 2:] - bboxes_b[:, :2], 1)
else:
tl = torch.max(
(bboxes_a[:, None, :2] - bboxes_a[:, None, 2:] / 2),
(bboxes_b[:, :2] - bboxes_b[:, 2:] / 2),
)
br = torch.min(
(bboxes_a[:, None, :2] + bboxes_a[:, None, 2:] / 2),
(bboxes_b[:, :2] + bboxes_b[:, 2:] / 2),
)

area_a = torch.prod(bboxes_a[:, 2:], 1)
area_b = torch.prod(bboxes_b[:, 2:], 1)
en = (tl < br).type(tl.type()).prod(dim=2)
area_i = torch.prod(br - tl, 2) * en # * ((tl < br).all())
return area_i / (area_a[:, None] + area_b - area_i)


def xyxy2cxcywh(bboxes):
bboxes[:, 2] = bboxes[:, 2] - bboxes[:, 0]
bboxes[:, 3] = bboxes[:, 3] - bboxes[:, 1]
bboxes[:, 0] = bboxes[:, 0] + bboxes[:, 2] * 0.5
bboxes[:, 1] = bboxes[:, 1] + bboxes[:, 3] * 0.5
return bboxes
from netspresso_trainer.utils.bbox_utils import bboxes_iou, calculate_iou, transform_bbox


class YOLOXLoss(nn.Module):
Expand Down Expand Up @@ -107,7 +75,7 @@ def forward(self, out: List, target: Dict) -> torch.Tensor:
# YOLOX model learns box cxcywh format directly,
# but our detection dataloader gives xyxy format.
for i in range(len(target)):
target[i]['boxes'] = xyxy2cxcywh(target[i]['boxes'])
target[i]['boxes'] = transform_bbox(target[i]['boxes'], "xyxy -> cxcywh")

# Ready for l1 loss
origin_preds = []
Expand Down Expand Up @@ -505,33 +473,12 @@ def forward(self, pred, target):

pred = pred.view(-1, 4)
target = target.view(-1, 4)
tl = torch.max(
(pred[:, :2] - pred[:, 2:] / 2), (target[:, :2] - target[:, 2:] / 2)
)
br = torch.min(
(pred[:, :2] + pred[:, 2:] / 2), (target[:, :2] + target[:, 2:] / 2)
)

area_p = torch.prod(pred[:, 2:], 1)
area_g = torch.prod(target[:, 2:], 1)

en = (tl < br).type(tl.type()).prod(dim=1)
area_i = torch.prod(br - tl, 1) * en
area_u = area_p + area_g - area_i
iou = (area_i) / (area_u + 1e-16)
iou = calculate_iou(pred, target, metric=self.loss_type)

if self.loss_type == "iou":
loss = 1 - iou ** 2
elif self.loss_type == "giou":
c_tl = torch.min(
(pred[:, :2] - pred[:, 2:] / 2), (target[:, :2] - target[:, 2:] / 2)
)
c_br = torch.max(
(pred[:, :2] + pred[:, 2:] / 2), (target[:, :2] + target[:, 2:] / 2)
)
area_c = torch.prod(c_br - c_tl, 1)
giou = iou - (area_c - area_u) / area_c.clamp(1e-16)
loss = 1 - giou.clamp(min=-1.0, max=1.0)
elif self.loss_type == "giou" or self.loss_type == "diou" or self.loss_type == "ciou":
loss = 1 - iou.clamp(min=-1.0, max=1.0)

if self.reduction == "mean":
loss = loss.mean()
Expand Down
18 changes: 5 additions & 13 deletions src/netspresso_trainer/postprocessors/detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
from torchvision.models.detection._utils import BoxCoder, _topk_min
from torchvision.ops import boxes as box_ops

from netspresso_trainer.utils.bbox_utils import transform_bbox

from ..models.utils import ModelOutput


Expand All @@ -30,12 +32,8 @@ def rtdetr_decode(pred, original_shape, num_top_queries=300, score_thresh=0.0):
boxes, logits = pred[..., :4], pred[..., 4:]

num_classes = logits.shape[-1]

boxes = torchvision.ops.box_convert(boxes, in_fmt='cxcywh', out_fmt='xyxy')

h, w = original_shape[1], original_shape[2]
boxes[..., ::2] *= w
boxes[..., 1::2] *= h
boxes = transform_bbox(boxes, "cxcywhn -> xyxy", img_size=(w, h))

scores = torch.sigmoid(logits)
scores, index = torch.topk(scores.flatten(1), num_top_queries, axis=-1)
Expand Down Expand Up @@ -141,10 +139,7 @@ def anchor_free_decoupled_head_decode(pred, original_shape, score_thresh=0.7):
], dim=-1)

box_corner = pred.new(pred.shape)
box_corner[:, :, 0] = pred[:, :, 0] - pred[:, :, 2] / 2
box_corner[:, :, 1] = pred[:, :, 1] - pred[:, :, 3] / 2
box_corner[:, :, 2] = pred[:, :, 0] + pred[:, :, 2] / 2
box_corner[:, :, 3] = pred[:, :, 1] + pred[:, :, 3] / 2
box_corner[:, :, :4] = transform_bbox(pred[:, :, :4], "cxcywh -> xyxy")
pred[:, :, :4] = box_corner[:, :, :4]

# Discard boxes with low score
Expand Down Expand Up @@ -193,10 +188,7 @@ def yolo_fastest_head_decode(pred, original_shape, score_thresh=0.7, anchors=Non
pred = torch.cat(preds, dim=1)

box_corner = pred.new(pred.shape)
box_corner[:, :, 0] = pred[:, :, 0] - pred[:, :, 2] / 2
box_corner[:, :, 1] = pred[:, :, 1] - pred[:, :, 3] / 2
box_corner[:, :, 2] = pred[:, :, 0] + pred[:, :, 2] / 2
box_corner[:, :, 3] = pred[:, :, 1] + pred[:, :, 3] / 2
box_corner[:, :, :4] = transform_bbox(pred[:, :, :4], "cxcywh -> xyxy")
pred[:, :, :4] = box_corner[:, :, :4]

# Discard boxes with low score
Expand Down
Loading
Loading