Skip to content

Commit

Permalink
[Feature] Support PointNet++ decode head (#479)
Browse files Browse the repository at this point in the history
* support PN2 decode head

* add mmseg dependency in github workflow

* complete PN2 decode head

* modify backbone pn2 to support seg task & its unit test

* add unit test for PN2 decode_head
  • Loading branch information
Wuziyi616 authored May 8, 2021
1 parent 3640070 commit 53e0622
Show file tree
Hide file tree
Showing 10 changed files with 378 additions and 15 deletions.
1 change: 1 addition & 0 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ jobs:
run: |
pip install mmcv-full -f https://download.openmmlab.com/mmcv/dist/cu101/${{matrix.torch_version}}/index.html
pip install mmdet==2.11.0
pip install -q git+https://github.com/open-mmlab/mmsegmentation.git
pip install -r requirements.txt
- name: Build and install
run: |
Expand Down
1 change: 1 addition & 0 deletions mmdet3d/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
build_head, build_loss, build_middle_encoder, build_neck,
build_roi_extractor, build_shared_head,
build_voxel_encoder)
from .decode_heads import * # noqa: F401,F403
from .dense_heads import * # noqa: F401,F403
from .detectors import * # noqa: F401,F403
from .fusion_layers import * # noqa: F401,F403
Expand Down
33 changes: 20 additions & 13 deletions mmdet3d/models/backbones/pointnet2_sa_msg.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,15 +102,21 @@ def __init__(self,
cfg=sa_cfg,
bias=True))
skip_channel_list.append(sa_out_channel)
self.aggregation_mlps.append(
ConvModule(
sa_out_channel,
aggregation_channels[sa_index],
conv_cfg=dict(type='Conv1d'),
norm_cfg=dict(type='BN1d'),
kernel_size=1,
bias=True))
sa_in_channel = aggregation_channels[sa_index]

cur_aggregation_channel = aggregation_channels[sa_index]
if cur_aggregation_channel is None:
self.aggregation_mlps.append(None)
sa_in_channel = sa_out_channel
else:
self.aggregation_mlps.append(
ConvModule(
sa_out_channel,
cur_aggregation_channel,
conv_cfg=dict(type='Conv1d'),
norm_cfg=dict(type='BN1d'),
kernel_size=1,
bias=True))
sa_in_channel = cur_aggregation_channel

@auto_fp16(apply_to=('points', ))
def forward(self, points):
Expand Down Expand Up @@ -139,14 +145,15 @@ def forward(self, points):
sa_features = [features]
sa_indices = [indices]

out_sa_xyz = []
out_sa_features = []
out_sa_indices = []
out_sa_xyz = [xyz]
out_sa_features = [features]
out_sa_indices = [indices]

for i in range(self.num_sa):
cur_xyz, cur_features, cur_indices = self.SA_modules[i](
sa_xyz[i], sa_features[i])
cur_features = self.aggregation_mlps[i](cur_features)
if self.aggregation_mlps[i] is not None:
cur_features = self.aggregation_mlps[i](cur_features)
sa_xyz.append(cur_xyz)
sa_features.append(cur_features)
sa_indices.append(
Expand Down
7 changes: 6 additions & 1 deletion mmdet3d/models/backbones/pointnet2_sa_ssg.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,5 +132,10 @@ def forward(self, points):
fp_indices.append(sa_indices[self.num_sa - i - 1])

ret = dict(
fp_xyz=fp_xyz, fp_features=fp_features, fp_indices=fp_indices)
fp_xyz=fp_xyz,
fp_features=fp_features,
fp_indices=fp_indices,
sa_xyz=sa_xyz,
sa_features=sa_features,
sa_indices=sa_indices)
return ret
3 changes: 3 additions & 0 deletions mmdet3d/models/decode_heads/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .pointnet2_head import PointNet2Head

