Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Enhance] Update data structures #2155

Merged
12 changes: 3 additions & 9 deletions mmdet3d/structures/bbox_3d/base_box3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,17 +463,11 @@ def overlaps(cls, boxes1, boxes2, mode='iou'):
# height overlap
overlaps_h = cls.height_overlaps(boxes1, boxes2)

# Restrict the min values of W and H to avoid memory overflow in
# ``box_iou_rotated``.
boxes1_bev, boxes2_bev = boxes1.bev, boxes2.bev
Xiangxu-0103 marked this conversation as resolved.
Show resolved Hide resolved
boxes1_bev[:, 2:4] = boxes1_bev[:, 2:4].clamp(min=1e-4)
boxes2_bev[:, 2:4] = boxes2.bev[:, 2:4].clamp(min=1e-4)

# bev overlap
iou2d = box_iou_rotated(boxes1_bev, boxes2_bev)
areas1 = (boxes1_bev[:, 2] * boxes1_bev[:, 3]).unsqueeze(1).expand(
iou2d = box_iou_rotated(boxes1.bev, boxes2.bev)
areas1 = (boxes1.bev[:, 2] * boxes1.bev[:, 3]).unsqueeze(1).expand(
rows, cols)
areas2 = (boxes2_bev[:, 2] * boxes2_bev[:, 3]).unsqueeze(0).expand(
areas2 = (boxes2.bev[:, 2] * boxes2.bev[:, 3]).unsqueeze(0).expand(
rows, cols)
overlaps_bev = iou2d * (areas1 + areas2) / (1 + iou2d)

Expand Down
95 changes: 56 additions & 39 deletions mmdet3d/structures/det3d_data_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,46 +15,47 @@ class Det3DDataSample(DetDataSample):
The attributes in ``Det3DDataSample`` are divided into several parts:

- ``proposals``(InstanceData): Region proposals used in two-stage
detectors.
detectors.
- ``ignored_instances``(InstanceData): Instances to be ignored during
training/testing.
training/testing.
- ``gt_instances_3d``(InstanceData): Ground truth of 3D instance
annotations.
annotations.
- ``gt_instances``(InstanceData): Ground truth of 2D instance
annotations.
annotations.
- ``pred_instances_3d``(InstanceData): 3D instances of model
predictions.
- For point-cloud 3d object detection task whose input modality
predictions.
- For point-cloud 3D object detection task whose input modality
is `use_lidar=True, use_camera=False`, the 3D predictions results
are saved in `pred_instances_3d`.
- For vision-only(monocular/multi-view) 3D object detection task
- For vision-only (monocular/multi-view) 3D object detection task
whose input modality is `use_lidar=False, use_camera=True`, the 3D
predictions are saved in `pred_instances_3d`.
- ``pred_instances``(InstanceData): 2D instances of model
predictions.
- For multi-modality 3D detection task whose input modality is
predictions.
- For multi-modality 3D detection task whose input modality is
`use_lidar=True, use_camera=True`, the 2D predictions
are saved in `pred_instances`.
- ``pts_pred_instances_3d``(InstanceData): 3D instances of model
predictions based on point cloud.
- For multi-modality 3D detection task whose input modality is
predictions based on point cloud.
- For multi-modality 3D detection task whose input modality is
`use_lidar=True, use_camera=True`, the 3D predictions based on
point cloud are saved in `pts_pred_instances_3d` to distinguish
with `img_pred_instances_3d` which based on image.
- ``img_pred_instances_3d``(InstanceData): 3D instances of model
predictions based on image.
- For multi-modality 3D detection task whose input modality is
predictions based on image.
- For multi-modality 3D detection task whose input modality is
`use_lidar=True, use_camera=True`, the 3D predictions based on
image are saved in `img_pred_instances_3d` to distinguish with
`pts_pred_instances_3d` which based on point cloud.
- ``gt_pts_seg``(PointData): Ground truth of point cloud
segmentation.
segmentation.
- ``pred_pts_seg``(PointData): Prediction of point cloud
segmentation.
- ``eval_ann_info``(dict): Raw annotation, which will be passed to
evaluator and do the online evaluation.
segmentation.
- ``eval_ann_info``(dict or None): Raw annotation, which will be passed
to evaluator and do the online evaluation.

Examples:
>>> import torch
>>> from mmengine.structures import InstanceData

>>> from mmdet3d.structures import Det3DDataSample
Expand All @@ -64,8 +65,8 @@ class Det3DDataSample(DetDataSample):
>>> meta_info = dict(img_shape=(800, 1196, 3),
... pad_shape=(800, 1216, 3))
>>> gt_instances_3d = InstanceData(metainfo=meta_info)
>>> gt_instances_3d.bboxes = BaseInstance3DBoxes(torch.rand((5, 7)))
>>> gt_instances_3d.labels = torch.randint(0,3,(5, ))
>>> gt_instances_3d.bboxes_3d = BaseInstance3DBoxes(torch.rand((5, 7)))
>>> gt_instances_3d.labels_3d = torch.randint(0, 3, (5,))
>>> data_sample.gt_instances_3d = gt_instances_3d
>>> assert 'img_shape' in data_sample.gt_instances_3d.metainfo_keys()
>>> print(data_sample)
Expand All @@ -81,8 +82,8 @@ class Det3DDataSample(DetDataSample):
img_shape: (800, 1196, 3)

DATA FIELDS
labels: tensor([0, 0, 1, 0, 2])
bboxes: BaseInstance3DBoxes(
labels_3d: tensor([0, 0, 1, 0, 2])
bboxes_3d: BaseInstance3DBoxes(
tensor([[0.2874, 0.3078, 0.8368, 0.2326, 0.9845, 0.6199, 0.9944],
[0.6222, 0.8778, 0.7306, 0.3320, 0.3973, 0.7662, 0.7326],
[0.8547, 0.6082, 0.1660, 0.1676, 0.9810, 0.3092, 0.0917],
Expand All @@ -96,8 +97,8 @@ class Det3DDataSample(DetDataSample):
img_shape: (800, 1196, 3)

DATA FIELDS
labels: tensor([0, 0, 1, 0, 2])
bboxes: BaseInstance3DBoxes(
labels_3d: tensor([0, 0, 1, 0, 2])
bboxes_3d: BaseInstance3DBoxes(
tensor([[0.2874, 0.3078, 0.8368, 0.2326, 0.9845, 0.6199, 0.9944],
[0.6222, 0.8778, 0.7306, 0.3320, 0.3973, 0.7662, 0.7326],
[0.8547, 0.6082, 0.1660, 0.1676, 0.9810, 0.3092, 0.0917],
Expand All @@ -120,15 +121,16 @@ class Det3DDataSample(DetDataSample):

>>> data_sample = Det3DDataSample()
>>> gt_instances_3d_data = dict(
... bboxes=BaseInstance3DBoxes(torch.rand((2, 7))),
... labels=torch.rand(2))
... bboxes_3d=BaseInstance3DBoxes(torch.rand((2, 7))),
... labels_3d=torch.rand(2))
>>> gt_instances_3d = InstanceData(**gt_instances_3d_data)
>>> data_sample.gt_instances_3d = gt_instances_3d
>>> assert 'gt_instances_3d' in data_sample
>>> assert 'bboxes' in data_sample.gt_instances_3d
>>> assert 'bboxes_3d' in data_sample.gt_instances_3d

>>> from mmdet3d.structures import PointData
>>> data_sample = Det3DDataSample()
... gt_pts_seg_data = dict(
>>> gt_pts_seg_data = dict(
... pts_instance_mask=torch.rand(2),
... pts_semantic_mask=torch.rand(2))
>>> data_sample.gt_pts_seg = PointData(**gt_pts_seg_data)
Expand Down Expand Up @@ -162,73 +164,88 @@ def gt_instances_3d(self) -> InstanceData:
return self._gt_instances_3d

@gt_instances_3d.setter
def gt_instances_3d(self, value: InstanceData):
def gt_instances_3d(self, value: InstanceData) -> None:
ZwwWayne marked this conversation as resolved.
Show resolved Hide resolved
self.set_field(value, '_gt_instances_3d', dtype=InstanceData)

@gt_instances_3d.deleter
def gt_instances_3d(self):
def gt_instances_3d(self) -> None:
del self._gt_instances_3d

@property
def pred_instances_3d(self) -> InstanceData:
return self._pred_instances_3d

@pred_instances_3d.setter
def pred_instances_3d(self, value: InstanceData):
def pred_instances_3d(self, value: InstanceData) -> None:
self.set_field(value, '_pred_instances_3d', dtype=InstanceData)

@pred_instances_3d.deleter
def pred_instances_3d(self):
def pred_instances_3d(self) -> None:
del self._pred_instances_3d

@property
def pts_pred_instances_3d(self) -> InstanceData:
return self._pts_pred_instances_3d

@pts_pred_instances_3d.setter
def pts_pred_instances_3d(self, value: InstanceData):
def pts_pred_instances_3d(self, value: InstanceData) -> None:
self.set_field(value, '_pts_pred_instances_3d', dtype=InstanceData)

@pts_pred_instances_3d.deleter
def pts_pred_instances_3d(self):
def pts_pred_instances_3d(self) -> None:
del self._pts_pred_instances_3d

@property
def img_pred_instances_3d(self) -> InstanceData:
return self._img_pred_instances_3d

@img_pred_instances_3d.setter
def img_pred_instances_3d(self, value: InstanceData):
def img_pred_instances_3d(self, value: InstanceData) -> None:
self.set_field(value, '_img_pred_instances_3d', dtype=InstanceData)

@img_pred_instances_3d.deleter
def img_pred_instances_3d(self):
def img_pred_instances_3d(self) -> None:
del self._img_pred_instances_3d

@property
def gt_pts_seg(self) -> PointData:
return self._gt_pts_seg

@gt_pts_seg.setter
def gt_pts_seg(self, value: PointData):
def gt_pts_seg(self, value: PointData) -> None:
self.set_field(value, '_gt_pts_seg', dtype=PointData)

@gt_pts_seg.deleter
def gt_pts_seg(self):
def gt_pts_seg(self) -> None:
del self._gt_pts_seg

@property
def pred_pts_seg(self) -> PointData:
return self._pred_pts_seg

@pred_pts_seg.setter
def pred_pts_seg(self, value: PointData):
def pred_pts_seg(self, value: PointData) -> None:
self.set_field(value, '_pred_pts_seg', dtype=PointData)

@pred_pts_seg.deleter
def pred_pts_seg(self):
def pred_pts_seg(self) -> None:
del self._pred_pts_seg

@property
Xiangxu-0103 marked this conversation as resolved.
Show resolved Hide resolved
def eval_ann_info(self) -> Union[dict, None]:
return self._eval_ann_info

@eval_ann_info.setter
def eval_ann_info(self, value: Union[dict, None]) -> None:
if value is None:
self.set_field(value, '_eval_ann_info')
else:
self.set_field(value, '_eval_ann_info', dtype=dict)

@eval_ann_info.deleter
def eval_ann_info(self) -> None:
del self._eval_ann_info


SampleList = List[Det3DDataSample]
OptSampleList = Optional[SampleList]
Expand Down
75 changes: 43 additions & 32 deletions mmdet3d/structures/point_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@


class PointData(BaseDataElement):
"""Data structure for point-level annnotations or predictions.
"""Data structure for point-level annotations or predictions.

All data items in ``data_fields`` of ``PointData`` meet the following
requirements:
Expand All @@ -27,58 +27,70 @@ class PointData(BaseDataElement):

Examples:
>>> metainfo = dict(
... sample_id=random.randint(0, 100))
... sample_idx=random.randint(0, 100))
>>> points = np.random.randint(0, 255, (100, 3))
>>> point_data = PointData(metainfo=metainfo,
... points=points)
>>> print(len(point_data))
>>> (100)
100

>>> # slice
>>> slice_data = pixel_data[10:60]
>>> assert slice_data.shape == (50,)
>>> slice_data = point_data[10:60]
>>> assert len(slice_data) == 50

>>> # set
>>> point_data.pts_semantic_mask = torch.randint(0, 255, (100))
>>> point_data.pts_instance_mask = torch.randint(0, 255, (100))
>>> assert tuple(point_data.pts_semantic_mask.shape) == (100)
>>> assert tuple(point_data.pts_instance_mask.shape) == (100)
>>> point_data.pts_semantic_mask = torch.randint(0, 255, (100,))
>>> point_data.pts_instance_mask = torch.randint(0, 255, (100,))
>>> assert tuple(point_data.pts_semantic_mask.shape) == (100,)
>>> assert tuple(point_data.pts_instance_mask.shape) == (100,)
"""

def __setattr__(self, name: str, value: Sized):
def __setattr__(self, name: str, value: Sized) -> None:
"""setattr is only used to set data.

the value must have the attribute of `__len__` and have the same length
of PointData.
The value must have the attribute of `__len__` and have the same length
of `PointData`.
"""
if name in ('_metainfo_fields', '_data_fields'):
if not hasattr(self, name):
super().__setattr__(name, value)
else:
raise AttributeError(
f'{name} has been used as a '
f'private attribute, which is immutable. ')
raise AttributeError(f'{name} has been used as a '
'private attribute, which is immutable.')

else:
assert isinstance(value,
Sized), 'value must contain `_len__` attribute'
Sized), 'value must contain `__len__` attribute'

if len(self) > 0:
assert len(value) == len(self), 'the length of ' \
f'values {len(value)} is ' \
'not consistent with ' \
'the length of this ' \
':obj:`PointData` ' \
f'{len(self)}'
super().__setattr__(name, value)

__setitem__ = __setattr__

def __getitem__(self, item: IndexType) -> 'PointData':
"""
Args:
item (str, obj:`slice`,
obj`torch.LongTensor`, obj:`torch.BoolTensor`):
item (str, :obj:`slice`,
:obj:`torch.LongTensor`, :obj:`torch.BoolTensor`):
get the corresponding values according to item.

Returns:
obj:`PointData`: Corresponding values.
:obj:`PointData`: Corresponding values.
"""
if isinstance(item, list):
item = np.array(item)
if isinstance(item, np.ndarray):
# The default int type of numpy is platform dependent, int32 for
# windows and int64 for linux. `torch.Tensor` requires the index
# should be int64, therefore we simply convert it to int64 here.
# Mode details in https://github.com/numpy/numpy/issues/9464
item = item.astype(np.int64) if item.dtype == np.int32 else item
item = torch.from_numpy(item)
assert isinstance(
item, (str, slice, int, torch.LongTensor, torch.cuda.LongTensor,
Expand All @@ -87,8 +99,8 @@ def __getitem__(self, item: IndexType) -> 'PointData':
if isinstance(item, str):
return getattr(self, item)

if type(item) == int:
if item >= len(self) or item < -len(self): # type:ignore
if isinstance(item, int):
if item >= len(self) or item < -len(self): # type: ignore
raise IndexError(f'Index {item} out of range!')
else:
# keep the dimension
Expand All @@ -99,14 +111,14 @@ def __getitem__(self, item: IndexType) -> 'PointData':
assert item.dim() == 1, 'Only support to get the' \
' values along the first dimension.'
if isinstance(item, (torch.BoolTensor, torch.cuda.BoolTensor)):
assert len(item) == len(self), f'The shape of the' \
f' input(BoolTensor)) ' \
assert len(item) == len(self), 'The shape of the ' \
'input(BoolTensor) ' \
f'{len(item)} ' \
f' does not match the shape ' \
f'of the indexed tensor ' \
f'in results_filed ' \
'does not match the shape ' \
'of the indexed tensor ' \
'in results_field ' \
f'{len(self)} at ' \
f'first dimension. '
'first dimension.'

for k, v in self.items():
if isinstance(v, torch.Tensor):
Expand All @@ -116,7 +128,7 @@ def __getitem__(self, item: IndexType) -> 'PointData':
elif isinstance(
v, (str, list, tuple)) or (hasattr(v, '__getitem__')
and hasattr(v, 'cat')):
# convert to indexes from boolTensor
# convert to indexes from BoolTensor
if isinstance(item,
(torch.BoolTensor, torch.cuda.BoolTensor)):
indexes = torch.nonzero(item).view(
Expand All @@ -141,16 +153,15 @@ def __getitem__(self, item: IndexType) -> 'PointData':
raise ValueError(
f'The type of `{k}` is `{type(v)}`, which has no '
'attribute of `cat`, so it does not '
f'support slice with `bool`')

'support slice with `bool`')
else:
# item is a slice
for k, v in self.items():
new_data[k] = v[item]
return new_data # type:ignore
return new_data # type: ignore

def __len__(self) -> int:
"""int: the length of PointData"""
"""int: the length of `PointData`."""
if len(self._data_fields) > 0:
return len(self.values()[0])
else:
Expand Down
Loading