Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] ImVoteNet complete model #352

Merged
merged 89 commits into from
Mar 24, 2021
Merged
Show file tree
Hide file tree
Changes from 66 commits
Commits
Show all changes
89 commits
Select commit Hold shift + click to select a range
3be0f68
Added image loading in SUNRGB-D dataset (#195)
yezhen17 Nov 25, 2020
f7514bc
Added imvotenet image branch pretrain (#217)
yezhen17 Dec 9, 2020
3a7ff0f
Merge branch 'master' of github.com:open-mmlab/mmdetection3d into imv…
ZwwWayne Dec 24, 2020
b373499
integrated vote fusion
yezhen17 Feb 3, 2021
3be3690
coords transform and unit test
yezhen17 Feb 5, 2021
cce6534
Update docstring
yezhen17 Feb 5, 2021
824b85a
Merge branch 'master' into imvotenet
ZwwWayne Feb 10, 2021
b88085b
Merge branch 'imvotenet_fork' into imvotenet
yezhen17 Feb 10, 2021
747cfd9
refactor and add unit test
yezhen17 Feb 10, 2021
99e2f7f
fix bug caused by mmcv upgrade; delete pdb breakpoint
yezhen17 Feb 10, 2021
1e4f12c
add point fusion unittest
yezhen17 Feb 15, 2021
b9cf3ca
remove unused file
yezhen17 Feb 15, 2021
e9c6df3
Merge branch 'master_fork'
yezhen17 Feb 16, 2021
cb23822
fix typos
yezhen17 Feb 16, 2021
f29fb1b
Merge branch 'master_fork'
yezhen17 Feb 18, 2021
4415f49
updates
yezhen17 Feb 19, 2021
a96d949
add assertion info
yezhen17 Feb 23, 2021
c419f9b
update
yezhen17 Feb 24, 2021
c8c466f
add unittest
yezhen17 Feb 24, 2021
97552b5
add vote fusion unittest
yezhen17 Feb 25, 2021
770ed19
add vote fusion unittest
yezhen17 Feb 25, 2021
01ff7a6
[Refactor] VoteNet refactor (#322)
yezhen17 Mar 1, 2021
f289bc5
minor update
yezhen17 Mar 3, 2021
abf54e8
docstring
yezhen17 Mar 3, 2021
a47a6ff
Merge branch 'imvotenet_fork' into _imvotenet_backbone
yezhen17 Mar 4, 2021
a80769e
Merge branch '_imvotenet_fusion' into _imvotenet_backbone
yezhen17 Mar 4, 2021
68f07fc
initial update of imvotenet
yezhen17 Mar 4, 2021
8ed6ed9
[Feature] Support vote fusion (#297)
yezhen17 Mar 4, 2021
8f341a3
change np ops to torch
yezhen17 Mar 4, 2021
c235234
Merge branch 'imvotenet_fork' into imvotenet
yezhen17 Mar 4, 2021
e973261
refactor test
yezhen17 Mar 4, 2021
6cf6b47
Merge branch '_imvotenet_backbone' into imvotenet
yezhen17 Mar 4, 2021
689ff2b
update
yezhen17 Mar 5, 2021
c613710
Merge branch '_imvotenet_backbone' into imvotenet
yezhen17 Mar 5, 2021
ba16291
refactor of image mlp and np random ops to torch
yezhen17 Mar 5, 2021
7862419
Merge branch '_imvotenet_backbone' into imvotenet
yezhen17 Mar 5, 2021
ef04ac5
add docstring
yezhen17 Mar 7, 2021
078edb1
add config and mod dataset
yezhen17 Mar 7, 2021
4d1b09a
Merge branch '_imvotenet_backbone' into _imvotenet_complete
yezhen17 Mar 7, 2021
0da8c3f
Merge branch '_imvotenet_backbone' into imvotenet
yezhen17 Mar 7, 2021
8b66bf9
fix bugs
yezhen17 Mar 7, 2021
a850f26
add_comments
yezhen17 Mar 7, 2021
ba361de
fix bugs
yezhen17 Mar 7, 2021
122d829
Merge branch '_imvotenet_backbone' into imvotenet
yezhen17 Mar 7, 2021
6bd5a3d
fix_bug
yezhen17 Mar 8, 2021
a542dde
fix bug
yezhen17 Mar 8, 2021
b955263
fix bug
yezhen17 Mar 8, 2021
4464dd9
Merge branch '_imvotenet_backbone' into imvotenet
yezhen17 Mar 8, 2021
93b06d7
fix bug
yezhen17 Mar 8, 2021
cce3e38
fix bug
yezhen17 Mar 8, 2021
5bb4136
Merge branch '_imvotenet_backbone' into imvotenet
yezhen17 Mar 8, 2021
233918d
final fix
yezhen17 Mar 9, 2021
5baaa1f
fix bug
yezhen17 Mar 9, 2021
69c8d9b
Merge branch '_imvotenet_backbone' into imvotenet
yezhen17 Mar 9, 2021
85af306
?
yezhen17 Mar 10, 2021
7714432
add docstring
yezhen17 Mar 10, 2021
27c478e
Merge branch '_imvotenet_backbone' into imvotenet
yezhen17 Mar 10, 2021
453a952
merge conflicts
yezhen17 Mar 10, 2021
c8749c9
move train/test cfg
yezhen17 Mar 10, 2021
654b46b
Merge branch 'master_fork'
yezhen17 Mar 11, 2021
0e6b8ba
Merge branch 'imvotenet'
yezhen17 Mar 11, 2021
86c170f
change img mlp default param
yezhen17 Mar 11, 2021
e80978a
rename config
yezhen17 Mar 11, 2021
1264bb3
Merge branch 'imvotenet'
yezhen17 Mar 11, 2021
622d34a
minor mod
yezhen17 Mar 13, 2021
70ae208
resolve conflict
yezhen17 Mar 13, 2021
f3dd1c8
change config name
yezhen17 Mar 13, 2021
aefac13
move train/test cfg
yezhen17 Mar 13, 2021
f0cfcf1
some fixes and 2d utils
yezhen17 Mar 17, 2021
600a48a
Merge branch 'imvotenet'
yezhen17 Mar 17, 2021
6dd3e73
fix config name
yezhen17 Mar 17, 2021
9ad3712
fix config override issue
yezhen17 Mar 18, 2021
e8c4f5d
Merge branch 'imvotenet' into _imvotenet_complete
yezhen17 Mar 18, 2021
4621db3
config simplify & reformat
yezhen17 Mar 18, 2021
b0cdc1f
Merge branch 'imvotenet' into _imvotenet_complete
yezhen17 Mar 18, 2021
a240ee3
explicitly set eval mode->override train()
yezhen17 Mar 18, 2021
4199c0e
Merge branch 'imvotenet' into _imvotenet_complete
yezhen17 Mar 18, 2021
4e037b6
add fix_img_branch to config
yezhen17 Mar 18, 2021
ada2dde
remove set_img_branch_eval_mode
yezhen17 Mar 18, 2021
3acb3e7
Merge branch 'imvotenet' into _imvotenet_complete
yezhen17 Mar 18, 2021
db14c67
temporal fix, change calibs to calib
yezhen17 Mar 18, 2021
d9d9701
more docstring and view/reshape, expand/repeat change
yezhen17 Mar 19, 2021
60a9c37
Merge branch 'imvotenet' into _imvotenet_complete
yezhen17 Mar 19, 2021
e3d59fc
complete imvotenet docstring
yezhen17 Mar 19, 2021
8440099
resolve conflict
yezhen17 Mar 19, 2021
db2daf2
fix docstring
yezhen17 Mar 19, 2021
204ad0d
Merge branch 'imvotenet' into _imvotenet_complete
yezhen17 Mar 19, 2021
450607f
add config and some minor fix
yezhen17 Mar 23, 2021
2ccc72e
rename config
yezhen17 Mar 24, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
_base_ = [
'../_base_/datasets/sunrgbd-3d-10class.py', '../_base_/default_runtime.py'
]

model = dict(
type='ImVoteNet',
img_backbone=dict(
type='ResNet',
depth=50,
num_stages=4,
out_indices=(0, 1, 2, 3),
frozen_stages=1,
norm_cfg=dict(type='BN', requires_grad=False),
norm_eval=True,
style='caffe'),
img_neck=dict(
type='FPN',
in_channels=[256, 512, 1024, 2048],
out_channels=256,
num_outs=5),
img_rpn_head=dict(
type='RPNHead',
in_channels=256,
feat_channels=256,
anchor_generator=dict(
type='AnchorGenerator',
scales=[8],
ratios=[0.5, 1.0, 2.0],
strides=[4, 8, 16, 32, 64]),
bbox_coder=dict(
type='DeltaXYWHBBoxCoder',
target_means=[.0, .0, .0, .0],
target_stds=[1.0, 1.0, 1.0, 1.0]),
loss_cls=dict(
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),
loss_bbox=dict(type='L1Loss', loss_weight=1.0)),
img_roi_head=dict(
type='StandardRoIHead',
bbox_roi_extractor=dict(
type='SingleRoIExtractor',
roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0),
out_channels=256,
featmap_strides=[4, 8, 16, 32]),
bbox_head=dict(
type='Shared2FCBBoxHead',
in_channels=256,
fc_out_channels=1024,
roi_feat_size=7,
num_classes=10,
bbox_coder=dict(
type='DeltaXYWHBBoxCoder',
target_means=[0., 0., 0., 0.],
target_stds=[0.1, 0.1, 0.2, 0.2]),
reg_class_agnostic=False,
loss_cls=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
loss_bbox=dict(type='L1Loss', loss_weight=1.0))),

# model training and testing settings
train_cfg=dict(
img_rpn=dict(
assigner=dict(
type='MaxIoUAssigner',
pos_iou_thr=0.7,
neg_iou_thr=0.3,
min_pos_iou=0.3,
match_low_quality=True,
ignore_iof_thr=-1),
sampler=dict(
type='RandomSampler',
num=256,
pos_fraction=0.5,
neg_pos_ub=-1,
add_gt_as_proposals=False),
allowed_border=-1,
pos_weight=-1,
debug=False),
img_rpn_proposal=dict(
nms_across_levels=False,
nms_pre=2000,
nms_post=1000,
max_num=1000,
nms_thr=0.7,
min_bbox_size=0),
img_rcnn=dict(
assigner=dict(
type='MaxIoUAssigner',
pos_iou_thr=0.5,
neg_iou_thr=0.5,
min_pos_iou=0.5,
match_low_quality=False,
ignore_iof_thr=-1),
sampler=dict(
type='RandomSampler',
num=512,
pos_fraction=0.25,
neg_pos_ub=-1,
add_gt_as_proposals=True),
pos_weight=-1,
debug=False)),
test_cfg=dict(
img_rpn=dict(
nms_across_levels=False,
nms_pre=1000,
nms_post=1000,
max_num=1000,
nms_thr=0.7,
min_bbox_size=0),
img_rcnn=dict(
score_thr=0.05,
nms=dict(type='nms', iou_threshold=0.5),
max_per_img=100)))

# use caffe img_norm
img_norm_cfg = dict(
mean=[103.530, 116.280, 123.675], std=[1.0, 1.0, 1.0], to_rgb=False)

train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations', with_bbox=True),
dict(
type='Resize',
img_scale=[(1333, 480), (1333, 504), (1333, 528), (1333, 552),
(1333, 576), (1333, 600)],
multiscale_mode='value',
keep_ratio=True),
dict(type='RandomFlip', flip_ratio=0.5),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']),
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=(1333, 600),
flip=False,
transforms=[
dict(type='Resize', keep_ratio=True),
dict(type='RandomFlip'),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img']),
])
]

data = dict(
samples_per_gpu=2,
workers_per_gpu=2,
train=dict(times=1, dataset=dict(pipeline=train_pipeline)),
val=dict(pipeline=test_pipeline),
test=dict(pipeline=test_pipeline))

optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001)
optimizer_config = dict(grad_clip=None)
lr_config = dict(
policy='step',
warmup='linear',
warmup_iters=500,
warmup_ratio=0.001,
step=[6])
total_epochs = 8