__all__ = ['PointNet2Head']
118 changes: 118 additions & 0 deletions mmdet3d/models/decode_heads/decode_head.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
from abc import ABCMeta, abstractmethod
from mmcv.cnn import normal_init
from mmcv.runner import auto_fp16, force_fp32
from torch import nn as nn

from mmseg.models.builder import build_loss


class Base3DDecodeHead(nn.Module, metaclass=ABCMeta):
"""Base class for BaseDecodeHead.
Args:
channels (int): Channels after modules, before conv_seg.
num_classes (int): Number of classes.
dropout_ratio (float): Ratio of dropout layer. Default: 0.5.
conv_cfg (dict|None): Config of conv layers.
Default: dict(type='Conv1d').
norm_cfg (dict|None): Config of norm layers.
Default: dict(type='BN1d').
act_cfg (dict): Config of activation layers.
Default: dict(type='ReLU').
loss_decode (dict): Config of decode loss.
Default: dict(type='CrossEntropyLoss').
ignore_index (int | None): The label index to be ignored. When using
masked BCE loss, ignore_index should be set to None. Default: 255.
"""

def __init__(self,
channels,
num_classes,
dropout_ratio=0.5,
conv_cfg=dict(type='Conv1d'),
norm_cfg=dict(type='BN1d'),
act_cfg=dict(type='ReLU'),
loss_decode=dict(
type='CrossEntropyLoss',
use_sigmoid=False,
class_weight=None,
loss_weight=1.0),
ignore_index=255):
super(Base3DDecodeHead, self).__init__()
self.channels = channels
self.num_classes = num_classes
self.dropout_ratio = dropout_ratio
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
self.act_cfg = act_cfg
self.loss_decode = build_loss(loss_decode)
self.ignore_index = ignore_index

self.conv_seg = nn.Conv1d(channels, num_classes, kernel_size=1)
if dropout_ratio > 0:
self.dropout = nn.Dropout(dropout_ratio)
else:
self.dropout = None
self.fp16_enabled = False

def init_weights(self):
"""Initialize weights of classification layer."""
normal_init(self.conv_seg, mean=0, std=0.01)

@auto_fp16()
@abstractmethod
def forward(self, inputs):
"""Placeholder of forward function."""
pass

def forward_train(self, inputs, img_metas, gt_semantic_seg, train_cfg):
"""Forward function for training.
Args:
inputs (list[Tensor]): List of multi-level point features.
img_metas (list[dict]): Meta information of each sample.
gt_semantic_seg (torch.Tensor): Semantic segmentation masks
used if the architecture supports semantic segmentation task.
train_cfg (dict): The training config.
Returns:
dict[str, Tensor]: a dictionary of loss components
"""
seg_logits = self.forward(inputs)
losses = self.losses(seg_logits, gt_semantic_seg)
return losses

def forward_test(self, inputs, img_metas, test_cfg):
"""Forward function for testing.
Args:
inputs (list[Tensor]): List of multi-level point features.
img_metas (list[dict]): Meta information of each sample.
test_cfg (dict): The testing config.
Returns:
Tensor: Output segmentation map.
"""
return self.forward(inputs)

def cls_seg(self, feat):
"""Classify each points."""
if self.dropout is not None:
feat = self.dropout(feat)
output = self.conv_seg(feat)
return output

@force_fp32(apply_to=('seg_logit', ))
def losses(self, seg_logit, seg_label):
"""Compute semantic segmentation loss.
Args:
seg_logit (torch.Tensor): Predicted per-point segmentation logits \
of shape [B, num_classes, N].
seg_label (torch.Tensor): Ground-truth segmentation label of \
shape [B, N].
"""
loss = dict()
loss['loss_sem_seg'] = self.loss_decode(
seg_logit, seg_label, ignore_index=self.ignore_index)
return loss
80 changes: 80 additions & 0 deletions mmdet3d/models/decode_heads/pointnet2_head.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
from mmcv.cnn.bricks import ConvModule
from torch import nn as nn

from mmdet3d.ops import PointFPModule
from mmdet.models import HEADS
from .decode_head import Base3DDecodeHead


