3
3
#
4
4
"""Module for OTX3DObjectDetectionDataset."""
5
5
6
- # mypy: ignore-errors
7
-
8
6
from __future__ import annotations
9
7
10
8
from copy import deepcopy
11
9
from functools import partial
12
10
from typing import TYPE_CHECKING , Any , Callable , List , Union
13
11
14
12
import numpy as np
15
- import torch
16
13
from datumaro import Image
17
- from PIL import Image as PILImage
18
- from torchvision import tv_tensors
19
14
20
- from otx .core .data .dataset .utils .kitti_utils import (
21
- affine_transform ,
22
- angle2class ,
23
- get_affine_transform ,
24
- get_calib_from_file ,
25
- rect_to_img ,
26
- ry2alpha ,
27
- )
28
15
from otx .core .data .entity .base import ImageInfo
29
16
from otx .core .data .entity .object_detection_3d import Det3DBatchDataEntity , Det3DDataEntity
30
17
from otx .core .data .mem_cache import NULL_MEM_CACHE_HANDLER , MemCacheHandlerBase
34
21
from .base import OTXDataset
35
22
36
23
if TYPE_CHECKING :
37
- from datumaro import Bbox , DatasetSubset
24
+ from datumaro import DatasetSubset
38
25
39
26
40
27
Transforms = Union [Compose , Callable , List [Callable ], dict [str , Compose | Callable | List [Callable ]]]
@@ -54,8 +41,6 @@ def __init__(
54
41
stack_images : bool = True ,
55
42
to_tv_image : bool = False ,
56
43
max_objects : int = 50 ,
57
- depth_threshold : int = 65 ,
58
- resolution : tuple [int , int ] = (1280 , 384 ), # (W, H)
59
44
) -> None :
60
45
super ().__init__ (
61
46
dm_subset ,
@@ -68,239 +53,56 @@ def __init__(
68
53
to_tv_image ,
69
54
)
70
55
self .max_objects = max_objects
71
- self .depth_threshold = depth_threshold
72
- self .resolution = np .array (resolution ) # TODO(Kirill): make it configurable
73
56
self .subset_type = list (self .dm_subset .get_subset_info ())[- 1 ].split (":" )[0 ]
74
57
75
58
def _get_item_impl (self , index : int ) -> Det3DDataEntity | None :
76
59
entity = self .dm_subset [index ]
77
60
image = entity .media_as (Image )
78
- image = self ._get_img_data_and_shape (image )[0 ]
79
- calib = get_calib_from_file (entity .attributes ["calib_path" ])
80
- original_kitti_format = None # don't use for training
81
- if self .subset_type != "train" :
82
- # TODO (Kirill): remove this or duplication of the inputs
83
- annotations_copy = deepcopy (entity .annotations )
84
- original_kitti_format = [obj .attributes for obj in annotations_copy ]
85
- # decode original kitti format for metric calculation
86
- for i , anno_dict in enumerate (original_kitti_format ):
87
- anno_dict ["name" ] = self .label_info .label_names [annotations_copy [i ].label ]
88
- anno_dict ["bbox" ] = annotations_copy [i ].points
89
- dimension = anno_dict ["dimensions" ]
90
- anno_dict ["dimensions" ] = [dimension [2 ], dimension [0 ], dimension [1 ]]
91
- original_kitti_format = self ._reformate_for_kitti_metric (original_kitti_format )
92
- # decode labels for training
93
- inputs , targets , ori_img_shape = self ._decode_item (
94
- PILImage .fromarray (image ),
95
- entity .annotations ,
96
- calib ,
97
- )
98
- # normilize image
99
- inputs = self ._apply_transforms (torch .as_tensor (inputs , dtype = torch .float32 ))
100
- return Det3DDataEntity (
101
- image = inputs ,
61
+ image , ori_img_shape = self ._get_img_data_and_shape (image )
62
+ calib = self .get_calib_from_file (entity .attributes ["calib_path" ])
63
+ annotations_copy = deepcopy (entity .annotations )
64
+ datumaro_kitti_format = [obj .attributes for obj in annotations_copy ]
65
+
66
+ # decode original kitti format for metric calculation
67
+ for i , anno_dict in enumerate (datumaro_kitti_format ):
68
+ anno_dict ["name" ] = (
69
+ self .label_info .label_names [annotations_copy [i ].label ]
70
+ if self .subset_type != "train"
71
+ else annotations_copy [i ].label
72
+ )
73
+ anno_dict ["bbox" ] = annotations_copy [i ].points
74
+ dimension = anno_dict ["dimensions" ]
75
+ anno_dict ["dimensions" ] = [dimension [2 ], dimension [0 ], dimension [1 ]]
76
+ original_kitti_format = self ._reformate_for_kitti_metric (datumaro_kitti_format )
77
+
78
+ entity = Det3DDataEntity (
79
+ image = image ,
102
80
img_info = ImageInfo (
103
81
img_idx = index ,
104
- img_shape = inputs . shape [ 1 :] ,
105
- ori_shape = ori_img_shape , # TODO(Kirill): curently we use WxH here, make it HxW
82
+ img_shape = ori_img_shape ,
83
+ ori_shape = ori_img_shape ,
106
84
image_color_channel = self .image_color_channel ,
107
85
ignored_labels = [],
108
86
),
109
- boxes = tv_tensors .BoundingBoxes (
110
- targets ["boxes" ],
111
- format = tv_tensors .BoundingBoxFormat .XYXY ,
112
- canvas_size = inputs .shape [1 :],
113
- dtype = torch .float32 ,
114
- ),
115
- labels = torch .as_tensor (targets ["labels" ], dtype = torch .long ),
116
- calib_matrix = torch .as_tensor (calib , dtype = torch .float32 ),
117
- boxes_3d = torch .as_tensor (targets ["boxes_3d" ], dtype = torch .float32 ),
118
- size_2d = torch .as_tensor (targets ["size_2d" ], dtype = torch .float32 ),
119
- size_3d = torch .as_tensor (targets ["size_3d" ], dtype = torch .float32 ),
120
- depth = torch .as_tensor (targets ["depth" ], dtype = torch .float32 ),
121
- heading_angle = torch .as_tensor (
122
- np .concatenate ([targets ["heading_bin" ], targets ["heading_res" ]], axis = 1 ),
123
- dtype = torch .float32 ,
124
- ),
87
+ boxes = np .zeros ((self .max_objects , 4 ), dtype = np .float32 ),
88
+ labels = np .zeros ((self .max_objects ), dtype = np .int8 ),
89
+ calib_matrix = calib ,
90
+ boxes_3d = np .zeros ((self .max_objects , 6 ), dtype = np .float32 ),
91
+ size_2d = np .zeros ((self .max_objects , 2 ), dtype = np .float32 ),
92
+ size_3d = np .zeros ((self .max_objects , 3 ), dtype = np .float32 ),
93
+ depth = np .zeros ((self .max_objects , 1 ), dtype = np .float32 ),
94
+ heading_angle = np .zeros ((self .max_objects , 2 ), dtype = np .float32 ),
125
95
original_kitti_format = original_kitti_format ,
126
96
)
127
97
98
+ return self ._apply_transforms (entity )
99
+
128
100
@property
129
101
def collate_fn (self ) -> Callable :
130
102
"""Collection function to collect DetDataEntity into DetBatchDataEntity in data loader."""
131
103
return partial (Det3DBatchDataEntity .collate_fn , stack_images = self .stack_images )
132
104
133
- def _decode_item (self , img : PILImage , annotations : list [Bbox ], calib : np .ndarray ) -> tuple : # noqa: C901
134
- """Decode item for training."""
135
- # data augmentation for image
136
- img_size = np .array (img .size )
137
- bbox2d = np .array ([ann .points for ann in annotations ])
138
- center = img_size / 2
139
- crop_size , crop_scale = img_size , 1
140
- random_flip_flag = False
141
- # TODO(Kirill): add data augmentation for 3d, remove them from here.
142
- if self .subset_type == "train" :
143
- if np .random .random () < 0.5 :
144
- random_flip_flag = True
145
- img = img .transpose (PILImage .FLIP_LEFT_RIGHT )
146
-
147
- if np .random .random () < 0.5 :
148
- scale = 0.05
149
- shift = 0.05
150
- crop_scale = np .clip (np .random .randn () * scale + 1 , 1 - scale , 1 + scale )
151
- crop_size = img_size * crop_scale
152
- center [0 ] += img_size [0 ] * np .clip (np .random .randn () * shift , - 2 * shift , 2 * shift )
153
- center [1 ] += img_size [1 ] * np .clip (np .random .randn () * shift , - 2 * shift , 2 * shift )
154
-
155
- # add affine transformation for 2d images.
156
- trans , trans_inv = get_affine_transform (center , crop_size , 0 , self .resolution , inv = 1 )
157
- img = img .transform (
158
- tuple (self .resolution .tolist ()),
159
- method = PILImage .AFFINE ,
160
- data = tuple (trans_inv .reshape (- 1 ).tolist ()),
161
- resample = PILImage .BILINEAR ,
162
- )
163
- img = np .array (img ).astype (np .float32 )
164
- img = img .transpose (2 , 0 , 1 ) # C * H * W -> (384 * 1280)
165
- # ============================ get labels ==============================
166
- # data augmentation for labels
167
- annotations_list : list [dict [str , Any ]] = [ann .attributes for ann in annotations ]
168
- for i , obj in enumerate (annotations_list ):
169
- obj ["label" ] = annotations [i ].label
170
- obj ["location" ] = np .array (obj ["location" ])
171
-
172
- if random_flip_flag :
173
- for i in range (bbox2d .shape [0 ]):
174
- [x1 , _ , x2 , _ ] = bbox2d [i ]
175
- bbox2d [i ][0 ], bbox2d [i ][2 ] = img_size [0 ] - x2 , img_size [0 ] - x1
176
- annotations_list [i ]["alpha" ] = np .pi - annotations_list [i ]["alpha" ]
177
- annotations_list [i ]["rotation_y" ] = np .pi - annotations_list [i ]["rotation_y" ]
178
- if annotations_list [i ]["alpha" ] > np .pi :
179
- annotations_list [i ]["alpha" ] -= 2 * np .pi # check range
180
- if annotations_list [i ]["alpha" ] < - np .pi :
181
- annotations_list [i ]["alpha" ] += 2 * np .pi
182
- if annotations_list [i ]["rotation_y" ] > np .pi :
183
- annotations_list [i ]["rotation_y" ] -= 2 * np .pi
184
- if annotations_list [i ]["rotation_y" ] < - np .pi :
185
- annotations_list [i ]["rotation_y" ] += 2 * np .pi
186
-
187
- # labels encoding
188
- mask_2d = np .zeros ((self .max_objects ), dtype = bool )
189
- labels = np .zeros ((self .max_objects ), dtype = np .int8 )
190
- depth = np .zeros ((self .max_objects , 1 ), dtype = np .float32 )
191
- heading_bin = np .zeros ((self .max_objects , 1 ), dtype = np .int64 )
192
- heading_res = np .zeros ((self .max_objects , 1 ), dtype = np .float32 )
193
- size_2d = np .zeros ((self .max_objects , 2 ), dtype = np .float32 )
194
- size_3d = np .zeros ((self .max_objects , 3 ), dtype = np .float32 )
195
- src_size_3d = np .zeros ((self .max_objects , 3 ), dtype = np .float32 )
196
- boxes = np .zeros ((self .max_objects , 4 ), dtype = np .float32 )
197
- boxes_3d = np .zeros ((self .max_objects , 6 ), dtype = np .float32 )
198
-
199
- object_num = len (annotations ) if len (annotations ) < self .max_objects else self .max_objects
200
- for i in range (object_num ):
201
- cur_obj = annotations_list [i ]
202
- # ignore the samples beyond the threshold [hard encoding]
203
- if cur_obj ["location" ][- 1 ] > self .depth_threshold and cur_obj ["location" ][- 1 ] < 2 :
204
- continue
205
-
206
- # process 2d bbox & get 2d center
207
- bbox_2d = bbox2d [i ].copy ()
208
-
209
- # add affine transformation for 2d boxes.
210
- bbox_2d [:2 ] = affine_transform (bbox_2d [:2 ], trans )
211
- bbox_2d [2 :] = affine_transform (bbox_2d [2 :], trans )
212
-
213
- # process 3d center
214
- center_2d = np .array (
215
- [(bbox_2d [0 ] + bbox_2d [2 ]) / 2 , (bbox_2d [1 ] + bbox_2d [3 ]) / 2 ],
216
- dtype = np .float32 ,
217
- ) # W * H
218
- corner_2d = bbox_2d .copy ()
219
-
220
- center_3d = np .array (
221
- cur_obj ["location" ]
222
- + [
223
- 0 ,
224
- - cur_obj ["dimensions" ][0 ] / 2 ,
225
- 0 ,
226
- ],
227
- ) # real 3D center in 3D space
228
- center_3d = center_3d .reshape (- 1 , 3 ) # shape adjustment (N, 3)
229
- center_3d , _ = rect_to_img (calib , center_3d ) # project 3D center to image plane
230
- center_3d = center_3d [0 ] # shape adjustment
231
- if random_flip_flag : # random flip for center3d
232
- center_3d [0 ] = img_size [0 ] - center_3d [0 ]
233
- center_3d = affine_transform (center_3d .reshape (- 1 ), trans )
234
-
235
- # filter 3d center out of img
236
- proj_inside_img = True
237
-
238
- if center_3d [0 ] < 0 or center_3d [0 ] >= self .resolution [0 ]:
239
- proj_inside_img = False
240
- if center_3d [1 ] < 0 or center_3d [1 ] >= self .resolution [1 ]:
241
- proj_inside_img = False
242
-
243
- if not proj_inside_img :
244
- continue
245
-
246
- # class
247
- labels [i ] = cur_obj ["label" ]
248
-
249
- # encoding 2d/3d boxes
250
- w , h = bbox_2d [2 ] - bbox_2d [0 ], bbox_2d [3 ] - bbox_2d [1 ]
251
- size_2d [i ] = 1.0 * w , 1.0 * h
252
-
253
- center_2d_norm = center_2d / self .resolution
254
- size_2d_norm = size_2d [i ] / self .resolution
255
-
256
- corner_2d_norm = corner_2d
257
- corner_2d_norm [0 :2 ] = corner_2d [0 :2 ] / self .resolution
258
- corner_2d_norm [2 :4 ] = corner_2d [2 :4 ] / self .resolution
259
- center_3d_norm = center_3d / self .resolution
260
-
261
- k , r = center_3d_norm [0 ] - corner_2d_norm [0 ], corner_2d_norm [2 ] - center_3d_norm [0 ]
262
- t , b = center_3d_norm [1 ] - corner_2d_norm [1 ], corner_2d_norm [3 ] - center_3d_norm [1 ]
263
-
264
- if k < 0 or r < 0 or t < 0 or b < 0 :
265
- continue
266
-
267
- boxes [i ] = center_2d_norm [0 ], center_2d_norm [1 ], size_2d_norm [0 ], size_2d_norm [1 ]
268
- boxes_3d [i ] = center_3d_norm [0 ], center_3d_norm [1 ], k , r , t , b
269
-
270
- # encoding depth
271
- depth [i ] = cur_obj ["location" ][- 1 ] * crop_scale
272
-
273
- # encoding heading angle
274
- heading_angle = ry2alpha (calib , cur_obj ["rotation_y" ], (bbox2d [i ][0 ] + bbox2d [i ][2 ]) / 2 )
275
- if heading_angle > np .pi :
276
- heading_angle -= 2 * np .pi # check range
277
- if heading_angle < - np .pi :
278
- heading_angle += 2 * np .pi
279
- heading_bin [i ], heading_res [i ] = angle2class (heading_angle )
280
-
281
- # encoding size_3d
282
- src_size_3d [i ] = np .array ([cur_obj ["dimensions" ]], dtype = np .float32 )
283
- size_3d [i ] = src_size_3d [i ]
284
-
285
- # filter out the samples with truncated or occluded
286
- if cur_obj ["truncated" ] <= 0.5 and cur_obj ["occluded" ] <= 2 :
287
- mask_2d [i ] = 1
288
-
289
- # collect return data
290
- targets_for_train = {
291
- "labels" : labels [mask_2d ],
292
- "boxes" : boxes [mask_2d ],
293
- "boxes_3d" : boxes_3d [mask_2d ],
294
- "depth" : depth [mask_2d ],
295
- "size_2d" : size_2d [mask_2d ],
296
- "size_3d" : size_3d [mask_2d ],
297
- "heading_bin" : heading_bin [mask_2d ],
298
- "heading_res" : heading_res [mask_2d ],
299
- }
300
-
301
- return img , targets_for_train , img_size
302
-
303
- def _reformate_for_kitti_metric (self , annotations : dict [str , Any ]) -> dict [str , np .array ]:
105
+ def _reformate_for_kitti_metric (self , annotations : list [Any ]) -> dict [str , np .array ]:
304
106
"""Reformat the annotation for KITTI metric."""
305
107
return {
306
108
"name" : np .array ([obj ["name" ] for obj in annotations ]),
@@ -312,3 +114,13 @@ def _reformate_for_kitti_metric(self, annotations: dict[str, Any]) -> dict[str,
312
114
"occluded" : np .array ([obj ["occluded" ] for obj in annotations ]),
313
115
"truncated" : np .array ([obj ["truncated" ] for obj in annotations ]),
314
116
}
117
+
118
+ @staticmethod
119
+ def get_calib_from_file (calib_file : str ) -> np .ndarray :
120
+ """Get calibration matrix from txt file (KITTI format)."""
121
+ with open (calib_file ) as f : # noqa: PTH123
122
+ lines = f .readlines ()
123
+
124
+ obj = lines [2 ].strip ().split (" " )[1 :]
125
+
126
+ return np .array (obj , dtype = np .float32 ).reshape (3 , 4 )
0 commit comments