load_from = 'http://download.openmmlab.com/mmdetection/v2.0/mask_rcnn/mask_rcnn_r50_caffe_fpn_mstrain-poly_3x_coco/mask_rcnn_r50_caffe_fpn_mstrain-poly_3x_coco_bbox_mAP-0.408__segm_mAP-0.37_20200504_163245-42aa3d00.pth' # noqa
22 changes: 15 additions & 7 deletions mmdet3d/core/bbox/structures/coord_3d_mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ def convert_point(point, src, dst, rt_mat=None):
"""Convert points from `src` mode to `dst` mode.

Args:
box (tuple | list | np.dnarray |
point (tuple | list | np.dnarray |
torch.Tensor | BasePoints):
Can be a k-tuple, k-list or an Nxk array/tensor.
src (:obj:`CoordMode`): The src Point mode.
Expand Down Expand Up @@ -218,17 +218,25 @@ def convert_point(point, src, dst, rt_mat=None):
arr = point.clone()

# convert point from `src` mode to `dst` mode.
if rt_mat is not None:
if not isinstance(rt_mat, torch.Tensor):
rt_mat = arr.new_tensor(rt_mat)
# TODO: LIDAR
# only implemented provided Rt matrix in cam-depth conversion
if src == Coord3DMode.LIDAR and dst == Coord3DMode.CAM:
rt_mat = arr.new_tensor([[0, -1, 0], [0, 0, -1], [1, 0, 0]])
elif src == Coord3DMode.CAM and dst == Coord3DMode.LIDAR:
rt_mat = arr.new_tensor([[0, 0, 1], [-1, 0, 0], [0, -1, 0]])
elif src == Coord3DMode.DEPTH and dst == Coord3DMode.CAM:
rt_mat = arr.new_tensor([[1, 0, 0], [0, 0, 1], [0, -1, 0]])
if rt_mat is None:
rt_mat = arr.new_tensor([[1, 0, 0], [0, 0, -1], [0, 1, 0]])
else:
rt_mat = rt_mat.new_tensor(
[[1, 0, 0], [0, 0, -1], [0, 1, 0]]) @ \
rt_mat.transpose(1, 0)
elif src == Coord3DMode.CAM and dst == Coord3DMode.DEPTH:
rt_mat = arr.new_tensor([[1, 0, 0], [0, 0, -1], [0, 1, 0]])
if rt_mat is None:
rt_mat = arr.new_tensor([[1, 0, 0], [0, 0, 1], [0, -1, 0]])
else:
rt_mat = rt_mat @ rt_mat.new_tensor([[1, 0, 0], [0, 0, 1],
[0, -1, 0]])
elif src == Coord3DMode.LIDAR and dst == Coord3DMode.DEPTH:
rt_mat = arr.new_tensor([[0, -1, 0], [1, 0, 0], [0, 0, 1]])
elif src == Coord3DMode.DEPTH and dst == Coord3DMode.LIDAR:
Expand All @@ -245,7 +253,7 @@ def convert_point(point, src, dst, rt_mat=None):
else:
xyz = arr[:, :3] @ rt_mat.t()

remains = arr[..., 3:]
remains = arr[:, 3:]
arr = torch.cat([xyz[:, :3], remains], dim=-1)

# convert arr to the original type
Expand Down
13 changes: 13 additions & 0 deletions mmdet3d/core/bbox/structures/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,20 @@ def points_cam2img(points_3d, proj_mat):
torch.Tensor: Points in image coordinates with shape [N, 2].
"""
points_num = list(points_3d.shape)[:-1]

