Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Support TTA for Segmentor #2382

Merged
merged 8 commits into from
Apr 24, 2023
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 44 additions & 19 deletions configs/_base_/datasets/s3dis-seg.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,25 +73,6 @@
with_seg_3d=True,
backend_args=backend_args),
dict(type='NormalizePointsColor', color_mean=None),
dict(
# a wrapper in order to successfully call test function
# actually we don't perform test-time-aug
type='MultiScaleFlipAug3D',
img_scale=(1333, 800),
pts_scale_ratio=1,
flip=False,
transforms=[
dict(
type='GlobalRotScaleTrans',
rot_range=[0, 0],
scale_ratio_range=[1., 1.],
translation_std=[0, 0, 0]),
dict(
type='RandomFlip3D',
sync_2d=False,
flip_ratio_bev_horizontal=0.0,
flip_ratio_bev_vertical=0.0),
]),
dict(type='Pack3DDetInputs', keys=['points'])
]
# construct a pipeline for data and gt loading in show function
Expand All @@ -109,6 +90,48 @@
dict(type='NormalizePointsColor', color_mean=None),
dict(type='Pack3DDetInputs', keys=['points'])
]
tta_pipeline = [
dict(
type='LoadPointsFromFile',
coord_type='DEPTH',
shift_height=False,
use_color=True,
load_dim=6,
use_dim=[0, 1, 2, 3, 4, 5],
backend_args=backend_args),
dict(
type='LoadAnnotations3D',
with_bbox_3d=False,
with_label_3d=False,
with_mask_3d=False,
with_seg_3d=True,
backend_args=backend_args),
dict(type='NormalizePointsColor', color_mean=None),
dict(
type='TestTimeAug',
transforms=[[
dict(
type='RandomFlip3D',
sync_2d=False,
flip_ratio_bev_horizontal=0.,
flip_ratio_bev_vertical=0.),
dict(
type='RandomFlip3D',
sync_2d=False,
flip_ratio_bev_horizontal=0.,
flip_ratio_bev_vertical=1.),
dict(
type='RandomFlip3D',
sync_2d=False,
flip_ratio_bev_horizontal=1.,
flip_ratio_bev_vertical=0.),
dict(
type='RandomFlip3D',
sync_2d=False,
flip_ratio_bev_horizontal=1.,
flip_ratio_bev_vertical=1.)
], [dict(type='Pack3DDetInputs', keys=['points'])]])
]

# train on area 1, 2, 3, 4, 6
# test on area 5
Expand Down Expand Up @@ -157,3 +180,5 @@
vis_backends = [dict(type='LocalVisBackend')]
visualizer = dict(
type='Det3DLocalVisualizer', vis_backends=vis_backends, name='visualizer')

tta_model = dict(type='Seg3DTTAModel')
63 changes: 44 additions & 19 deletions configs/_base_/datasets/scannet-seg.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,25 +73,6 @@
with_seg_3d=True,
backend_args=backend_args),
dict(type='NormalizePointsColor', color_mean=None),
dict(
# a wrapper in order to successfully call test function
# actually we don't perform test-time-aug
type='MultiScaleFlipAug3D',
img_scale=(1333, 800),
pts_scale_ratio=1,
flip=False,
transforms=[
dict(
type='GlobalRotScaleTrans',
rot_range=[0, 0],
scale_ratio_range=[1., 1.],
translation_std=[0, 0, 0]),
dict(
type='RandomFlip3D',
sync_2d=False,
flip_ratio_bev_horizontal=0.0,
flip_ratio_bev_vertical=0.0),
]),
dict(type='Pack3DDetInputs', keys=['points'])
]
# construct a pipeline for data and gt loading in show function
Expand All @@ -109,6 +90,48 @@
dict(type='NormalizePointsColor', color_mean=None),
dict(type='Pack3DDetInputs', keys=['points'])
]
tta_pipeline = [
Xiangxu-0103 marked this conversation as resolved.
Show resolved Hide resolved
dict(
type='LoadPointsFromFile',
coord_type='DEPTH',
shift_height=False,
use_color=True,
load_dim=6,
use_dim=[0, 1, 2, 3, 4, 5],
backend_args=backend_args),
dict(
type='LoadAnnotations3D',
with_bbox_3d=False,
with_label_3d=False,
with_mask_3d=False,
with_seg_3d=True,
backend_args=backend_args),
dict(type='NormalizePointsColor', color_mean=None),
dict(
type='TestTimeAug',
transforms=[[
dict(
type='RandomFlip3D',
sync_2d=False,
flip_ratio_bev_horizontal=0.,
flip_ratio_bev_vertical=0.),
dict(
type='RandomFlip3D',
sync_2d=False,
flip_ratio_bev_horizontal=0.,
flip_ratio_bev_vertical=1.),
dict(
type='RandomFlip3D',
sync_2d=False,
flip_ratio_bev_horizontal=1.,
flip_ratio_bev_vertical=0.),
dict(
type='RandomFlip3D',
sync_2d=False,
flip_ratio_bev_horizontal=1.,
flip_ratio_bev_vertical=1.)
], [dict(type='Pack3DDetInputs', keys=['points'])]])
]

