Skip to content

Commit 6672653

Browse files
kprokofisungchul2eunwoosh
authored
Move transforms from 3d object detection dataset (#4046)
* added coco metric * fix linter * added ap_05 * fix perf test * small fix * fix some misc comments from previous PRs * move transforms from dataset. Metric not working * test metric. still 0 * transforms work. metric is expected. Delete kitty utils * remove unit test * fix unit tests * fix OV IR inference. reply comments * minor * fix linter * Update src/otx/core/data/transform_libs/torchvision.py Co-authored-by: Kim, Sungchul <sungchul.kim@intel.com> * Update src/otx/core/model/detection_3d.py Co-authored-by: Eunwoo Shin <eunwoo.shin@intel.com> * added more unit tests --------- Co-authored-by: Kim, Sungchul <sungchul.kim@intel.com> Co-authored-by: Eunwoo Shin <eunwoo.shin@intel.com>
1 parent e92cdfd commit 6672653

File tree

16 files changed

+633
-558
lines changed

16 files changed

+633
-558
lines changed

src/otx/algo/object_detection_3d/detectors/monodetr.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -235,7 +235,7 @@ def forward(
235235

236236
# depth_geo
237237
box2d_height_norm = outputs_coord[:, :, 4] + outputs_coord[:, :, 5]
238-
box2d_height = torch.clamp(box2d_height_norm * img_sizes[:, 1:2], min=1.0)
238+
box2d_height = torch.clamp(box2d_height_norm * img_sizes[:, :1], min=1.0)
239239
depth_geo = size3d[:, :, 0] / box2d_height * calibs[:, 0, 0].unsqueeze(1)
240240

241241
# depth_reg

src/otx/core/config/data.py

+2
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@ class SubsetConfig:
2929
(`TransformLibType.MMCV`, `TransformLibType.MMPRETRAIN`, ...).
3030
transform_lib_type (TransformLibType): Transform library type used by this subset.
3131
num_workers (int): Number of workers for the dataloader of this subset.
32+
sampler (SamplerConfig | None): Sampler configuration for the dataloader of this subset.
33+
to_tv_image (bool): Whether to convert image to torch tensor.
3234
input_size (int | tuple[int, int] | None) :
3335
input size model expects. If $(input_size) exists in transforms, it will be replaced with this value.
3436

src/otx/core/data/dataset/object_detection_3d.py

+43-231
Original file line numberDiff line numberDiff line change
@@ -3,28 +3,15 @@
33
#
44
"""Module for OTX3DObjectDetectionDataset."""
55

6-
# mypy: ignore-errors
7-
86
from __future__ import annotations
97

108
from copy import deepcopy
119
from functools import partial
1210
from typing import TYPE_CHECKING, Any, Callable, List, Union
1311

1412
import numpy as np
15-
import torch
1613
from datumaro import Image
17-
from PIL import Image as PILImage
18-
from torchvision import tv_tensors
1914

20-
from otx.core.data.dataset.utils.kitti_utils import (
21-
affine_transform,
22-
angle2class,
23-
get_affine_transform,
24-
get_calib_from_file,
25-
rect_to_img,
26-
ry2alpha,
27-
)
2815
from otx.core.data.entity.base import ImageInfo
2916
from otx.core.data.entity.object_detection_3d import Det3DBatchDataEntity, Det3DDataEntity
3017
from otx.core.data.mem_cache import NULL_MEM_CACHE_HANDLER, MemCacheHandlerBase
@@ -34,7 +21,7 @@
3421
from .base import OTXDataset
3522

3623
if TYPE_CHECKING:
37-
from datumaro import Bbox, DatasetSubset
24+
from datumaro import DatasetSubset
3825

3926

4027
Transforms = Union[Compose, Callable, List[Callable], dict[str, Compose | Callable | List[Callable]]]
@@ -54,8 +41,6 @@ def __init__(
5441
stack_images: bool = True,
5542
to_tv_image: bool = False,
5643
max_objects: int = 50,
57-
depth_threshold: int = 65,
58-
resolution: tuple[int, int] = (1280, 384), # (W, H)
5944
) -> None:
6045
super().__init__(
6146
dm_subset,
@@ -68,239 +53,56 @@ def __init__(
6853
to_tv_image,
6954
)
7055
self.max_objects = max_objects
71-
self.depth_threshold = depth_threshold
72-
self.resolution = np.array(resolution) # TODO(Kirill): make it configurable
7356
self.subset_type = list(self.dm_subset.get_subset_info())[-1].split(":")[0]
7457

7558
def _get_item_impl(self, index: int) -> Det3DDataEntity | None:
7659
entity = self.dm_subset[index]
7760
image = entity.media_as(Image)
78-
image = self._get_img_data_and_shape(image)[0]
79-
calib = get_calib_from_file(entity.attributes["calib_path"])
80-
original_kitti_format = None # don't use for training
81-
if self.subset_type != "train":
82-
# TODO (Kirill): remove this or duplication of the inputs
83-
annotations_copy = deepcopy(entity.annotations)
84-
original_kitti_format = [obj.attributes for obj in annotations_copy]
85-
# decode original kitti format for metric calculation
86-
for i, anno_dict in enumerate(original_kitti_format):
87-
anno_dict["name"] = self.label_info.label_names[annotations_copy[i].label]
88-
anno_dict["bbox"] = annotations_copy[i].points
89-
dimension = anno_dict["dimensions"]
90-
anno_dict["dimensions"] = [dimension[2], dimension[0], dimension[1]]
91-
original_kitti_format = self._reformate_for_kitti_metric(original_kitti_format)
92-
# decode labels for training
93-
inputs, targets, ori_img_shape = self._decode_item(
94-
PILImage.fromarray(image),
95-
entity.annotations,
96-
calib,
97-
)
98-
# normilize image
99-
inputs = self._apply_transforms(torch.as_tensor(inputs, dtype=torch.float32))
100-
return Det3DDataEntity(
101-
image=inputs,
61+
image, ori_img_shape = self._get_img_data_and_shape(image)
62+
calib = self.get_calib_from_file(entity.attributes["calib_path"])
63+
annotations_copy = deepcopy(entity.annotations)
64+
datumaro_kitti_format = [obj.attributes for obj in annotations_copy]
65+
66+
# decode original kitti format for metric calculation
67+
for i, anno_dict in enumerate(datumaro_kitti_format):
68+
anno_dict["name"] = (
69+
self.label_info.label_names[annotations_copy[i].label]
70+
if self.subset_type != "train"
71+
else annotations_copy[i].label
72+
)
73+
anno_dict["bbox"] = annotations_copy[i].points
74+
dimension = anno_dict["dimensions"]
75+
anno_dict["dimensions"] = [dimension[2], dimension[0], dimension[1]]
76+
original_kitti_format = self._reformate_for_kitti_metric(datumaro_kitti_format)
77+
78+
entity = Det3DDataEntity(
79+
image=image,
10280
img_info=ImageInfo(
10381
img_idx=index,
104-
img_shape=inputs.shape[1:],
105-
ori_shape=ori_img_shape, # TODO(Kirill): curently we use WxH here, make it HxW
82+
img_shape=ori_img_shape,
83+
ori_shape=ori_img_shape,
10684
image_color_channel=self.image_color_channel,
10785
ignored_labels=[],
10886
),
109-
boxes=tv_tensors.BoundingBoxes(
110-
targets["boxes"],
111-
format=tv_tensors.BoundingBoxFormat.XYXY,
112-
canvas_size=inputs.shape[1:],
113-
dtype=torch.float32,
114-
),
115-
labels=torch.as_tensor(targets["labels"], dtype=torch.long),
116-
calib_matrix=torch.as_tensor(calib, dtype=torch.float32),
117-
boxes_3d=torch.as_tensor(targets["boxes_3d"], dtype=torch.float32),
118-
size_2d=torch.as_tensor(targets["size_2d"], dtype=torch.float32),
119-
size_3d=torch.as_tensor(targets["size_3d"], dtype=torch.float32),
120-
depth=torch.as_tensor(targets["depth"], dtype=torch.float32),
121-
heading_angle=torch.as_tensor(
122-
np.concatenate([targets["heading_bin"], targets["heading_res"]], axis=1),
123-
dtype=torch.float32,
124-
),
87+
boxes=np.zeros((self.max_objects, 4), dtype=np.float32),
88+
labels=np.zeros((self.max_objects), dtype=np.int8),
89+
calib_matrix=calib,
90+
boxes_3d=np.zeros((self.max_objects, 6), dtype=np.float32),
91+
size_2d=np.zeros((self.max_objects, 2), dtype=np.float32),
92+
size_3d=np.zeros((self.max_objects, 3), dtype=np.float32),
93+
depth=np.zeros((self.max_objects, 1), dtype=np.float32),
94+
heading_angle=np.zeros((self.max_objects, 2), dtype=np.float32),
12595
original_kitti_format=original_kitti_format,
12696
)
12797

98+
return self._apply_transforms(entity)
99+
128100
@property
129101
def collate_fn(self) -> Callable:
130102
"""Collection function to collect DetDataEntity into DetBatchDataEntity in data loader."""
131103
return partial(Det3DBatchDataEntity.collate_fn, stack_images=self.stack_images)
132104

133-
def _decode_item(self, img: PILImage, annotations: list[Bbox], calib: np.ndarray) -> tuple: # noqa: C901
134-
"""Decode item for training."""
135-
# data augmentation for image
136-
img_size = np.array(img.size)
137-
bbox2d = np.array([ann.points for ann in annotations])
138-
center = img_size / 2
139-
crop_size, crop_scale = img_size, 1
140-
random_flip_flag = False
141-
# TODO(Kirill): add data augmentation for 3d, remove them from here.
142-
if self.subset_type == "train":
143-
if np.random.random() < 0.5:
144-
random_flip_flag = True
145-
img = img.transpose(PILImage.FLIP_LEFT_RIGHT)
146-
147-
if np.random.random() < 0.5:
148-
scale = 0.05
149-
shift = 0.05
150-
crop_scale = np.clip(np.random.randn() * scale + 1, 1 - scale, 1 + scale)
151-
crop_size = img_size * crop_scale
152-
center[0] += img_size[0] * np.clip(np.random.randn() * shift, -2 * shift, 2 * shift)
153-
center[1] += img_size[1] * np.clip(np.random.randn() * shift, -2 * shift, 2 * shift)
154-
155-
# add affine transformation for 2d images.
156-
trans, trans_inv = get_affine_transform(center, crop_size, 0, self.resolution, inv=1)
157-
img = img.transform(
158-
tuple(self.resolution.tolist()),
159-
method=PILImage.AFFINE,
160-
data=tuple(trans_inv.reshape(-1).tolist()),
161-
resample=PILImage.BILINEAR,
162-
)
163-
img = np.array(img).astype(np.float32)
164-
img = img.transpose(2, 0, 1) # C * H * W -> (384 * 1280)
165-
# ============================ get labels ==============================
166-
# data augmentation for labels
167-
annotations_list: list[dict[str, Any]] = [ann.attributes for ann in annotations]
168-
for i, obj in enumerate(annotations_list):
169-
obj["label"] = annotations[i].label
170-
obj["location"] = np.array(obj["location"])
171-
172-
if random_flip_flag:
173-
for i in range(bbox2d.shape[0]):
174-
[x1, _, x2, _] = bbox2d[i]
175-
bbox2d[i][0], bbox2d[i][2] = img_size[0] - x2, img_size[0] - x1
176-
annotations_list[i]["alpha"] = np.pi - annotations_list[i]["alpha"]
177-
annotations_list[i]["rotation_y"] = np.pi - annotations_list[i]["rotation_y"]
178-
if annotations_list[i]["alpha"] > np.pi:
179-
annotations_list[i]["alpha"] -= 2 * np.pi # check range
180-
if annotations_list[i]["alpha"] < -np.pi:
181-
annotations_list[i]["alpha"] += 2 * np.pi
182-
if annotations_list[i]["rotation_y"] > np.pi:
183-
annotations_list[i]["rotation_y"] -= 2 * np.pi
184-
if annotations_list[i]["rotation_y"] < -np.pi:
185-
annotations_list[i]["rotation_y"] += 2 * np.pi
186-
187-
# labels encoding
188-
mask_2d = np.zeros((self.max_objects), dtype=bool)
189-
labels = np.zeros((self.max_objects), dtype=np.int8)
190-
depth = np.zeros((self.max_objects, 1), dtype=np.float32)
191-
heading_bin = np.zeros((self.max_objects, 1), dtype=np.int64)
192-
heading_res = np.zeros((self.max_objects, 1), dtype=np.float32)
193-
size_2d = np.zeros((self.max_objects, 2), dtype=np.float32)
194-
size_3d = np.zeros((self.max_objects, 3), dtype=np.float32)
195-
src_size_3d = np.zeros((self.max_objects, 3), dtype=np.float32)
196-
boxes = np.zeros((self.max_objects, 4), dtype=np.float32)
197-
boxes_3d = np.zeros((self.max_objects, 6), dtype=np.float32)
198-
199-
object_num = len(annotations) if len(annotations) < self.max_objects else self.max_objects
200-
for i in range(object_num):
201-
cur_obj = annotations_list[i]
202-
# ignore the samples beyond the threshold [hard encoding]
203-
if cur_obj["location"][-1] > self.depth_threshold and cur_obj["location"][-1] < 2:
204-
continue
205-
206-
# process 2d bbox & get 2d center
207-
bbox_2d = bbox2d[i].copy()
208-
209-
# add affine transformation for 2d boxes.
210-
bbox_2d[:2] = affine_transform(bbox_2d[:2], trans)
211-
bbox_2d[2:] = affine_transform(bbox_2d[2:], trans)
212-
213-
# process 3d center
214-
center_2d = np.array(
215-
[(bbox_2d[0] + bbox_2d[2]) / 2, (bbox_2d[1] + bbox_2d[3]) / 2],
216-
dtype=np.float32,
217-
) # W * H
218-
corner_2d = bbox_2d.copy()
219-
220-
center_3d = np.array(
221-
cur_obj["location"]
222-
+ [
223-
0,
224-
-cur_obj["dimensions"][0] / 2,
225-
0,
226-
],
227-
) # real 3D center in 3D space
228-
center_3d = center_3d.reshape(-1, 3) # shape adjustment (N, 3)
229-
center_3d, _ = rect_to_img(calib, center_3d) # project 3D center to image plane
230-
center_3d = center_3d[0] # shape adjustment
231-
if random_flip_flag: # random flip for center3d
232-
center_3d[0] = img_size[0] - center_3d[0]
233-
center_3d = affine_transform(center_3d.reshape(-1), trans)
234-
235-
# filter 3d center out of img
236-
proj_inside_img = True
237-
238-
if center_3d[0] < 0 or center_3d[0] >= self.resolution[0]:
239-
proj_inside_img = False
240-
if center_3d[1] < 0 or center_3d[1] >= self.resolution[1]:
241-
proj_inside_img = False
242-
243-
if not proj_inside_img:
244-
continue
245-
246-
# class
247-
labels[i] = cur_obj["label"]
248-
249-
# encoding 2d/3d boxes
250-
w, h = bbox_2d[2] - bbox_2d[0], bbox_2d[3] - bbox_2d[1]
251-
size_2d[i] = 1.0 * w, 1.0 * h
252-
253-
center_2d_norm = center_2d / self.resolution
254-
size_2d_norm = size_2d[i] / self.resolution
255-
256-
corner_2d_norm = corner_2d
257-
corner_2d_norm[0:2] = corner_2d[0:2] / self.resolution
258-
corner_2d_norm[2:4] = corner_2d[2:4] / self.resolution
259-
center_3d_norm = center_3d / self.resolution
260-
261-
k, r = center_3d_norm[0] - corner_2d_norm[0], corner_2d_norm[2] - center_3d_norm[0]
262-
t, b = center_3d_norm[1] - corner_2d_norm[1], corner_2d_norm[3] - center_3d_norm[1]
263-
264-
if k < 0 or r < 0 or t < 0 or b < 0:
265-
continue
266-
267-
boxes[i] = center_2d_norm[0], center_2d_norm[1], size_2d_norm[0], size_2d_norm[1]
268-
boxes_3d[i] = center_3d_norm[0], center_3d_norm[1], k, r, t, b
269-
270-
# encoding depth
271-
depth[i] = cur_obj["location"][-1] * crop_scale
272-
273-
# encoding heading angle
274-
heading_angle = ry2alpha(calib, cur_obj["rotation_y"], (bbox2d[i][0] + bbox2d[i][2]) / 2)
275-
if heading_angle > np.pi:
276-
heading_angle -= 2 * np.pi # check range
277-
if heading_angle < -np.pi:
278-
heading_angle += 2 * np.pi
279-
heading_bin[i], heading_res[i] = angle2class(heading_angle)
280-
281-
# encoding size_3d
282-
src_size_3d[i] = np.array([cur_obj["dimensions"]], dtype=np.float32)
283-
size_3d[i] = src_size_3d[i]
284-
285-
# filter out the samples with truncated or occluded
286-
if cur_obj["truncated"] <= 0.5 and cur_obj["occluded"] <= 2:
287-
mask_2d[i] = 1
288-
289-
# collect return data
290-
targets_for_train = {
291-
"labels": labels[mask_2d],
292-
"boxes": boxes[mask_2d],
293-
"boxes_3d": boxes_3d[mask_2d],
294-
"depth": depth[mask_2d],
295-
"size_2d": size_2d[mask_2d],
296-
"size_3d": size_3d[mask_2d],
297-
"heading_bin": heading_bin[mask_2d],
298-
"heading_res": heading_res[mask_2d],
299-
}
300-
301-
return img, targets_for_train, img_size
302-
303-
def _reformate_for_kitti_metric(self, annotations: dict[str, Any]) -> dict[str, np.array]:
105+
def _reformate_for_kitti_metric(self, annotations: list[Any]) -> dict[str, np.array]:
304106
"""Reformat the annotation for KITTI metric."""
305107
return {
306108
"name": np.array([obj["name"] for obj in annotations]),
@@ -312,3 +114,13 @@ def _reformate_for_kitti_metric(self, annotations: dict[str, Any]) -> dict[str,
312114
"occluded": np.array([obj["occluded"] for obj in annotations]),
313115
"truncated": np.array([obj["truncated"] for obj in annotations]),
314116
}
117+
118+
@staticmethod
119+
def get_calib_from_file(calib_file: str) -> np.ndarray:
120+
"""Get calibration matrix from txt file (KITTI format)."""
121+
with open(calib_file) as f: # noqa: PTH123
122+
lines = f.readlines()
123+
124+
obj = lines[2].strip().split(" ")[1:]
125+
126+
return np.array(obj, dtype=np.float32).reshape(3, 4)

src/otx/core/data/dataset/utils/__init__.py

-4
This file was deleted.

0 commit comments

Comments
 (0)