points_shape = np.concatenate([points_num, [1]], axis=0).tolist()
assert len(proj_mat.shape) == 2, f'The dimension of the projection'\
f'matrix should be 2 instead of {len(proj_mat.shape)}.'
d1, d2 = proj_mat.shape[:2]
assert (d1 == 3 and d2 == 3) or (d1 == 3 and d2 == 4) or (
d1 == 4 and d2 == 4), f'The shape of the projection matrix'\
f' ({d1}*{d2}) is not supported.'
if d1 == 3:
proj_mat_expanded = torch.eye(
4, device=proj_mat.device, dtype=proj_mat.dtype)
proj_mat_expanded[:d1, :d2] = proj_mat
proj_mat = proj_mat_expanded

# previous implementation use new_zeros, new_one yeilds better results
points_4 = torch.cat(
[points_3d, points_3d.new_ones(*points_shape)], dim=-1)
Expand Down
4 changes: 2 additions & 2 deletions mmdet3d/core/points/base_points.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,14 @@ class BasePoints(object):
tensor (torch.Tensor | np.ndarray | list): a N x points_dim matrix.
points_dim (int): Number of the dimension of a point.
Each row is (x, y, z). Default to 3.
attribute_dims (dict): Dictinory to indicate the meaning of extra
attribute_dims (dict): Dictionary to indicate the meaning of extra
dimension. Default to None.