@HEADS.register_module()
class PointNet2Head(Base3DDecodeHead):
r"""PointNet2 decoder head.
Decoder head used in `PointNet++ <https://arxiv.org/abs/1706.02413>`_.
Refer to the `official code <https://github.com/charlesq34/pointnet2>`_.
Args:
fp_channels (tuple[tuple[int]]): Tuple of mlp channels in FP modules.
"""

def __init__(self,
fp_channels=((768, 256, 256), (384, 256, 256),
(320, 256, 128), (128, 128, 128, 128)),
**kwargs):
super(PointNet2Head, self).__init__(**kwargs)

self.num_fp = len(fp_channels)
self.FP_modules = nn.ModuleList()
for cur_fp_mlps in fp_channels:
self.FP_modules.append(PointFPModule(mlp_channels=cur_fp_mlps))

# https://github.com/charlesq34/pointnet2/blob/master/models/pointnet2_sem_seg.py#L40
self.pre_seg_conv = ConvModule(
fp_channels[-1][-1],
self.channels,
kernel_size=1,
bias=True,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)

def _extract_input(self, feat_dict):
"""Extract inputs from features dictionary.
Args:
feat_dict (dict): Feature dict from backbone.
Returns:
list[torch.Tensor]: Coordinates of multiple levels of points.
list[torch.Tensor]: Features of multiple levels of points.
"""
sa_xyz = feat_dict['sa_xyz']
sa_features = feat_dict['sa_features']
assert len(sa_xyz) == len(sa_features)

return sa_xyz, sa_features

def forward(self, feat_dict):
"""Forward pass.
Args:
feat_dict (dict): Feature dict from backbone.
Returns:
torch.Tensor: Segmentation map of shape [B, num_classes, N].
"""
sa_xyz, sa_features = self._extract_input(feat_dict)

# https://github.com/charlesq34/pointnet2/blob/master/models/pointnet2_sem_seg.py#L24
sa_features[0] = None

fp_feature = sa_features[-1]

for i in range(self.num_fp):
# consume the points in a bottom-up manner
fp_feature = self.FP_modules[i](sa_xyz[-(i + 2)], sa_xyz[-(i + 1)],
sa_features[-(i + 2)], fp_feature)
output = self.pre_seg_conv(fp_feature)
output = self.cls_seg(output)

return output
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ SPLIT_BEFORE_EXPRESSION_AFTER_OPENING_PAREN = true
line_length = 79
multi_line_output = 0
known_standard_library = setuptools
known_first_party = mmdet,mmdet3d
known_first_party = mmdet,mmseg,mmdet3d
known_third_party = cv2,indoor3d_util,load_scannet_data,lyft_dataset_sdk,m2r,matplotlib,mmcv,nuimages,numba,numpy,nuscenes,pandas,plyfile,pycocotools,pyquaternion,pytest,recommonmark,scannet_utils,scipy,seaborn,shapely,skimage,tensorflow,terminaltables,torch,trimesh,waymo_open_dataset
no_lines_before = STDLIB,LOCALFOLDER
default_section = THIRDPARTY
66 changes: 66 additions & 0 deletions tests/test_models/test_backbones.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,29 @@ def test_pointnet2_sa_ssg():
fp_xyz = ret_dict['fp_xyz']
fp_features = ret_dict['fp_features']
fp_indices = ret_dict['fp_indices']
sa_xyz = ret_dict['sa_xyz']
sa_features = ret_dict['sa_features']
sa_indices = ret_dict['sa_indices']
assert len(fp_xyz) == len(fp_features) == len(fp_indices) == 3
assert len(sa_xyz) == len(sa_features) == len(sa_indices) == 3
assert fp_xyz[0].shape == torch.Size([1, 16, 3])
assert fp_xyz[1].shape == torch.Size([1, 32, 3])
assert fp_xyz[2].shape == torch.Size([1, 100, 3])
assert fp_features[0].shape == torch.Size([1, 16, 16])
assert fp_features[1].shape == torch.Size([1, 16, 32])
assert fp_features[2].shape == torch.Size([1, 16, 100])
assert fp_indices[0].shape == torch.Size([1, 16])
assert fp_indices[1].shape == torch.Size([1, 32])
assert fp_indices[2].shape == torch.Size([1, 100])
assert sa_xyz[0].shape == torch.Size([1, 100, 3])
assert sa_xyz[1].shape == torch.Size([1, 32, 3])
assert sa_xyz[2].shape == torch.Size([1, 16, 3])
assert sa_features[0].shape == torch.Size([1, 3, 100])
assert sa_features[1].shape == torch.Size([1, 16, 32])
assert sa_features[2].shape == torch.Size([1, 16, 16])
assert sa_indices[0].shape == torch.Size([1, 100])
assert sa_indices[1].shape == torch.Size([1, 32])
assert sa_indices[2].shape == torch.Size([1, 16])


