Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Group-Free-3D head #539

Merged
merged 34 commits into from
Jul 1, 2021
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
709f32c
group-free-3d head
hjin2902 May 12, 2021
4d035f9
GroupFree3DNet->VoteNet
hjin2902 May 12, 2021
d825abe
modify docstring
hjin2902 May 17, 2021
06c8e02
bugfix: calculate pts_instance_label, decoder self/cross posembed init
hjin2902 May 18, 2021
7d8ecaa
support GroupFree3DNet, modify scannnet train config
hjin2902 May 18, 2021
39b0e82
support point cloud input features dim = 0
hjin2902 May 18, 2021
bd744ad
add groupfree3dnet test case
hjin2902 May 18, 2021
7bc85ef
bugfix: softmax in decode boxes
hjin2902 May 19, 2021
62b0e33
support multi-stage predictions
hjin2902 May 19, 2021
9c69b1c
modify GroupFree3DMultiheadAttention input parameters
hjin2902 May 20, 2021
3f2ddc0
refactor: support sunrgbd-based train
hjin2902 May 22, 2021
b10f2d3
refactor: support sunrgbd-based train
hjin2902 May 22, 2021
2bf027f
fix parts of bug
hjin2902 May 26, 2021
ce49f7c
modify multi-stage prediction
hjin2902 May 26, 2021
f43cbc6
fixbug: conv_channels
hjin2902 May 26, 2021
122f716
bugfix: permute
hjin2902 May 26, 2021
305ae0e
bugfix: permute
hjin2902 May 26, 2021
a303c91
bugfix: expand
hjin2902 May 26, 2021
44732cb
fix MAX_NUM_OBJ=64
hjin2902 May 30, 2021
4c5c345
Merge branch 'master' into groupfree3dHead
hjin2902 May 31, 2021
a17fa5f
4 gpu training, score_thr = 0.0
hjin2902 Jun 3, 2021
2d6b7b0
modify config, repeattime=1
hjin2902 Jun 3, 2021
b559bf6
bigfix: expand
hjin2902 Jun 4, 2021
5746034
modify: GroupFree3DMHA, build_positional_encoding
Jun 9, 2021
36245d9
modify: GroupFree3DMHA, build_positional_encoding
hjin2902 Jun 9, 2021
1f65845
bugfix: torch.nn
hjin2902 Jun 9, 2021
6e95fc2
bugfix: mean loss
hjin2902 Jun 9, 2021
7cf7981
residual -> identity
hjin2902 Jun 11, 2021
f392f81
fix name: DropOut -> Dropout
hjin2902 Jun 14, 2021
2ad4952
merge master into and resolve conflicts with ImVoxelNet
hjin2902 Jun 21, 2021
4476bf0
delete sunrgbd-based congfig
hjin2902 Jun 23, 2021
ac21b89
Fix: trailing whitespace
hjin2902 Jun 23, 2021
83d76fa
suffix -> prefix
hjin2902 Jul 1, 2021
f8c46d0
bugfix: groupfree3d config
hjin2902 Jul 1, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions configs/_base_/models/groupfree3d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
model = dict(
type='GroupFree3DNet',
backbone=dict(
type='PointNet2SASSG',
in_channels=3,
num_points=(2048, 1024, 512, 256),
radius=(0.2, 0.4, 0.8, 1.2),
num_samples=(64, 32, 16, 16),
sa_channels=((64, 64, 128), (128, 128, 256), (128, 128, 256),
(128, 128, 256)),
fp_channels=((256, 256), (256, 288)),
norm_cfg=dict(type='BN2d'),
sa_cfg=dict(
type='PointSAModule',
pool_mod='max',
use_xyz=True,
normalize_xyz=True)),
bbox_head=dict(
type='GroupFree3DHead',
in_channels=288,
num_decoder_layers=6,
num_proposal=256,
transformerlayers=dict(
type='BaseTransformerLayer',
attn_cfgs=dict(
type='GroupFree3DMultiheadAttention',
hjin2902 marked this conversation as resolved.
Show resolved Hide resolved
embed_dims=288,
num_heads=8,
attn_drop=0.1,
dropout_layer=dict(type='DropOut', drop_prob=0.1)),
ffn_cfgs=dict(
embed_dims=288,
feedforward_channels=2048,
ffn_drop=0.1,
act_cfg=dict(type='ReLU', inplace=True)),
operation_order=('self_attn', 'norm', 'cross_attn', 'norm', 'ffn',
'norm')),
pred_layer_cfg=dict(
in_channels=288, shared_conv_channels=(288, 288), bias=True)),
# model training and testing settings
train_cfg=dict(sample_mod='kps'),
test_cfg=dict(
sample_mod='kps',
nms_thr=0.25,
score_thr=0.05,
per_class_proposal=True,
prediction_stages=('_3', '_4', '_5')))
164 changes: 164 additions & 0 deletions configs/groupfree3d/groupfree3d_16x8_sunrgbd-3d-10class.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
_base_ = [
'../_base_/datasets/sunrgbd-3d-10class.py',
'../_base_/models/groupfree3d.py', '../_base_/schedules/schedule_3x.py',
'../_base_/default_runtime.py'
]
# model settings
model = dict(
bbox_head=dict(
num_classes=10,
size_cls_agnostic=True,
bbox_coder=dict(
type='GroupFree3DBBoxCoder',
num_sizes=10,
num_dir_bins=12,
with_rot=True,
size_cls_agnostic=True,
mean_sizes=[
[2.114256, 1.620300, 0.927272], [0.791118, 1.279516, 0.718182],
[0.923508, 1.867419, 0.845495], [0.591958, 0.552978, 0.827272],
[0.699104, 0.454178, 0.75625], [0.69519, 1.346299, 0.736364],
[0.528526, 1.002642, 1.172878], [0.500618, 0.632163, 0.683424],
[0.404671, 1.071108, 1.688889], [0.76584, 1.398258, 0.472728]
]),
sampling_objectness_loss=dict(
type='FocalLoss',
use_sigmoid=True,
gamma=2.0,
alpha=0.25,
loss_weight=2.0),
objectness_loss=dict(
type='FocalLoss',
use_sigmoid=True,
gamma=2.0,
alpha=0.25,
loss_weight=4.0),
center_loss=dict(
type='SmoothL1Loss',
beta=1.0 / 9.0,
reduction='sum',
loss_weight=10.0),
dir_class_loss=dict(
type='CrossEntropyLoss', reduction='sum', loss_weight=1.0),
dir_res_loss=dict(
type='SmoothL1Loss',
beta=0.04,
reduction='sum',
loss_weight=10.0 * 0.04),
size_reg_loss=dict(
type='SmoothL1Loss',
beta=0.0625,
reduction='sum',
loss_weight=10.0 * 0.0625),
semantic_loss=dict(
type='CrossEntropyLoss', reduction='sum', loss_weight=1.0)))