Attributes:
tensor (torch.Tensor): Float matrix of N x points_dim.
points_dim (int): Integer indicating the dimension of a point.
Each row is (x, y, z, ...).
attribute_dims (bool): Dictinory to indicate the meaning of extra
attribute_dims (bool): Dictionary to indicate the meaning of extra
dimension. Default to None.
rotation_axis (int): Default rotation axis for points rotation.
"""
Expand Down
4 changes: 2 additions & 2 deletions mmdet3d/core/points/cam_points.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,14 @@ class CameraPoints(BasePoints):
tensor (torch.Tensor | np.ndarray | list): a N x points_dim matrix.
points_dim (int): Number of the dimension of a point.
Each row is (x, y, z). Default to 3.
attribute_dims (dict): Dictinory to indicate the meaning of extra
attribute_dims (dict): Dictionary to indicate the meaning of extra
dimension. Default to None.

Attributes:
tensor (torch.Tensor): Float matrix of N x points_dim.
points_dim (int): Integer indicating the dimension of a point.
Each row is (x, y, z, ...).
attribute_dims (bool): Dictinory to indicate the meaning of extra
attribute_dims (bool): Dictionary to indicate the meaning of extra
dimension. Default to None.
rotation_axis (int): Default rotation axis for points rotation.
"""
Expand Down
4 changes: 2 additions & 2 deletions mmdet3d/core/points/depth_points.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,14 @@ class DepthPoints(BasePoints):
tensor (torch.Tensor | np.ndarray | list): a N x points_dim matrix.
points_dim (int): Number of the dimension of a point.
Each row is (x, y, z). Default to 3.
attribute_dims (dict): Dictinory to indicate the meaning of extra
attribute_dims (dict): Dictionary to indicate the meaning of extra
dimension. Default to None.