train_dataloader = dict(
batch_size=8,
Expand Down Expand Up @@ -152,3 +175,5 @@
vis_backends = [dict(type='LocalVisBackend')]
visualizer = dict(
type='Det3DLocalVisualizer', vis_backends=vis_backends, name='visualizer')

tta_model = dict(type='Seg3DTTAModel')
100 changes: 70 additions & 30 deletions configs/_base_/datasets/semantickitti.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@
seg_offset=2**16,
dataset_type='semantickitti',
backend_args=backend_args),
dict(type='PointSegClassMapping', ),
dict(type='PointSegClassMapping'),
dict(
type='RandomFlip3D',
sync_2d=False,
Expand Down Expand Up @@ -112,12 +112,21 @@
seg_offset=2**16,
dataset_type='semantickitti',
backend_args=backend_args),
dict(type='PointSegClassMapping', ),
dict(type='Pack3DDetInputs', keys=['points', 'pts_semantic_mask'])
Xiangxu-0103 marked this conversation as resolved.
Show resolved Hide resolved
dict(type='PointSegClassMapping'),
dict(type='Pack3DDetInputs', keys=['points'])
]
# construct a pipeline for data and gt loading in show function
# please keep its loading function consistent with test_pipeline (e.g. client)
eval_pipeline = [
dict(
type='LoadPointsFromFile',
coord_type='LIDAR',
load_dim=4,
use_dim=4,
backend_args=backend_args),
dict(type='Pack3DDetInputs', keys=['points'])
]
tta_pipeline = [
dict(
type='LoadPointsFromFile',
coord_type='LIDAR',
Expand All @@ -133,46 +142,75 @@
seg_offset=2**16,
dataset_type='semantickitti',
backend_args=backend_args),
dict(type='PointSegClassMapping', ),
dict(type='Pack3DDetInputs', keys=['points', 'pts_semantic_mask'])
dict(type='PointSegClassMapping'),
dict(
type='TestTimeAug',
transforms=[[
dict(
type='RandomFlip3D',
sync_2d=False,
flip_ratio_bev_horizontal=0.,
flip_ratio_bev_vertical=0.),
dict(
type='RandomFlip3D',
sync_2d=False,
flip_ratio_bev_horizontal=0.,
flip_ratio_bev_vertical=1.),
dict(
type='RandomFlip3D',
sync_2d=False,
flip_ratio_bev_horizontal=1.,
flip_ratio_bev_vertical=0.),
dict(
type='RandomFlip3D',
sync_2d=False,
flip_ratio_bev_horizontal=1.,
flip_ratio_bev_vertical=1.)
],
[
dict(
type='GlobalRotScaleTrans',
rot_range=[pcd_rotate_range, pcd_rotate_range],
scale_ratio_range=[
pcd_scale_factor, pcd_scale_factor
],
translation_std=[0, 0, 0])
for pcd_rotate_range in [-0.78539816, 0.0, 0.78539816]
for pcd_scale_factor in [0.95, 1.0, 1.05]
], [dict(type='Pack3DDetInputs', keys=['points'])]])
]

train_dataloader = dict(
batch_size=2,
num_workers=4,
persistent_workers=True,
sampler=dict(type='DefaultSampler', shuffle=True),
dataset=dict(
type='RepeatDataset',
times=1,
dataset=dict(
type=dataset_type,
data_root=data_root,
ann_file='semantickitti_infos_train.pkl',
pipeline=train_pipeline,
metainfo=metainfo,
modality=input_modality,
ignore_index=19,
backend_args=backend_args)),
)
type=dataset_type,
data_root=data_root,
ann_file='semantickitti_infos_train.pkl',
pipeline=train_pipeline,
metainfo=metainfo,
modality=input_modality,
ignore_index=19,
backend_args=backend_args))

test_dataloader = dict(
batch_size=1,
num_workers=1,
persistent_workers=True,
drop_last=False,
sampler=dict(type='DefaultSampler', shuffle=False),
dataset=dict(
type='RepeatDataset',
times=1,
dataset=dict(
type=dataset_type,
data_root=data_root,
ann_file='semantickitti_infos_val.pkl',
pipeline=test_pipeline,
metainfo=metainfo,
modality=input_modality,
ignore_index=19,
test_mode=True,
backend_args=backend_args)),
)
type=dataset_type,
data_root=data_root,
ann_file='semantickitti_infos_val.pkl',
pipeline=test_pipeline,
metainfo=metainfo,
modality=input_modality,
ignore_index=19,
test_mode=True,
backend_args=backend_args))

