Skip to content

RTMDet-tiny enablement for detection task #3542

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 19 commits into from
May 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/otx/algo/detection/atss.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def _customize_outputs(
elif isinstance(v, torch.Tensor):
losses[k] = v
else:
msg = "Loss output should be list or torch.tensor but got {type(v)}"
msg = f"Loss output should be list or torch.tensor but got {type(v)}"
raise TypeError(msg)
return losses

Expand Down
5 changes: 3 additions & 2 deletions src/otx/algo/detection/heads/atss_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from otx.algo.detection.heads.class_incremental_mixin import (
ClassIncrementalMixin,
)
from otx.algo.detection.losses.cross_entropy_loss import CrossEntropyLoss
from otx.algo.detection.losses.cross_focal_loss import (
CrossSigmoidFocalLoss,
)
Expand Down Expand Up @@ -57,12 +58,12 @@ def __init__(
self,
num_classes: int,
in_channels: int,
loss_centerness: nn.Module,
pred_kernel_size: int = 3,
stacked_convs: int = 4,
conv_cfg: dict | None = None,
norm_cfg: dict | None = None,
reg_decoded_bbox: bool = True,
loss_centerness: nn.Module | None = None,
init_cfg: dict | None = None,
bg_loss_weight: float = -1.0,
use_qfl: bool = False,
Expand All @@ -89,7 +90,7 @@ def __init__(
)

self.sampling = False
self.loss_centerness = loss_centerness
self.loss_centerness = loss_centerness or CrossEntropyLoss(use_sigmoid=True, loss_weight=1.0)

if use_qfl:
kwargs["loss_cls"] = (
Expand Down
52 changes: 3 additions & 49 deletions src/otx/algo/detection/heads/delta_xywh_bbox_coder.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
import torch
from torch import Tensor

from otx.algo.detection.utils.utils import clip_bboxes_export


# This class and its supporting functions below lightly adapted from the mmdet DeltaXYWHBBoxCoder available at:
# https://github.com/open-mmlab/mmdetection/blob/cfd5d3a985b0249de009b67d04f37263e11cdf3d/mmdet/models/task_modules/coders/delta_xywh_bbox_coder.py
Expand Down Expand Up @@ -360,54 +362,6 @@ def delta2bbox_export(
y2 = xy2[..., 1]

if clip_border and max_shape is not None:
x1, y1, x2, y2 = clip_bboxes(x1, y1, x2, y2, max_shape)
x1, y1, x2, y2 = clip_bboxes_export(x1, y1, x2, y2, max_shape)

return torch.stack([x1, y1, x2, y2], dim=-1).view(deltas.size())


def clip_bboxes(
x1: Tensor,
y1: Tensor,
x2: Tensor,
y2: Tensor,
max_shape: Tensor | tuple[int, ...],
) -> tuple[Tensor, ...]:
"""Clip bboxes for onnx.

Since torch.clamp cannot have dynamic `min` and `max`, we scale the
boxes by 1/max_shape and clamp in the range [0, 1] if necessary.

Args:
x1 (Tensor): The x1 for bounding boxes.
y1 (Tensor): The y1 for bounding boxes.
x2 (Tensor): The x2 for bounding boxes.
y2 (Tensor): The y2 for bounding boxes.
max_shape (Tensor | Sequence[int]): The (H,W) of original image.

Returns:
tuple(Tensor): The clipped x1, y1, x2, y2.
"""
if isinstance(max_shape, torch.Tensor):
# scale by 1/max_shape
x1 = x1 / max_shape[1]
y1 = y1 / max_shape[0]
x2 = x2 / max_shape[1]
y2 = y2 / max_shape[0]

# clamp [0, 1]
x1 = torch.clamp(x1, 0, 1)
y1 = torch.clamp(y1, 0, 1)
x2 = torch.clamp(x2, 0, 1)
y2 = torch.clamp(y2, 0, 1)

# scale back
x1 = x1 * max_shape[1]
y1 = y1 * max_shape[0]
x2 = x2 * max_shape[1]
y2 = y2 * max_shape[0]
else:
x1 = torch.clamp(x1, 0, max_shape[1])
y1 = torch.clamp(y1, 0, max_shape[0])
x2 = torch.clamp(x2, 0, max_shape[1])
y2 = torch.clamp(y2, 0, max_shape[0])
return x1, y1, x2, y2
31 changes: 29 additions & 2 deletions src/otx/algo/detection/heads/distance_point_bbox_coder.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
# Copyright (c) OpenMMLab. All rights reserved.
""""Distance Point BBox coder."""
"""Distance Point BBox coder."""

from __future__ import annotations

from typing import TYPE_CHECKING

from otx.algo.detection.utils.utils import bbox2distance, distance2bbox
from otx.algo.detection.utils.utils import bbox2distance, distance2bbox, distance2bbox_export

if TYPE_CHECKING:
from torch import Tensor
Expand Down Expand Up @@ -84,3 +84,30 @@ def decode(
if self.clip_border is False:
max_shape = None
return distance2bbox(points, pred_bboxes, max_shape)

def decode_export(
self,
points: Tensor,
pred_bboxes: Tensor,
max_shape: tuple[int, ...] | Tensor | tuple[tuple[int, ...], ...] | None = None,
) -> Tensor:
"""Decode distance prediction to bounding box for export."""
if points.size(0) != pred_bboxes.size(0):
msg = (
f"The batch of points (={points.size(0)}) and the batch of pred_bboxes "
f"(={pred_bboxes.size(0)}) should be same."
)
raise ValueError(msg)

if points.size(-1) != 2:
msg = f"points should have the format with size of 2, given {points.size(-1)}."
raise ValueError(msg)

if pred_bboxes.size(-1) != 4:
msg = f"pred_bboxes should have the format with size of 4, given {pred_bboxes.size(-1)}."
raise ValueError(msg)

if self.clip_border is False:
max_shape = None

return distance2bbox_export(points, pred_bboxes, max_shape)
Loading