Attributes:
tensor (torch.Tensor): Float matrix of N x points_dim.
points_dim (int): Integer indicating the dimension of a point.
Each row is (x, y, z, ...).
attribute_dims (bool): Dictinory to indicate the meaning of extra
attribute_dims (bool): Dictionary to indicate the meaning of extra
dimension. Default to None.
rotation_axis (int): Default rotation axis for points rotation.
"""
Expand Down
4 changes: 2 additions & 2 deletions mmdet3d/core/points/lidar_points.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,14 @@ class LiDARPoints(BasePoints):
tensor (torch.Tensor | np.ndarray | list): a N x points_dim matrix.
points_dim (int): Number of the dimension of a point.
Each row is (x, y, z). Default to 3.
attribute_dims (dict): Dictinory to indicate the meaning of extra
attribute_dims (dict): Dictionary to indicate the meaning of extra
dimension. Default to None.

Attributes:
tensor (torch.Tensor): Float matrix of N x points_dim.
points_dim (int): Integer indicating the dimension of a point.
Each row is (x, y, z, ...).
attribute_dims (bool): Dictinory to indicate the meaning of extra
attribute_dims (bool): Dictionary to indicate the meaning of extra
dimension. Default to None.
rotation_axis (int): Default rotation axis for points rotation.
"""
Expand Down
10 changes: 10 additions & 0 deletions mmdet3d/datasets/pipelines/transforms_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,10 +96,15 @@ def __call__(self, input_dict):
) < self.flip_ratio_bev_vertical else False
input_dict['pcd_vertical_flip'] = flip_vertical

if 'transformation_3d_flow' not in input_dict:
input_dict['transformation_3d_flow'] = []

if input_dict['pcd_horizontal_flip']:
self.random_flip_data_3d(input_dict, 'horizontal')
input_dict['transformation_3d_flow'].extend(['HF'])
if input_dict['pcd_vertical_flip']:
self.random_flip_data_3d(input_dict, 'vertical')
input_dict['transformation_3d_flow'].extend(['VF'])
return input_dict

def __repr__(self):
Expand Down Expand Up @@ -405,13 +410,18 @@ def __call__(self, input_dict):
'pcd_scale_factor', 'pcd_trans' and keys in \
input_dict['bbox3d_fields'] are updated in the result dict.
"""
if 'transformation_3d_flow' not in input_dict:
input_dict['transformation_3d_flow'] = []

self._rot_bbox_points(input_dict)

if 'pcd_scale_factor' not in input_dict:
self._random_scale(input_dict)
self._scale_bbox_points(input_dict)

self._trans_bbox_points(input_dict)

input_dict['transformation_3d_flow'].extend(['R', 'S', 'T'])
return input_dict

def __repr__(self):
Expand Down
Loading