val_dataloader = test_dataloader

Expand All @@ -182,3 +220,5 @@
vis_backends = [dict(type='LocalVisBackend')]
visualizer = dict(
type='Det3DLocalVisualizer', vis_backends=vis_backends, name='visualizer')

tta_model = dict(type='Seg3DTTAModel')
2 changes: 1 addition & 1 deletion configs/minkunet/minkunet_w32_8xb2-15e_semantickitti.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
]

train_dataloader = dict(
sampler=dict(seed=0), dataset=dict(dataset=dict(pipeline=train_pipeline)))
sampler=dict(seed=0), dataset=dict(pipeline=train_pipeline))

lr = 0.24
optim_wrapper = dict(
Expand Down
2 changes: 1 addition & 1 deletion configs/spvcnn/spvcnn_w32_8xb2-15e_semantickitti.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
]

train_dataloader = dict(
sampler=dict(seed=0), dataset=dict(dataset=dict(pipeline=train_pipeline)))
sampler=dict(seed=0), dataset=dict(pipeline=train_pipeline))

lr = 0.24
optim_wrapper = dict(
Expand Down
8 changes: 4 additions & 4 deletions mmdet3d/datasets/transforms/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .dbsampler import DataBaseSampler
from .formating import Pack3DDetInputs
from .loading import (LoadAnnotations3D, LoadImageFromFileMono3D,
LoadMultiViewImageFromFiles, LoadPointsFromDict,
LoadPointsFromFile, LoadPointsFromMultiSweeps,
MonoDet3DInferencerLoader,
from .loading import (LidarDet3DInferencerLoader, LoadAnnotations3D,
LoadImageFromFileMono3D, LoadMultiViewImageFromFiles,
LoadPointsFromDict, LoadPointsFromFile,
LoadPointsFromMultiSweeps, MonoDet3DInferencerLoader,
MultiModalityDet3DInferencerLoader, NormalizePointsColor,
PointSegClassMapping)
from .test_time_aug import MultiScaleFlipAug3D
Expand Down
6 changes: 5 additions & 1 deletion mmdet3d/models/segmentors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,9 @@
from .cylinder3d import Cylinder3D
from .encoder_decoder import EncoderDecoder3D
from .minkunet import MinkUNet
from .seg3d_tta import Seg3DTTAModel

__all__ = ['Base3DSegmentor', 'EncoderDecoder3D', 'Cylinder3D', 'MinkUNet']
__all__ = [
'Base3DSegmentor', 'EncoderDecoder3D', 'Cylinder3D', 'MinkUNet',
'Seg3DTTAModel'
]
26 changes: 14 additions & 12 deletions mmdet3d/models/segmentors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,17 +132,12 @@ def _forward(self,
"""
pass

@abstractmethod
def aug_test(self, batch_inputs, batch_data_samples):
"""Placeholder for augmentation test."""
pass

def postprocess_result(self, seg_pred_list: List[dict],
def postprocess_result(self, seg_logits_list: List[Tensor],
batch_data_samples: SampleList) -> SampleList:
"""Convert results list to `Det3DDataSample`.

Args:
seg_logits_list (List[dict]): List of segmentation results,
seg_logits_list (List[Tensor]): List of segmentation results,
seg_logits from model of each input point clouds sample.
batch_data_samples (List[:obj:`Det3DDataSample`]): The det3d data
samples. It usually includes information such as `metainfo` and
Expand All @@ -152,12 +147,19 @@ def postprocess_result(self, seg_pred_list: List[dict],
List[:obj:`Det3DDataSample`]: Segmentation results of the input
points. Each Det3DDataSample usually contains:

- ``pred_pts_seg`` (PixelData): Prediction of 3D semantic
- ``pred_pts_seg`` (PointData): Prediction of 3D semantic
segmentation.
- ``pts_seg_logits`` (PointData): Predicted logits of 3D semantic
segmentation before normalization.
"""

for i in range(len(seg_pred_list)):
seg_pred = seg_pred_list[i]
batch_data_samples[i].set_data(
{'pred_pts_seg': PointData(**{'pts_semantic_mask': seg_pred})})
for i in range(len(seg_logits_list)):
seg_logits = seg_logits_list[i]
seg_pred = seg_logits.argmax(dim=0)
batch_data_samples[i].set_data({
'pts_seg_logits':
PointData(**{'pts_seg_logits': seg_logits}),
'pred_pts_seg':
PointData(**{'pts_semantic_mask': seg_pred})
})
return batch_data_samples
Loading