def test_multi_backbone():
Expand Down Expand Up @@ -156,6 +173,8 @@ def test_multi_backbone():
def test_pointnet2_sa_msg():
if not torch.cuda.is_available():
pytest.skip()

# PN2MSG used in 3DSSD
cfg = dict(
type='PointNet2SAMSG',
in_channels=4,
Expand Down Expand Up @@ -216,3 +235,50 @@ def test_pointnet2_sa_msg():
pool_mod='max',
use_xyz=True,
normalize_xyz=False)))

# PN2MSG used in segmentation
cfg = dict(
type='PointNet2SAMSG',
in_channels=6, # [xyz, rgb]
num_points=(1024, 256, 64, 16),
radii=((0.05, 0.1), (0.1, 0.2), (0.2, 0.4), (0.4, 0.8)),
num_samples=((16, 32), (16, 32), (16, 32), (16, 32)),
sa_channels=(((16, 16, 32), (32, 32, 64)), ((64, 64, 128), (64, 96,
128)),
((128, 196, 256), (128, 196, 256)), ((256, 256, 512),
(256, 384, 512))),
aggregation_channels=(None, None, None, None),
fps_mods=(('D-FPS'), ('D-FPS'), ('D-FPS'), ('D-FPS')),
fps_sample_range_lists=((-1), (-1), (-1), (-1)),
dilated_group=(False, False, False, False),
out_indices=(0, 1, 2, 3),
norm_cfg=dict(type='BN2d'),
sa_cfg=dict(
type='PointSAModuleMSG',
pool_mod='max',
use_xyz=True,
normalize_xyz=False))

self = build_backbone(cfg)
self.cuda()
ret_dict = self(xyz)
sa_xyz = ret_dict['sa_xyz']
sa_features = ret_dict['sa_features']
sa_indices = ret_dict['sa_indices']

assert len(sa_xyz) == len(sa_features) == len(sa_indices) == 5
assert sa_xyz[0].shape == torch.Size([1, 100, 3])
assert sa_xyz[1].shape == torch.Size([1, 1024, 3])
assert sa_xyz[2].shape == torch.Size([1, 256, 3])
assert sa_xyz[3].shape == torch.Size([1, 64, 3])
assert sa_xyz[4].shape == torch.Size([1, 16, 3])
assert sa_features[0].shape == torch.Size([1, 3, 100])
assert sa_features[1].shape == torch.Size([1, 96, 1024])
assert sa_features[2].shape == torch.Size([1, 256, 256])
assert sa_features[3].shape == torch.Size([1, 512, 64])
assert sa_features[4].shape == torch.Size([1, 1024, 16])
assert sa_indices[0].shape == torch.Size([1, 100])
assert sa_indices[1].shape == torch.Size([1, 1024])
assert sa_indices[2].shape == torch.Size([1, 256])
assert sa_indices[3].shape == torch.Size([1, 64])
assert sa_indices[4].shape == torch.Size([1, 16])
Loading

0 comments on commit 53e0622

Please sign in to comment.