dataset_type = 'SUNRGBDDataset'
data_root = 'data/sunrgbd/'
class_names = ('bed', 'table', 'sofa', 'chair', 'toilet', 'desk', 'dresser',
'night_stand', 'bookshelf', 'bathtub')
train_pipeline = [
dict(
type='LoadPointsFromFile',
coord_type='DEPTH',
load_dim=6,
use_dim=[0, 1, 2]),
dict(type='LoadAnnotations3D'),
dict(
type='RandomFlip3D',
sync_2d=False,
flip_ratio_bev_horizontal=0.5,
),
dict(
type='GlobalRotScaleTrans',
rot_range=[-0.523599, 0.523599],
scale_ratio_range=[0.85, 1.15]),
dict(type='IndoorPointSample', num_points=20000),
dict(type='DefaultFormatBundle3D', class_names=class_names),
dict(type='Collect3D', keys=['points', 'gt_bboxes_3d', 'gt_labels_3d'])
]
test_pipeline = [
dict(
type='LoadPointsFromFile',
coord_type='DEPTH',
load_dim=6,
use_dim=[0, 1, 2]),
dict(
type='MultiScaleFlipAug3D',
img_scale=(1333, 800),
pts_scale_ratio=1,
flip=False,
transforms=[
dict(
type='GlobalRotScaleTrans',
rot_range=[0, 0],
scale_ratio_range=[1., 1.],
translation_std=[0, 0, 0]),
dict(
type='RandomFlip3D',
sync_2d=False,
flip_ratio_bev_horizontal=0.5,
),
dict(type='IndoorPointSample', num_points=20000),
dict(
type='DefaultFormatBundle3D',
class_names=class_names,
with_label=False),
dict(type='Collect3D', keys=['points'])
])
]

data = dict(
samples_per_gpu=16,
workers_per_gpu=4,
train=dict(
type='RepeatDataset',
times=5,
dataset=dict(
type=dataset_type,
data_root=data_root,
ann_file=data_root + 'sunrgbd_infos_train.pkl',
pipeline=train_pipeline,
classes=class_names,
filter_empty_gt=False,
# we use box_type_3d='LiDAR' in kitti and nuscenes dataset
# and box_type_3d='Depth' in sunrgbd and scannet dataset.
box_type_3d='Depth')),
val=dict(
type=dataset_type,
data_root=data_root,
ann_file=data_root + 'sunrgbd_infos_val.pkl',
pipeline=test_pipeline,
classes=class_names,
test_mode=True,
box_type_3d='Depth'),
test=dict(
type=dataset_type,
data_root=data_root,
ann_file=data_root + 'sunrgbd_infos_val.pkl',
pipeline=test_pipeline,
classes=class_names,
test_mode=True,
box_type_3d='Depth'))

# optimizer
lr = 0.004
optimizer = dict(
lr=lr,
weight_decay=0.00000001,
paramwise_cfg=dict(
custom_keys={
'bbox_head.decoder_layers': dict(lr_mult=0.05, decay_mult=1.0),
'bbox_head.decoder_self_posembeds': dict(
lr_mult=0.05, decay_mult=1.0),
'bbox_head.decoder_cross_posembeds': dict(
lr_mult=0.05, decay_mult=1.0),
'bbox_head.decoder_query_proj': dict(lr_mult=0.05, decay_mult=1.0),
'bbox_head.decoder_key_proj': dict(lr_mult=0.05, decay_mult=1.0)
}))

optimizer_config = dict(grad_clip=dict(max_norm=0.1, norm_type=2))
lr_config = dict(policy='step', warmup=None, step=[84, 96, 108])

# runtime settings
runner = dict(type='EpochBasedRunner', max_epochs=120)
Loading