diff --git a/CHANGELOG.md b/CHANGELOG.md index b6697a22..b052b1b7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,6 +19,7 @@ No changes to highlight. - Refactor RT-DETR and generalize CSPRepLayer and RepVGG block by `@hglee98` in [PR 581](https://github.com/Nota-NetsPresso/netspresso-trainer/pull/581) - Generalize 2d pooling layers and define as custom layer by `@hglee98` in [PR 583](https://github.com/Nota-NetsPresso/netspresso-trainer/pull/583) +- Unify bbox transformation and IoU computing methods by `@hglee98` in [PR 587](https://github.com/Nota-NetsPresso/netspresso-trainer/pull/587) # v1.0.3 diff --git a/src/netspresso_trainer/losses/detection/rtdetr.py b/src/netspresso_trainer/losses/detection/rtdetr.py index 73a5b836..716d712e 100644 --- a/src/netspresso_trainer/losses/detection/rtdetr.py +++ b/src/netspresso_trainer/losses/detection/rtdetr.py @@ -26,6 +26,8 @@ from scipy.optimize import linear_sum_assignment from torchvision.ops.boxes import box_area +from netspresso_trainer.utils.bbox_utils import transform_bbox + def box_iou(boxes1, boxes2): area1 = box_area(boxes1) @@ -42,17 +44,6 @@ def box_iou(boxes1, boxes2): iou = inter / union return iou, union -def box_cxcywh_to_xyxy(x): - x_c, y_c, w, h = x.unbind(-1) - b = [(x_c - 0.5 * w), (y_c - 0.5 * h), - (x_c + 0.5 * w), (y_c + 0.5 * h)] - return torch.stack(b, dim=-1) - -def box_xyxy_to_cxcywh(x): - x0, y0, x1, y1 = x.unbind(-1) - b = [(x0 + x1) / 2, (y0 + y1) / 2, - (x1 - x0), (y1 - y0)] - return torch.stack(b, dim=-1) def generalized_box_iou(boxes1, boxes2): """ @@ -122,10 +113,10 @@ def forward(self, outputs, targets): cost_class = -out_prob[:, tgt_ids] # Compute the L1 cost between boxes - cost_bbox = torch.cdist(out_bbox, box_xyxy_to_cxcywh(tgt_bbox), p=1) + cost_bbox = torch.cdist(out_bbox, transform_bbox(tgt_bbox, "xyxy -> cxcywh"), p=1) # Compute the giou cost betwen boxes - cost_giou = -generalized_box_iou(box_cxcywh_to_xyxy(out_bbox), tgt_bbox) + cost_giou = -generalized_box_iou(transform_bbox(out_bbox, "cxcywh -> xyxy"), tgt_bbox) # Final cost matrix C = self.cost_bbox * cost_bbox + self.cost_class * cost_class + self.cost_giou * cost_giou @@ -204,7 +195,7 @@ def loss_labels_vfl(self, outputs, targets, indices, num_boxes, log=True): src_boxes = outputs['pred_boxes'][idx] target_boxes = torch.cat([t['boxes'][i] for t, (_, i) in zip(targets, indices)], dim=0) - ious, _ = box_iou(box_cxcywh_to_xyxy(src_boxes), target_boxes) + ious, _ = box_iou(transform_bbox(src_boxes, "cxcywh -> xyxy"), target_boxes) ious = torch.diag(ious).detach() src_logits = outputs['pred_logits'] @@ -250,11 +241,11 @@ def loss_boxes(self, outputs, targets, indices, num_boxes): target_boxes = torch.cat([t['boxes'][i] for t, (_, i) in zip(targets, indices)], dim=0) losses = {} - loss_bbox = F.l1_loss(src_boxes, box_xyxy_to_cxcywh(target_boxes), reduction='none') + loss_bbox = F.l1_loss(src_boxes, transform_bbox(target_boxes, "xyxy -> cxcywh"), reduction='none') losses['loss_bbox'] = loss_bbox.sum() / num_boxes loss_giou = 1 - torch.diag(generalized_box_iou( - box_cxcywh_to_xyxy(src_boxes), + transform_bbox(src_boxes, "cxcywh -> xyxy"), target_boxes)) losses['loss_giou'] = loss_giou.sum() / num_boxes return losses diff --git a/src/netspresso_trainer/losses/detection/yolo.py b/src/netspresso_trainer/losses/detection/yolo.py index 127fd0a2..47d32f50 100644 --- a/src/netspresso_trainer/losses/detection/yolo.py +++ b/src/netspresso_trainer/losses/detection/yolo.py @@ -21,41 +21,9 @@ import torch.nn as nn import torch.nn.functional as F -from .yolox import IOUloss, YOLOXLoss, xyxy2cxcywh - - -def xyxy2cxcywhn(bboxes, img_size): - new_bboxes = bboxes.clone() / img_size - new_bboxes[:, 2] = new_bboxes[:, 2] - new_bboxes[:, 0] - new_bboxes[:, 3] = new_bboxes[:, 3] - new_bboxes[:, 1] - new_bboxes[:, 0] = new_bboxes[:, 0] + new_bboxes[:, 2] * 0.5 - new_bboxes[:, 1] = new_bboxes[:, 1] + new_bboxes[:, 3] * 0.5 - return new_bboxes - -def bboxes_iou(bboxes_a, bboxes_b, xyxy=False): - if bboxes_a.shape[1] != 4 or bboxes_b.shape[1] != 4: - raise IndexError - - if xyxy: - tl = torch.max(bboxes_a[:, None, :2], bboxes_b[:, :2]) - br = torch.min(bboxes_a[:, None, 2:], bboxes_b[:, 2:]) - area_a = torch.prod(bboxes_a[:, 2:] - bboxes_a[:, :2], 1) - area_b = torch.prod(bboxes_b[:, 2:] - bboxes_b[:, :2], 1) - else: - tl = torch.max( - (bboxes_a[:, None, :2] - bboxes_a[:, None, 2:] / 2), - (bboxes_b[:, :2] - bboxes_b[:, 2:] / 2), - ) - br = torch.min( - (bboxes_a[:, None, :2] + bboxes_a[:, None, 2:] / 2), - (bboxes_b[:, :2] + bboxes_b[:, 2:] / 2), - ) +from netspresso_trainer.utils.bbox_utils import transform_bbox - area_a = torch.prod(bboxes_a[:, 2:], 1) - area_b = torch.prod(bboxes_b[:, 2:], 1) - en = (tl < br).type(tl.type()).prod(dim=2) - area_i = torch.prod(br - tl, 2) * en # * ((tl < br).all()) - return area_i / (area_a[:, None] + area_b - area_i) +from .yolox import IOUloss, YOLOXLoss class YOLOFastestLoss(YOLOXLoss): @@ -121,7 +89,7 @@ def forward(self, out: List, target: Dict) -> torch.Tensor: # YOLOX model learns box cxcywh format directly, # but our detection dataloader gives xyxy format. for i in range(len(target)): - target[i]['boxes'] = xyxy2cxcywh(target[i]['boxes']) + target[i]['boxes'] = transform_bbox(target[i]['boxes'], "xyxy -> cxcywh") # Ready for l1 loss origin_preds = [] diff --git a/src/netspresso_trainer/losses/detection/yolox.py b/src/netspresso_trainer/losses/detection/yolox.py index 5403d693..25067ee1 100644 --- a/src/netspresso_trainer/losses/detection/yolox.py +++ b/src/netspresso_trainer/losses/detection/yolox.py @@ -20,39 +20,7 @@ import torch.nn as nn import torch.nn.functional as F - -def bboxes_iou(bboxes_a, bboxes_b, xyxy=True): - if bboxes_a.shape[1] != 4 or bboxes_b.shape[1] != 4: - raise IndexError - - if xyxy: - tl = torch.max(bboxes_a[:, None, :2], bboxes_b[:, :2]) - br = torch.min(bboxes_a[:, None, 2:], bboxes_b[:, 2:]) - area_a = torch.prod(bboxes_a[:, 2:] - bboxes_a[:, :2], 1) - area_b = torch.prod(bboxes_b[:, 2:] - bboxes_b[:, :2], 1) - else: - tl = torch.max( - (bboxes_a[:, None, :2] - bboxes_a[:, None, 2:] / 2), - (bboxes_b[:, :2] - bboxes_b[:, 2:] / 2), - ) - br = torch.min( - (bboxes_a[:, None, :2] + bboxes_a[:, None, 2:] / 2), - (bboxes_b[:, :2] + bboxes_b[:, 2:] / 2), - ) - - area_a = torch.prod(bboxes_a[:, 2:], 1) - area_b = torch.prod(bboxes_b[:, 2:], 1) - en = (tl < br).type(tl.type()).prod(dim=2) - area_i = torch.prod(br - tl, 2) * en # * ((tl < br).all()) - return area_i / (area_a[:, None] + area_b - area_i) - - -def xyxy2cxcywh(bboxes): - bboxes[:, 2] = bboxes[:, 2] - bboxes[:, 0] - bboxes[:, 3] = bboxes[:, 3] - bboxes[:, 1] - bboxes[:, 0] = bboxes[:, 0] + bboxes[:, 2] * 0.5 - bboxes[:, 1] = bboxes[:, 1] + bboxes[:, 3] * 0.5 - return bboxes +from netspresso_trainer.utils.bbox_utils import bboxes_iou, calculate_iou, transform_bbox class YOLOXLoss(nn.Module): @@ -107,7 +75,7 @@ def forward(self, out: List, target: Dict) -> torch.Tensor: # YOLOX model learns box cxcywh format directly, # but our detection dataloader gives xyxy format. for i in range(len(target)): - target[i]['boxes'] = xyxy2cxcywh(target[i]['boxes']) + target[i]['boxes'] = transform_bbox(target[i]['boxes'], "xyxy -> cxcywh") # Ready for l1 loss origin_preds = [] @@ -505,33 +473,12 @@ def forward(self, pred, target): pred = pred.view(-1, 4) target = target.view(-1, 4) - tl = torch.max( - (pred[:, :2] - pred[:, 2:] / 2), (target[:, :2] - target[:, 2:] / 2) - ) - br = torch.min( - (pred[:, :2] + pred[:, 2:] / 2), (target[:, :2] + target[:, 2:] / 2) - ) - - area_p = torch.prod(pred[:, 2:], 1) - area_g = torch.prod(target[:, 2:], 1) - - en = (tl < br).type(tl.type()).prod(dim=1) - area_i = torch.prod(br - tl, 1) * en - area_u = area_p + area_g - area_i - iou = (area_i) / (area_u + 1e-16) + iou = calculate_iou(pred, target, metric=self.loss_type) if self.loss_type == "iou": loss = 1 - iou ** 2 - elif self.loss_type == "giou": - c_tl = torch.min( - (pred[:, :2] - pred[:, 2:] / 2), (target[:, :2] - target[:, 2:] / 2) - ) - c_br = torch.max( - (pred[:, :2] + pred[:, 2:] / 2), (target[:, :2] + target[:, 2:] / 2) - ) - area_c = torch.prod(c_br - c_tl, 1) - giou = iou - (area_c - area_u) / area_c.clamp(1e-16) - loss = 1 - giou.clamp(min=-1.0, max=1.0) + elif self.loss_type == "giou" or self.loss_type == "diou" or self.loss_type == "ciou": + loss = 1 - iou.clamp(min=-1.0, max=1.0) if self.reduction == "mean": loss = loss.mean() diff --git a/src/netspresso_trainer/postprocessors/detection.py b/src/netspresso_trainer/postprocessors/detection.py index cfae27e7..604fc699 100644 --- a/src/netspresso_trainer/postprocessors/detection.py +++ b/src/netspresso_trainer/postprocessors/detection.py @@ -22,6 +22,8 @@ from torchvision.models.detection._utils import BoxCoder, _topk_min from torchvision.ops import boxes as box_ops +from netspresso_trainer.utils.bbox_utils import transform_bbox + from ..models.utils import ModelOutput @@ -30,12 +32,8 @@ def rtdetr_decode(pred, original_shape, num_top_queries=300, score_thresh=0.0): boxes, logits = pred[..., :4], pred[..., 4:] num_classes = logits.shape[-1] - - boxes = torchvision.ops.box_convert(boxes, in_fmt='cxcywh', out_fmt='xyxy') - h, w = original_shape[1], original_shape[2] - boxes[..., ::2] *= w - boxes[..., 1::2] *= h + boxes = transform_bbox(boxes, "cxcywhn -> xyxy", image_shape=(h, w)) scores = torch.sigmoid(logits) scores, index = torch.topk(scores.flatten(1), num_top_queries, axis=-1) @@ -141,10 +139,7 @@ def anchor_free_decoupled_head_decode(pred, original_shape, score_thresh=0.7): ], dim=-1) box_corner = pred.new(pred.shape) - box_corner[:, :, 0] = pred[:, :, 0] - pred[:, :, 2] / 2 - box_corner[:, :, 1] = pred[:, :, 1] - pred[:, :, 3] / 2 - box_corner[:, :, 2] = pred[:, :, 0] + pred[:, :, 2] / 2 - box_corner[:, :, 3] = pred[:, :, 1] + pred[:, :, 3] / 2 + box_corner[:, :, :4] = transform_bbox(pred[:, :, :4], "cxcywh -> xyxy") pred[:, :, :4] = box_corner[:, :, :4] # Discard boxes with low score @@ -193,10 +188,7 @@ def yolo_fastest_head_decode(pred, original_shape, score_thresh=0.7, anchors=Non pred = torch.cat(preds, dim=1) box_corner = pred.new(pred.shape) - box_corner[:, :, 0] = pred[:, :, 0] - pred[:, :, 2] / 2 - box_corner[:, :, 1] = pred[:, :, 1] - pred[:, :, 3] / 2 - box_corner[:, :, 2] = pred[:, :, 0] + pred[:, :, 2] / 2 - box_corner[:, :, 3] = pred[:, :, 1] + pred[:, :, 3] / 2 + box_corner[:, :, :4] = transform_bbox(pred[:, :, :4], "cxcywh -> xyxy") pred[:, :, :4] = box_corner[:, :, :4] # Discard boxes with low score diff --git a/src/netspresso_trainer/utils/bbox_utils.py b/src/netspresso_trainer/utils/bbox_utils.py new file mode 100644 index 00000000..d76fa0b0 --- /dev/null +++ b/src/netspresso_trainer/utils/bbox_utils.py @@ -0,0 +1,163 @@ +import math +from typing import Dict, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor +from torch.fx.proxy import Proxy + + +def transform_bbox(bboxes: Union[Tensor, Proxy], + indicator="xywh -> xyxy", + image_shape: Optional[Union[int, Tuple[int, int]]]=None): + def is_normalized(fmt: str) -> bool: + return fmt.endswith('n') + + VALID_IN_TYPE = VALID_OUT_TYPE = ["xyxy", "xyxyn", "xywh", "xywhn", "cxcywh", "cxcywhn"] + dtype = bboxes.dtype + in_type, out_type = indicator.replace(" ", "").split("->") + assert in_type in VALID_IN_TYPE, f"Invalid in_type: '{in_type}'. Must be one of {VALID_IN_TYPE}." + assert out_type in VALID_OUT_TYPE, f"Invalid out_type: '{out_type}'. Must be one of {VALID_OUT_TYPE}." + + if is_normalized(in_type): + assert image_shape is not None, f"image_shape is required for normalized conversion: {indicator}" + if isinstance(image_shape, int): + img_height = img_width = image_shape + else: + img_height, img_width = image_shape + assert isinstance(img_height, int) and isinstance(img_width, int), \ + f"Invalid type: (height: {type(img_height)}, width: {type(img_width)}. Must be (int, int))" + in_type = in_type[:-1] + else: + img_height = img_width = 1.0 + + if in_type == "xyxy": + x_min, y_min, x_max, y_max = bboxes.unbind(-1) + elif in_type == "xywh": + x_min, y_min, w, h = bboxes.unbind(-1) + x_max = x_min + w + y_max = y_min + h + elif in_type == "cxcywh": + cx, cy, w, h = bboxes.unbind(-1) + x_min = cx - w / 2 + y_min = cy - h / 2 + x_max = cx + w / 2 + y_max = cy + h / 2 + + x_min *= img_width + y_min *= img_height + x_max *= img_width + y_max *= img_height + assert (x_max >= x_min).all(), "Invalid box: x_max < x_min" + assert (y_max >= y_min).all(), "Invalid box: y_max < y_min" + + if is_normalized(out_type): + assert image_shape is not None, f"img_size is required for normalized conversion: {indicator}" + if isinstance(image_shape, int): + img_height = img_width = image_shape + else: + img_height, img_width = image_shape + assert isinstance(img_height, int) and isinstance(img_width, int), \ + f"Invalid type: (height: {type(img_height)}, width: {type(img_width)}. Must be (int, int))" + out_type = out_type[:-1] + else: + img_height = img_width = 1.0 + + x_min /= img_width + y_min /= img_height + x_max /= img_width + y_max /= img_height + if out_type == "xywh": + bbox = torch.stack([x_min, y_min, x_max - x_min, y_max - y_min], dim=-1) + elif out_type == "xyxy": + bbox = torch.stack([x_min, y_min, x_max, y_max], dim=-1) + elif out_type == "cxcywh": + bbox = torch.stack([(x_min + x_max) / 2, (y_min + y_max) / 2, x_max - x_min, y_max - y_min], dim=-1) + + return bbox.to(dtype=dtype) + +def bboxes_iou(bboxes_a, bboxes_b, xyxy=True): + if bboxes_a.shape[1] != 4 or bboxes_b.shape[1] != 4: + raise IndexError + + if xyxy: + tl = torch.max(bboxes_a[:, None, :2], bboxes_b[:, :2]) + br = torch.min(bboxes_a[:, None, 2:], bboxes_b[:, 2:]) + area_a = torch.prod(bboxes_a[:, 2:] - bboxes_a[:, :2], 1) + area_b = torch.prod(bboxes_b[:, 2:] - bboxes_b[:, :2], 1) + else: + tl = torch.max( + (bboxes_a[:, None, :2] - bboxes_a[:, None, 2:] / 2), + (bboxes_b[:, :2] - bboxes_b[:, 2:] / 2), + ) + br = torch.min( + (bboxes_a[:, None, :2] + bboxes_a[:, None, 2:] / 2), + (bboxes_b[:, :2] + bboxes_b[:, 2:] / 2), + ) + + area_a = torch.prod(bboxes_a[:, 2:], 1) + area_b = torch.prod(bboxes_b[:, 2:], 1) + en = (tl < br).type(tl.type()).prod(dim=2) + area_i = torch.prod(br - tl, 2) * en # * ((tl < br).all()) + return area_i / (area_a[:, None] + area_b - area_i) + + +def calculate_iou(bbox1, bbox2, metric="iou", EPS=1e-7) -> Tensor: + VALID_METRICS = {"iou", "giou", "diou", "ciou"} + assert metric.lower() in VALID_METRICS, f"Invalid IoU metric: '{metric}'. Must be one of {VALID_METRICS}" + metric = metric.lower() + + tl = torch.max( + (bbox1[:, :2] - bbox1[:, 2:] / 2), (bbox2[:, :2] - bbox2[:, 2:] / 2) + ) + br = torch.min( + (bbox1[:, :2] + bbox1[:, 2:] / 2), (bbox2[:, :2] + bbox2[:, 2:] / 2) + ) + + area_p = torch.prod(bbox1[:, 2:], 1) + area_g = torch.prod(bbox2[:, 2:], 1) + + en = (tl < br).type(tl.type()).prod(dim=1) + area_i = torch.prod(br - tl, 1) * en + area_u = area_p + area_g - area_i + iou = (area_i) / (area_u + EPS) + + if metric == "iou": + return iou + elif metric == "giou": + c_tl = torch.min( + (bbox1[:, :2] - bbox1[:, 2:] / 2), (bbox2[:, :2] - bbox2[:, 2:] / 2) + ) + c_br = torch.max( + (bbox1[:, :2] + bbox1[:, 2:] / 2), (bbox2[:, :2] + bbox2[:, 2:] / 2) + ) + area_c = torch.prod(c_br - c_tl, 1) + giou = iou - (area_c - area_u) / area_c.clamp(EPS) + return giou + elif metric == "diou" or metric == "ciou": + cent1 = bbox1[..., :2] # (cx1, cy1) + cent2 = bbox2[..., :2] # (cx2, cy2) + + cent_dist = torch.sum((cent1 - cent2) * (cent1 - cent2), dim=-1) + + c_tl = torch.min( + bbox1[..., :2] - bbox1[..., 2:] / 2, + bbox2[..., :2] - bbox2[..., 2:] / 2 + ) + c_br = torch.max( + bbox1[..., :2] + bbox1[..., 2:] / 2, + bbox2[..., :2] + bbox2[..., 2:] / 2 + ) + + diag_dist = torch.sum((c_br - c_tl) ** 2, dim=-1) + EPS + + diou = iou - (cent_dist / diag_dist) + if metric == "diou": + return diou + arctan = torch.atan(bbox1[..., 2] / (bbox1[..., 3] + EPS)) - torch.atan(bbox2[..., 2] / (bbox2[..., 3] + EPS)) + v = (4 / (math.pi ** 2)) * (arctan ** 2) + with torch.no_grad(): + alpha = v / (v - iou + 1 + EPS) + ciou = diou - alpha * v + return ciou