-
Notifications
You must be signed in to change notification settings - Fork 1.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Feature] Support PointNet++ Segmentor (#528)
* build BaseSegmentor for point sem seg * add encoder-decoder segmentor * update mmseg dependency * fix linting errors * warp predicted seg_mask in dict * add unit test * use build_model to wrap detector and segmentor * fix compatibility with mmseg * faster sliding inference * merge master * configs for training on ScanNet * fix CI errors * add comments & fix typos * hard-code class_weight into configs * fix logger bugs * update segmentor unit test * logger use mmdet3d * use eps to replace hard-coded 1e-3 * add comments * replace np operation with torch code * add comments for class_weight * add comment for BaseSegmentor.simple_test * rewrite EncoderDecoder3D to avoid inheriting from mmseg
- Loading branch information
Showing
27 changed files
with
1,034 additions
and
44 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
_base_ = './pointnet2_ssg.py' | ||
|
||
# model settings | ||
model = dict( | ||
backbone=dict( | ||
_delete_=True, | ||
type='PointNet2SAMSG', | ||
in_channels=6, # [xyz, rgb], should be modified with dataset | ||
num_points=(1024, 256, 64, 16), | ||
radii=((0.05, 0.1), (0.1, 0.2), (0.2, 0.4), (0.4, 0.8)), | ||
num_samples=((16, 32), (16, 32), (16, 32), (16, 32)), | ||
sa_channels=(((16, 16, 32), (32, 32, 64)), ((64, 64, 128), (64, 96, | ||
128)), | ||
((128, 196, 256), (128, 196, 256)), ((256, 256, 512), | ||
(256, 384, 512))), | ||
aggregation_channels=(None, None, None, None), | ||
fps_mods=(('D-FPS'), ('D-FPS'), ('D-FPS'), ('D-FPS')), | ||
fps_sample_range_lists=((-1), (-1), (-1), (-1)), | ||
dilated_group=(False, False, False, False), | ||
out_indices=(0, 1, 2, 3), | ||
sa_cfg=dict( | ||
type='PointSAModuleMSG', | ||
pool_mod='max', | ||
use_xyz=True, | ||
normalize_xyz=False)), | ||
decode_head=dict( | ||
fp_channels=((1536, 256, 256), (512, 256, 256), (352, 256, 128), | ||
(128, 128, 128, 128)))) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
# model settings | ||
model = dict( | ||
type='EncoderDecoder3D', | ||
backbone=dict( | ||
type='PointNet2SASSG', | ||
in_channels=6, # [xyz, rgb], should be modified with dataset | ||
num_points=(1024, 256, 64, 16), | ||
radius=(0.1, 0.2, 0.4, 0.8), | ||
num_samples=(32, 32, 32, 32), | ||
sa_channels=((32, 32, 64), (64, 64, 128), (128, 128, 256), (256, 256, | ||
512)), | ||
fp_channels=(), | ||
norm_cfg=dict(type='BN2d'), | ||
sa_cfg=dict( | ||
type='PointSAModule', | ||
pool_mod='max', | ||
use_xyz=True, | ||
normalize_xyz=False)), | ||
decode_head=dict( | ||
type='PointNet2Head', | ||
fp_channels=((768, 256, 256), (384, 256, 256), (320, 256, 128), | ||
(128, 128, 128, 128)), | ||
channels=128, | ||
dropout_ratio=0.5, | ||
conv_cfg=dict(type='Conv1d'), | ||
norm_cfg=dict(type='BN1d'), | ||
act_cfg=dict(type='ReLU'), | ||
loss_decode=dict( | ||
type='CrossEntropyLoss', | ||
use_sigmoid=False, | ||
class_weight=None, # should be modified with dataset | ||
loss_weight=1.0)), | ||
# model training and testing settings | ||
train_cfg=dict(), | ||
test_cfg=dict(mode='slide')) |
40 changes: 40 additions & 0 deletions
40
configs/pointnet2/pointnet2_msg_16x2_scannet-3d-20class.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
_base_ = [ | ||
'../_base_/datasets/scannet_seg-3d-20class.py', | ||
'../_base_/models/pointnet2_msg.py', '../_base_/default_runtime.py' | ||
] | ||
|
||
# data settings | ||
data = dict(samples_per_gpu=16) | ||
evaluation = dict(interval=5) | ||
|
||
# model settings | ||
model = dict( | ||
decode_head=dict( | ||
num_classes=20, | ||
ignore_index=20, | ||
# `class_weight` is generated in data pre-processing, saved in | ||
# `data/scannet/seg_info/train_label_weight.npy` | ||
# you can copy paste the values here, or input the file path as | ||
# `class_weight=data/scannet/seg_info/train_label_weight.npy` | ||
loss_decode=dict(class_weight=[ | ||
2.389689, 2.7215734, 4.5944676, 4.8543367, 4.096086, 4.907941, | ||
4.690836, 4.512031, 4.623311, 4.9242644, 5.358117, 5.360071, | ||
5.019636, 4.967126, 5.3502126, 5.4023647, 5.4027233, 5.4169416, | ||
5.3954206, 4.6971426 | ||
])), | ||
test_cfg=dict( | ||
num_points=8192, | ||
block_size=1.5, | ||
sample_rate=0.5, | ||
use_normalized_coord=False, | ||
batch_size=24)) | ||
|
||
# optimizer | ||
lr = 0.001 # max learning rate | ||
optimizer = dict(type='Adam', lr=lr, weight_decay=1e-4) | ||
optimizer_config = dict(grad_clip=None) | ||
lr_config = dict(policy='CosineAnnealing', warmup=None, min_lr=1e-5) | ||
|
||
# runtime settings | ||
checkpoint_config = dict(interval=5) | ||
runner = dict(type='EpochBasedRunner', max_epochs=150) |
40 changes: 40 additions & 0 deletions
40
configs/pointnet2/pointnet2_ssg_16x2_scannet-3d-20class.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
_base_ = [ | ||
'../_base_/datasets/scannet_seg-3d-20class.py', | ||
'../_base_/models/pointnet2_ssg.py', '../_base_/default_runtime.py' | ||
] | ||
|
||
# data settings | ||
data = dict(samples_per_gpu=16) | ||
evaluation = dict(interval=5) | ||
|
||
# model settings | ||
model = dict( | ||
decode_head=dict( | ||
num_classes=20, | ||
ignore_index=20, | ||
# `class_weight` is generated in data pre-processing, saved in | ||
# `data/scannet/seg_info/train_label_weight.npy` | ||
# you can copy paste the values here, or input the file path as | ||
# `class_weight=data/scannet/seg_info/train_label_weight.npy` | ||
loss_decode=dict(class_weight=[ | ||
2.389689, 2.7215734, 4.5944676, 4.8543367, 4.096086, 4.907941, | ||
4.690836, 4.512031, 4.623311, 4.9242644, 5.358117, 5.360071, | ||
5.019636, 4.967126, 5.3502126, 5.4023647, 5.4027233, 5.4169416, | ||
5.3954206, 4.6971426 | ||
])), | ||
test_cfg=dict( | ||
num_points=8192, | ||
block_size=1.5, | ||
sample_rate=0.5, | ||
use_normalized_coord=False, | ||
batch_size=24)) | ||
|
||
# optimizer | ||
lr = 0.001 # max learning rate | ||
optimizer = dict(type='Adam', lr=lr, weight_decay=1e-4) | ||
optimizer_config = dict(grad_clip=None) | ||
lr_config = dict(policy='CosineAnnealing', warmup=None, min_lr=1e-5) | ||
|
||
# runtime settings | ||
checkpoint_config = dict(interval=5) | ||
runner = dict(type='EpochBasedRunner', max_epochs=150) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
from mmdet.apis import train_detector | ||
from mmseg.apis import train_segmentor | ||
|
||
|
||
def train_model(model, | ||
dataset, | ||
cfg, | ||
distributed=False, | ||
validate=False, | ||
timestamp=None, | ||
meta=None): | ||
"""A function wrapper for launching model training according to cfg. | ||
Because we need different eval_hook in runner. Should be deprecated in the | ||
future. | ||
""" | ||
if cfg.model.type in ['EncoderDecoder3D']: | ||
train_segmentor( | ||
model, | ||
dataset, | ||
cfg, | ||
distributed=distributed, | ||
validate=validate, | ||
timestamp=timestamp, | ||
meta=meta) | ||
else: | ||
train_detector( | ||
model, | ||
dataset, | ||
cfg, | ||
distributed=distributed, | ||
validate=validate, | ||
timestamp=timestamp, | ||
meta=meta) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.