diff --git a/configs/_base_/datasets/s3dis-seg.py b/configs/_base_/datasets/s3dis-seg.py index 0158e8ba9e..fdebc94b4a 100644 --- a/configs/_base_/datasets/s3dis-seg.py +++ b/configs/_base_/datasets/s3dis-seg.py @@ -73,25 +73,6 @@ with_seg_3d=True, backend_args=backend_args), dict(type='NormalizePointsColor', color_mean=None), - dict( - # a wrapper in order to successfully call test function - # actually we don't perform test-time-aug - type='MultiScaleFlipAug3D', - img_scale=(1333, 800), - pts_scale_ratio=1, - flip=False, - transforms=[ - dict( - type='GlobalRotScaleTrans', - rot_range=[0, 0], - scale_ratio_range=[1., 1.], - translation_std=[0, 0, 0]), - dict( - type='RandomFlip3D', - sync_2d=False, - flip_ratio_bev_horizontal=0.0, - flip_ratio_bev_vertical=0.0), - ]), dict(type='Pack3DDetInputs', keys=['points']) ] # construct a pipeline for data and gt loading in show function @@ -109,6 +90,33 @@ dict(type='NormalizePointsColor', color_mean=None), dict(type='Pack3DDetInputs', keys=['points']) ] +tta_pipeline = [ + dict( + type='LoadPointsFromFile', + coord_type='DEPTH', + shift_height=False, + use_color=True, + load_dim=6, + use_dim=[0, 1, 2, 3, 4, 5], + backend_args=backend_args), + dict( + type='LoadAnnotations3D', + with_bbox_3d=False, + with_label_3d=False, + with_mask_3d=False, + with_seg_3d=True, + backend_args=backend_args), + dict(type='NormalizePointsColor', color_mean=None), + dict( + type='TestTimeAug', + transforms=[[ + dict( + type='RandomFlip3D', + sync_2d=False, + flip_ratio_bev_horizontal=0., + flip_ratio_bev_vertical=0.) + ], [dict(type='Pack3DDetInputs', keys=['points'])]]) +] # train on area 1, 2, 3, 4, 6 # test on area 5 @@ -157,3 +165,5 @@ vis_backends = [dict(type='LocalVisBackend')] visualizer = dict( type='Det3DLocalVisualizer', vis_backends=vis_backends, name='visualizer') + +tta_model = dict(type='Seg3DTTAModel') diff --git a/configs/_base_/datasets/scannet-seg.py b/configs/_base_/datasets/scannet-seg.py index 6e94b34117..759e5e2a38 100644 --- a/configs/_base_/datasets/scannet-seg.py +++ b/configs/_base_/datasets/scannet-seg.py @@ -73,25 +73,6 @@ with_seg_3d=True, backend_args=backend_args), dict(type='NormalizePointsColor', color_mean=None), - dict( - # a wrapper in order to successfully call test function - # actually we don't perform test-time-aug - type='MultiScaleFlipAug3D', - img_scale=(1333, 800), - pts_scale_ratio=1, - flip=False, - transforms=[ - dict( - type='GlobalRotScaleTrans', - rot_range=[0, 0], - scale_ratio_range=[1., 1.], - translation_std=[0, 0, 0]), - dict( - type='RandomFlip3D', - sync_2d=False, - flip_ratio_bev_horizontal=0.0, - flip_ratio_bev_vertical=0.0), - ]), dict(type='Pack3DDetInputs', keys=['points']) ] # construct a pipeline for data and gt loading in show function @@ -109,6 +90,33 @@ dict(type='NormalizePointsColor', color_mean=None), dict(type='Pack3DDetInputs', keys=['points']) ] +tta_pipeline = [ + dict( + type='LoadPointsFromFile', + coord_type='DEPTH', + shift_height=False, + use_color=True, + load_dim=6, + use_dim=[0, 1, 2, 3, 4, 5], + backend_args=backend_args), + dict( + type='LoadAnnotations3D', + with_bbox_3d=False, + with_label_3d=False, + with_mask_3d=False, + with_seg_3d=True, + backend_args=backend_args), + dict(type='NormalizePointsColor', color_mean=None), + dict( + type='TestTimeAug', + transforms=[[ + dict( + type='RandomFlip3D', + sync_2d=False, + flip_ratio_bev_horizontal=0., + flip_ratio_bev_vertical=0.) + ], [dict(type='Pack3DDetInputs', keys=['points'])]]) +] train_dataloader = dict( batch_size=8, @@ -152,3 +160,5 @@ vis_backends = [dict(type='LocalVisBackend')] visualizer = dict( type='Det3DLocalVisualizer', vis_backends=vis_backends, name='visualizer') + +tta_model = dict(type='Seg3DTTAModel') diff --git a/configs/_base_/datasets/semantickitti.py b/configs/_base_/datasets/semantickitti.py index aa253e5393..61c9ef5b66 100644 --- a/configs/_base_/datasets/semantickitti.py +++ b/configs/_base_/datasets/semantickitti.py @@ -82,7 +82,7 @@ seg_offset=2**16, dataset_type='semantickitti', backend_args=backend_args), - dict(type='PointSegClassMapping', ), + dict(type='PointSegClassMapping'), dict( type='RandomFlip3D', sync_2d=False, @@ -112,12 +112,21 @@ seg_offset=2**16, dataset_type='semantickitti', backend_args=backend_args), - dict(type='PointSegClassMapping', ), - dict(type='Pack3DDetInputs', keys=['points', 'pts_semantic_mask']) + dict(type='PointSegClassMapping'), + dict(type='Pack3DDetInputs', keys=['points']) ] # construct a pipeline for data and gt loading in show function # please keep its loading function consistent with test_pipeline (e.g. client) eval_pipeline = [ + dict( + type='LoadPointsFromFile', + coord_type='LIDAR', + load_dim=4, + use_dim=4, + backend_args=backend_args), + dict(type='Pack3DDetInputs', keys=['points']) +] +tta_pipeline = [ dict( type='LoadPointsFromFile', coord_type='LIDAR', @@ -133,46 +142,75 @@ seg_offset=2**16, dataset_type='semantickitti', backend_args=backend_args), - dict(type='PointSegClassMapping', ), - dict(type='Pack3DDetInputs', keys=['points', 'pts_semantic_mask']) + dict(type='PointSegClassMapping'), + dict( + type='TestTimeAug', + transforms=[[ + dict( + type='RandomFlip3D', + sync_2d=False, + flip_ratio_bev_horizontal=0., + flip_ratio_bev_vertical=0.), + dict( + type='RandomFlip3D', + sync_2d=False, + flip_ratio_bev_horizontal=0., + flip_ratio_bev_vertical=1.), + dict( + type='RandomFlip3D', + sync_2d=False, + flip_ratio_bev_horizontal=1., + flip_ratio_bev_vertical=0.), + dict( + type='RandomFlip3D', + sync_2d=False, + flip_ratio_bev_horizontal=1., + flip_ratio_bev_vertical=1.) + ], + [ + dict( + type='GlobalRotScaleTrans', + rot_range=[pcd_rotate_range, pcd_rotate_range], + scale_ratio_range=[ + pcd_scale_factor, pcd_scale_factor + ], + translation_std=[0, 0, 0]) + for pcd_rotate_range in [-0.78539816, 0.0, 0.78539816] + for pcd_scale_factor in [0.95, 1.0, 1.05] + ], [dict(type='Pack3DDetInputs', keys=['points'])]]) ] train_dataloader = dict( batch_size=2, num_workers=4, + persistent_workers=True, sampler=dict(type='DefaultSampler', shuffle=True), dataset=dict( - type='RepeatDataset', - times=1, - dataset=dict( - type=dataset_type, - data_root=data_root, - ann_file='semantickitti_infos_train.pkl', - pipeline=train_pipeline, - metainfo=metainfo, - modality=input_modality, - ignore_index=19, - backend_args=backend_args)), -) + type=dataset_type, + data_root=data_root, + ann_file='semantickitti_infos_train.pkl', + pipeline=train_pipeline, + metainfo=metainfo, + modality=input_modality, + ignore_index=19, + backend_args=backend_args)) test_dataloader = dict( batch_size=1, num_workers=1, + persistent_workers=True, + drop_last=False, sampler=dict(type='DefaultSampler', shuffle=False), dataset=dict( - type='RepeatDataset', - times=1, - dataset=dict( - type=dataset_type, - data_root=data_root, - ann_file='semantickitti_infos_val.pkl', - pipeline=test_pipeline, - metainfo=metainfo, - modality=input_modality, - ignore_index=19, - test_mode=True, - backend_args=backend_args)), -) + type=dataset_type, + data_root=data_root, + ann_file='semantickitti_infos_val.pkl', + pipeline=test_pipeline, + metainfo=metainfo, + modality=input_modality, + ignore_index=19, + test_mode=True, + backend_args=backend_args)) val_dataloader = test_dataloader @@ -182,3 +220,5 @@ vis_backends = [dict(type='LocalVisBackend')] visualizer = dict( type='Det3DLocalVisualizer', vis_backends=vis_backends, name='visualizer') + +tta_model = dict(type='Seg3DTTAModel') diff --git a/configs/minkunet/minkunet_w32_8xb2-15e_semantickitti.py b/configs/minkunet/minkunet_w32_8xb2-15e_semantickitti.py index 80f5283ce8..22c70cf553 100644 --- a/configs/minkunet/minkunet_w32_8xb2-15e_semantickitti.py +++ b/configs/minkunet/minkunet_w32_8xb2-15e_semantickitti.py @@ -24,7 +24,7 @@ ] train_dataloader = dict( - sampler=dict(seed=0), dataset=dict(dataset=dict(pipeline=train_pipeline))) + sampler=dict(seed=0), dataset=dict(pipeline=train_pipeline)) lr = 0.24 optim_wrapper = dict( diff --git a/configs/spvcnn/spvcnn_w32_8xb2-15e_semantickitti.py b/configs/spvcnn/spvcnn_w32_8xb2-15e_semantickitti.py index 0d3f30e103..090576f432 100644 --- a/configs/spvcnn/spvcnn_w32_8xb2-15e_semantickitti.py +++ b/configs/spvcnn/spvcnn_w32_8xb2-15e_semantickitti.py @@ -24,7 +24,7 @@ ] train_dataloader = dict( - sampler=dict(seed=0), dataset=dict(dataset=dict(pipeline=train_pipeline))) + sampler=dict(seed=0), dataset=dict(pipeline=train_pipeline)) lr = 0.24 optim_wrapper = dict( diff --git a/mmdet3d/datasets/transforms/__init__.py b/mmdet3d/datasets/transforms/__init__.py index 4c0587f80e..cf91ba2352 100644 --- a/mmdet3d/datasets/transforms/__init__.py +++ b/mmdet3d/datasets/transforms/__init__.py @@ -1,10 +1,10 @@ # Copyright (c) OpenMMLab. All rights reserved. from .dbsampler import DataBaseSampler from .formating import Pack3DDetInputs -from .loading import (LoadAnnotations3D, LoadImageFromFileMono3D, - LoadMultiViewImageFromFiles, LoadPointsFromDict, - LoadPointsFromFile, LoadPointsFromMultiSweeps, - MonoDet3DInferencerLoader, +from .loading import (LidarDet3DInferencerLoader, LoadAnnotations3D, + LoadImageFromFileMono3D, LoadMultiViewImageFromFiles, + LoadPointsFromDict, LoadPointsFromFile, + LoadPointsFromMultiSweeps, MonoDet3DInferencerLoader, MultiModalityDet3DInferencerLoader, NormalizePointsColor, PointSegClassMapping) from .test_time_aug import MultiScaleFlipAug3D diff --git a/mmdet3d/models/segmentors/__init__.py b/mmdet3d/models/segmentors/__init__.py index 5e43985d82..ce0a555c35 100644 --- a/mmdet3d/models/segmentors/__init__.py +++ b/mmdet3d/models/segmentors/__init__.py @@ -3,5 +3,9 @@ from .cylinder3d import Cylinder3D from .encoder_decoder import EncoderDecoder3D from .minkunet import MinkUNet +from .seg3d_tta import Seg3DTTAModel -__all__ = ['Base3DSegmentor', 'EncoderDecoder3D', 'Cylinder3D', 'MinkUNet'] +__all__ = [ + 'Base3DSegmentor', 'EncoderDecoder3D', 'Cylinder3D', 'MinkUNet', + 'Seg3DTTAModel' +] diff --git a/mmdet3d/models/segmentors/base.py b/mmdet3d/models/segmentors/base.py index e881c2aadb..f68c2a3799 100644 --- a/mmdet3d/models/segmentors/base.py +++ b/mmdet3d/models/segmentors/base.py @@ -132,17 +132,12 @@ def _forward(self, """ pass - @abstractmethod - def aug_test(self, batch_inputs, batch_data_samples): - """Placeholder for augmentation test.""" - pass - - def postprocess_result(self, seg_pred_list: List[dict], + def postprocess_result(self, seg_logits_list: List[Tensor], batch_data_samples: SampleList) -> SampleList: """Convert results list to `Det3DDataSample`. Args: - seg_logits_list (List[dict]): List of segmentation results, + seg_logits_list (List[Tensor]): List of segmentation results, seg_logits from model of each input point clouds sample. batch_data_samples (List[:obj:`Det3DDataSample`]): The det3d data samples. It usually includes information such as `metainfo` and @@ -152,12 +147,19 @@ def postprocess_result(self, seg_pred_list: List[dict], List[:obj:`Det3DDataSample`]: Segmentation results of the input points. Each Det3DDataSample usually contains: - - ``pred_pts_seg`` (PixelData): Prediction of 3D semantic + - ``pred_pts_seg`` (PointData): Prediction of 3D semantic segmentation. + - ``pts_seg_logits`` (PointData): Predicted logits of 3D semantic + segmentation before normalization. """ - for i in range(len(seg_pred_list)): - seg_pred = seg_pred_list[i] - batch_data_samples[i].set_data( - {'pred_pts_seg': PointData(**{'pts_semantic_mask': seg_pred})}) + for i in range(len(seg_logits_list)): + seg_logits = seg_logits_list[i] + seg_pred = seg_logits.argmax(dim=0) + batch_data_samples[i].set_data({ + 'pts_seg_logits': + PointData(**{'pts_seg_logits': seg_logits}), + 'pred_pts_seg': + PointData(**{'pts_semantic_mask': seg_pred}) + }) return batch_data_samples diff --git a/mmdet3d/models/segmentors/cylinder3d.py b/mmdet3d/models/segmentors/cylinder3d.py index b126607e6f..d4177dd60f 100644 --- a/mmdet3d/models/segmentors/cylinder3d.py +++ b/mmdet3d/models/segmentors/cylinder3d.py @@ -127,16 +127,18 @@ def predict(self, List[:obj:`Det3DDataSample`]: Segmentation results of the input points. Each Det3DDataSample usually contains: - - ``pred_pts_seg`` (PixelData): Prediction of 3D semantic + - ``pred_pts_seg`` (PointData): Prediction of 3D semantic segmentation. + - ``pts_seg_logits`` (PointData): Predicted logits of 3D semantic + segmentation before normalization. """ # 3D segmentation requires per-point prediction, so it's impossible # to use down-sampling to get a batch of scenes with same num_points # therefore, we only support testing one scene every time x = self.extract_feat(batch_inputs_dict) - seg_pred_list = self.decode_head.predict(x, batch_inputs_dict, - batch_data_samples) - for i in range(len(seg_pred_list)): - seg_pred_list[i] = seg_pred_list[i].argmax(1).cpu() + seg_logits_list = self.decode_head.predict(x, batch_inputs_dict, + batch_data_samples) + for i in range(len(seg_logits_list)): + seg_logits_list[i] = seg_logits_list[i].transpose(0, 1) - return self.postprocess_result(seg_pred_list, batch_data_samples) + return self.postprocess_result(seg_logits_list, batch_data_samples) diff --git a/mmdet3d/models/segmentors/encoder_decoder.py b/mmdet3d/models/segmentors/encoder_decoder.py index 8554dc82fc..168d0e65a5 100644 --- a/mmdet3d/models/segmentors/encoder_decoder.py +++ b/mmdet3d/models/segmentors/encoder_decoder.py @@ -5,7 +5,6 @@ import torch from torch import Tensor from torch import nn as nn -from torch.nn import functional as F from mmdet3d.registry import MODELS from mmdet3d.utils import ConfigType, OptConfigType, OptMultiConfig @@ -477,8 +476,7 @@ def inference(self, points: Tensor, batch_input_metas: List[dict], else: seg_logit = self.whole_inference(points, batch_input_metas, rescale) - output = F.softmax(seg_logit, dim=1) - return output + return seg_logit def predict(self, batch_inputs_dict: dict, @@ -503,27 +501,26 @@ def predict(self, List[:obj:`Det3DDataSample`]: Segmentation results of the input points. Each Det3DDataSample usually contains: - - ``pred_pts_seg`` (PixelData): Prediction of 3D semantic + - ``pred_pts_seg`` (PointData): Prediction of 3D semantic segmentation. + - ``pts_seg_logits`` (PointData): Predicted logits of 3D semantic + segmentation before normalization. """ # 3D segmentation requires per-point prediction, so it's impossible # to use down-sampling to get a batch of scenes with same num_points # therefore, we only support testing one scene every time - seg_pred_list = [] + seg_logits_list = [] batch_input_metas = [] for data_sample in batch_data_samples: batch_input_metas.append(data_sample.metainfo) points = batch_inputs_dict['points'] for point, input_meta in zip(points, batch_input_metas): - seg_prob = self.inference( + seg_logits = self.inference( point.unsqueeze(0), [input_meta], rescale)[0] - seg_map = seg_prob.argmax(0) # [N] - # to cpu tensor for consistency with det3d - seg_map = seg_map.cpu() - seg_pred_list.append(seg_map) + seg_logits_list.append(seg_logits) - return self.postprocess_result(seg_pred_list, batch_data_samples) + return self.postprocess_result(seg_logits_list, batch_data_samples) def _forward(self, batch_inputs_dict: dict, @@ -546,7 +543,3 @@ def _forward(self, points = torch.stack(batch_inputs_dict['points']) x = self.extract_feat(points) return self.decode_head.forward(x) - - def aug_test(self, batch_inputs, batch_img_metas): - """Placeholder for augmentation test.""" - pass diff --git a/mmdet3d/models/segmentors/minkunet.py b/mmdet3d/models/segmentors/minkunet.py index fcf8c22421..b40708cda2 100644 --- a/mmdet3d/models/segmentors/minkunet.py +++ b/mmdet3d/models/segmentors/minkunet.py @@ -50,7 +50,8 @@ def loss(self, inputs: dict, data_samples: SampleList): losses = self.decode_head.loss(x, data_samples, self.train_cfg) return losses - def predict(self, inputs: dict, data_samples: SampleList) -> SampleList: + def predict(self, inputs: dict, + batch_data_samples: SampleList) -> SampleList: """Simple test with single scene. Args: @@ -67,14 +68,17 @@ def predict(self, inputs: dict, data_samples: SampleList) -> SampleList: List[:obj:`Det3DDataSample`]: Segmentation results of the input points. Each Det3DDataSample usually contains: - - ``pred_pts_seg`` (PixelData): Prediction of 3D semantic + - ``pred_pts_seg`` (PointData): Prediction of 3D semantic segmentation. + - ``pts_seg_logits`` (PointData): Predicted logits of 3D semantic + segmentation before normalization. """ x = self.extract_feat(inputs) - seg_logits = self.decode_head.predict(x, data_samples) - seg_preds = [seg_logit.argmax(dim=1) for seg_logit in seg_logits] + seg_logits_list = self.decode_head.predict(x, batch_data_samples) + for i in range(len(seg_logits_list)): + seg_logits_list[i] = seg_logits_list[i].transpose(0, 1) - return self.postprocess_result(seg_preds, data_samples) + return self.postprocess_result(seg_logits_list, batch_data_samples) def _forward(self, batch_inputs_dict: dict, diff --git a/mmdet3d/models/segmentors/seg3d_tta.py b/mmdet3d/models/segmentors/seg3d_tta.py new file mode 100644 index 0000000000..be93562ff7 --- /dev/null +++ b/mmdet3d/models/segmentors/seg3d_tta.py @@ -0,0 +1,35 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List + +import torch +from mmengine.model import BaseTTAModel + +from mmdet3d.registry import MODELS +from mmdet3d.structures.det3d_data_sample import SampleList + + +@MODELS.register_module() +class Seg3DTTAModel(BaseTTAModel): + + def merge_preds(self, data_samples_list: List[SampleList]) -> SampleList: + """Merge predictions of enhanced data to one prediction. + + Args: + data_samples_list (List[List[:obj:`Det3DDataSample`]]): List of + predictions of all enhanced data. + + Returns: + List[:obj:`Det3DDataSample`]: Merged prediction. + """ + predictions = [] + for data_samples in data_samples_list: + seg_logits = data_samples[0].pts_seg_logits.pts_seg_logits + logits = torch.zeros(seg_logits.shape).to(seg_logits) + for data_sample in data_samples: + seg_logit = data_sample.pts_seg_logits.pts_seg_logits + logits += seg_logit.softmax(dim=0) + logits /= len(data_samples) + seg_pred = logits.argmax(dim=0) + data_samples[0].pred_pts_seg.pts_semantic_mask = seg_pred + predictions.append(data_samples[0]) + return predictions diff --git a/tests/test_datasets/test_tta.py b/tests/test_datasets/test_tta.py new file mode 100644 index 0000000000..fa191cba96 --- /dev/null +++ b/tests/test_datasets/test_tta.py @@ -0,0 +1,210 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from unittest import TestCase + +import numpy as np +import pytest +from mmengine import DefaultScope + +from mmdet3d.datasets.transforms import * # noqa +from mmdet3d.registry import TRANSFORMS +from mmdet3d.structures.points import LiDARPoints + +DefaultScope.get_instance('test_multi_scale_flip_aug_3d', scope_name='mmdet3d') + + +class TestMuitiScaleFlipAug3D(TestCase): + + def test_exception(self): + with pytest.raises(TypeError): + tta_transform = dict( + type='TestTimeAug', + transforms=[ + dict( + type='RandomFlip3D', + flip_ratio_bev_horizontal=0.0, + flip_ratio_bev_vertical=0.0) + ]) + TRANSFORMS.build(tta_transform) + + def test_multi_scale_flip_aug(self): + tta_transform = dict( + type='TestTimeAug', + transforms=[[ + dict( + type='RandomFlip3D', + flip_ratio_bev_horizontal=0.0, + flip_ratio_bev_vertical=0.0), + dict( + type='RandomFlip3D', + flip_ratio_bev_horizontal=0.0, + flip_ratio_bev_vertical=1.0), + dict( + type='RandomFlip3D', + flip_ratio_bev_horizontal=1.0, + flip_ratio_bev_vertical=0.0), + dict( + type='RandomFlip3D', + flip_ratio_bev_horizontal=1.0, + flip_ratio_bev_vertical=1.0) + ], [dict(type='Pack3DDetInputs', keys=['points'])]]) + tta_module = TRANSFORMS.build(tta_transform) + + results = dict() + points = LiDARPoints(np.random.random((100, 4)), 4) + results['points'] = points + + tta_results = tta_module(results.copy()) + assert [ + data_sample.metainfo['pcd_horizontal_flip'] + for data_sample in tta_results['data_samples'] + ] == [False, False, True, True] + assert [ + data_sample.metainfo['pcd_vertical_flip'] + for data_sample in tta_results['data_samples'] + ] == [False, True, False, True] + + tta_transform = dict( + type='TestTimeAug', + transforms=[[ + dict( + type='GlobalRotScaleTrans', + rot_range=[-0.78539816, -0.78539816], + scale_ratio_range=[1.0, 1.0], + translation_std=[0, 0, 0]), + dict( + type='GlobalRotScaleTrans', + rot_range=[0, 0], + scale_ratio_range=[1.0, 1.0], + translation_std=[0, 0, 0]), + dict( + type='GlobalRotScaleTrans', + rot_range=[0.78539816, 0.78539816], + scale_ratio_range=[1.0, 1.0], + translation_std=[0, 0, 0]) + ], [dict(type='Pack3DDetInputs', keys=['points'])]]) + tta_module = TRANSFORMS.build(tta_transform) + + results = dict() + points = LiDARPoints(np.random.random((100, 4)), 4) + results['points'] = points + + tta_results = tta_module(results.copy()) + assert [ + data_sample.metainfo['pcd_rotation_angle'] + for data_sample in tta_results['data_samples'] + ] == [-0.78539816, 0, 0.78539816] + assert [ + data_sample.metainfo['pcd_scale_factor'] + for data_sample in tta_results['data_samples'] + ] == [1.0, 1.0, 1.0] + + tta_transform = dict( + type='TestTimeAug', + transforms=[[ + dict( + type='GlobalRotScaleTrans', + rot_range=[0, 0], + scale_ratio_range=[0.95, 0.95], + translation_std=[0, 0, 0]), + dict( + type='GlobalRotScaleTrans', + rot_range=[0, 0], + scale_ratio_range=[1.0, 1.0], + translation_std=[0, 0, 0]), + dict( + type='GlobalRotScaleTrans', + rot_range=[0, 0], + scale_ratio_range=[1.05, 1.05], + translation_std=[0, 0, 0]) + ], [dict(type='Pack3DDetInputs', keys=['points'])]]) + tta_module = TRANSFORMS.build(tta_transform) + + results = dict() + points = LiDARPoints(np.random.random((100, 4)), 4) + results['points'] = points + + tta_results = tta_module(results.copy()) + assert [ + data_sample.metainfo['pcd_rotation_angle'] + for data_sample in tta_results['data_samples'] + ] == [0, 0, 0] + assert [ + data_sample.metainfo['pcd_scale_factor'] + for data_sample in tta_results['data_samples'] + ] == [0.95, 1, 1.05] + + tta_transform = dict( + type='TestTimeAug', + transforms=[ + [ + dict( + type='RandomFlip3D', + flip_ratio_bev_horizontal=0.0, + flip_ratio_bev_vertical=0.0), + dict( + type='RandomFlip3D', + flip_ratio_bev_horizontal=0.0, + flip_ratio_bev_vertical=1.0), + dict( + type='RandomFlip3D', + flip_ratio_bev_horizontal=1.0, + flip_ratio_bev_vertical=0.0), + dict( + type='RandomFlip3D', + flip_ratio_bev_horizontal=1.0, + flip_ratio_bev_vertical=1.0) + ], + [ + dict( + type='GlobalRotScaleTrans', + rot_range=[pcd_rotate_range, pcd_rotate_range], + scale_ratio_range=[pcd_scale_factor, pcd_scale_factor], + translation_std=[0, 0, 0]) + for pcd_rotate_range in [-0.78539816, 0.0, 0.78539816] + for pcd_scale_factor in [0.95, 1.0, 1.05] + ], [dict(type='Pack3DDetInputs', keys=['points'])] + ]) + tta_module = TRANSFORMS.build(tta_transform) + + results = dict() + points = LiDARPoints(np.random.random((100, 4)), 4) + results['points'] = points + + tta_results = tta_module(results.copy()) + assert [ + data_sample.metainfo['pcd_horizontal_flip'] + for data_sample in tta_results['data_samples'] + ] == [ + False, False, False, False, False, False, False, False, False, + False, False, False, False, False, False, False, False, False, + True, True, True, True, True, True, True, True, True, True, True, + True, True, True, True, True, True, True + ] + assert [ + data_sample.metainfo['pcd_vertical_flip'] + for data_sample in tta_results['data_samples'] + ] == [ + False, False, False, False, False, False, False, False, False, + True, True, True, True, True, True, True, True, True, False, False, + False, False, False, False, False, False, False, True, True, True, + True, True, True, True, True, True + ] + assert [ + data_sample.metainfo['pcd_rotation_angle'] + for data_sample in tta_results['data_samples'] + ] == [ + -0.78539816, -0.78539816, -0.78539816, 0.0, 0.0, 0.0, 0.78539816, + 0.78539816, 0.78539816, -0.78539816, -0.78539816, -0.78539816, 0.0, + 0.0, 0.0, 0.78539816, 0.78539816, 0.78539816, -0.78539816, + -0.78539816, -0.78539816, 0.0, 0.0, 0.0, 0.78539816, 0.78539816, + 0.78539816, -0.78539816, -0.78539816, -0.78539816, 0.0, 0.0, 0.0, + 0.78539816, 0.78539816, 0.78539816 + ] + assert [ + data_sample.metainfo['pcd_scale_factor'] + for data_sample in tta_results['data_samples'] + ] == [ + 0.95, 1.0, 1.05, 0.95, 1.0, 1.05, 0.95, 1.0, 1.05, 0.95, 1.0, 1.05, + 0.95, 1.0, 1.05, 0.95, 1.0, 1.05, 0.95, 1.0, 1.05, 0.95, 1.0, 1.05, + 0.95, 1.0, 1.05, 0.95, 1.0, 1.05, 0.95, 1.0, 1.05, 0.95, 1.0, 1.05 + ] diff --git a/tests/test_models/test_segmentors/test_cylinder3d.py b/tests/test_models/test_segmentors/test_cylinder3d.py index 084918f54b..3c0c8ba3c4 100644 --- a/tests/test_models/test_segmentors/test_cylinder3d.py +++ b/tests/test_models/test_segmentors/test_cylinder3d.py @@ -1,3 +1,4 @@ +# Copyright (c) OpenMMLab. All rights reserved. import unittest import torch diff --git a/tests/test_models/test_segmentor/test_minkunet.py b/tests/test_models/test_segmentors/test_minkunet.py similarity index 100% rename from tests/test_models/test_segmentor/test_minkunet.py rename to tests/test_models/test_segmentors/test_minkunet.py diff --git a/tests/test_models/test_segmentors/test_seg3d_tta_model.py b/tests/test_models/test_segmentors/test_seg3d_tta_model.py new file mode 100644 index 0000000000..b6a02f7bd7 --- /dev/null +++ b/tests/test_models/test_segmentors/test_seg3d_tta_model.py @@ -0,0 +1,40 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from unittest import TestCase + +import torch +from mmengine import ConfigDict, DefaultScope + +from mmdet3d.models import Seg3DTTAModel +from mmdet3d.registry import MODELS +from mmdet3d.structures import Det3DDataSample +from mmdet3d.testing import get_detector_cfg + + +class TestSeg3DTTAModel(TestCase): + + def test_seg3d_tta_model(self): + import mmdet3d.models + + assert hasattr(mmdet3d.models, 'Cylinder3D') + DefaultScope.get_instance('test_cylinder3d', scope_name='mmdet3d') + segmentor3d_cfg = get_detector_cfg( + 'cylinder3d/cylinder3d_4xb4_3x_semantickitti.py') + cfg = ConfigDict(type='Seg3DTTAModel', module=segmentor3d_cfg) + + model: Seg3DTTAModel = MODELS.build(cfg) + + points = [] + data_samples = [] + pcd_horizontal_flip_list = [False, False, True, True] + pcd_vertical_flip_list = [False, True, False, True] + for i in range(4): + points.append({'points': [torch.randn(200, 4)]}) + data_samples.append([ + Det3DDataSample( + metainfo=dict( + pcd_horizontal_flip=pcd_horizontal_flip_list[i], + pcd_vertical_flip=pcd_vertical_flip_list[i])) + ]) + if torch.cuda.is_available(): + model.eval() + model.test_step(dict(inputs=points, data_samples=data_samples)) diff --git a/tools/test.py b/tools/test.py index d49477cfd8..7760cfd8c6 100644 --- a/tools/test.py +++ b/tools/test.py @@ -3,7 +3,7 @@ import os import os.path as osp -from mmengine.config import Config, DictAction +from mmengine.config import Config, ConfigDict, DictAction from mmengine.registry import RUNNERS from mmengine.runner import Runner @@ -53,6 +53,8 @@ def parse_args(): choices=['none', 'pytorch', 'slurm', 'mpi'], default='none', help='job launcher') + parser.add_argument( + '--tta', action='store_true', help='Test time augmentation') parser.add_argument('--local_rank', type=int, default=0) args = parser.parse_args() if 'LOCAL_RANK' not in os.environ: @@ -109,6 +111,14 @@ def main(): if args.show or args.show_dir: cfg = trigger_visualization_hook(cfg, args) + if args.tta: + # Currently, we only support tta for 3D segmentation + # TODO: Support tta for 3D detection + assert 'tta_model' in cfg, 'Cannot find ``tta_model`` in config.' + assert 'tta_pipeline' in cfg, 'Cannot find ``tta_pipeline`` in config.' + cfg.test_dataloader.dataset.pipeline = cfg.tta_pipeline + cfg.model = ConfigDict(**cfg.tta_model, module=cfg.model) + # build the runner from config if 'runner_type' not in cfg: # build the default runner