From 2cdc91a9576710f48a452573f25a10e25d24323e Mon Sep 17 00:00:00 2001 From: sunjiahao1999 <578431509@qq.com> Date: Tue, 25 Apr 2023 17:28:01 +0800 Subject: [PATCH 1/3] fix p2v_map when no gt & unit name --- mmdet3d/models/data_preprocessors/data_preprocessor.py | 10 +++++----- mmdet3d/models/decode_heads/cylinder3d_head.py | 2 +- mmdet3d/models/decode_heads/minkunet_head.py | 2 +- .../test_decode_heads/test_cylinder3d_head.py | 3 +-- .../test_segmentors/test_seg3d_tta_model.py | 2 +- 5 files changed, 9 insertions(+), 10 deletions(-) diff --git a/mmdet3d/models/data_preprocessors/data_preprocessor.py b/mmdet3d/models/data_preprocessors/data_preprocessor.py index 85286c9b62..87f6d155ce 100644 --- a/mmdet3d/models/data_preprocessors/data_preprocessor.py +++ b/mmdet3d/models/data_preprocessors/data_preprocessor.py @@ -423,9 +423,9 @@ def voxelize(self, points: List[torch.Tensor], res_coors -= res_coors.min(0)[0] res_coors_numpy = res_coors.cpu().numpy() - inds, voxel2point_map = self.sparse_quantize( + inds, point2voxel_map = self.sparse_quantize( res_coors_numpy, return_index=True, return_inverse=True) - voxel2point_map = torch.from_numpy(voxel2point_map).cuda() + point2voxel_map = torch.from_numpy(point2voxel_map).cuda() if self.training: if len(inds) > 80000: inds = np.random.choice(inds, 80000, replace=False) @@ -436,7 +436,7 @@ def voxelize(self, points: List[torch.Tensor], res_voxels = res[inds] res_voxel_coors = F.pad( res_voxel_coors, (0, 1), mode='constant', value=i) - data_sample.voxel2point_map = voxel2point_map.long() + data_sample.point2voxel_map = point2voxel_map.long() voxels.append(res_voxels) coors.append(res_voxel_coors) voxels = torch.cat(voxels, dim=0) @@ -466,12 +466,12 @@ def get_voxel_seg(self, res_coors: torch.Tensor, data_sample: SampleList): True) voxel_semantic_mask = torch.argmax(voxel_semantic_mask, dim=-1) data_sample.gt_pts_seg.voxel_semantic_mask = voxel_semantic_mask - data_sample.gt_pts_seg.point2voxel_map = point2voxel_map + data_sample.point2voxel_map = point2voxel_map else: pseudo_tensor = res_coors.new_ones([res_coors.shape[0], 1]).float() _, _, point2voxel_map = dynamic_scatter_3d(pseudo_tensor, res_coors, 'mean', True) - data_sample.gt_pts_seg.point2voxel_map = point2voxel_map + data_sample.point2voxel_map = point2voxel_map def ravel_hash(self, x: np.ndarray) -> np.ndarray: """Get voxel coordinates hash for np.unique(). diff --git a/mmdet3d/models/decode_heads/cylinder3d_head.py b/mmdet3d/models/decode_heads/cylinder3d_head.py index 26c621c5ba..1672a0f209 100644 --- a/mmdet3d/models/decode_heads/cylinder3d_head.py +++ b/mmdet3d/models/decode_heads/cylinder3d_head.py @@ -151,7 +151,7 @@ def predict( for batch_idx in range(len(batch_data_samples)): seg_logits_sample = seg_logits[coors[:, 0] == batch_idx] point2voxel_map = batch_data_samples[ - batch_idx].gt_pts_seg.point2voxel_map.long() + batch_idx].point2voxel_map.long() point_seg_predicts = seg_logits_sample[point2voxel_map] seg_pred_list.append(point_seg_predicts) diff --git a/mmdet3d/models/decode_heads/minkunet_head.py b/mmdet3d/models/decode_heads/minkunet_head.py index 97d8fdf59f..e94c25b610 100644 --- a/mmdet3d/models/decode_heads/minkunet_head.py +++ b/mmdet3d/models/decode_heads/minkunet_head.py @@ -61,7 +61,7 @@ def predict(self, inputs: SparseTensor, seg_logit_list = [] for i, data_sample in enumerate(batch_data_samples): seg_logit = seg_logits[batch_idx == i] - seg_logit = seg_logit[data_sample.voxel2point_map] + seg_logit = seg_logit[data_sample.point2voxel_map] seg_logit_list.append(seg_logit) return seg_logit_list diff --git a/tests/test_models/test_decode_heads/test_cylinder3d_head.py b/tests/test_models/test_decode_heads/test_cylinder3d_head.py index 3bb62c5eef..c8fae827e8 100644 --- a/tests/test_models/test_decode_heads/test_cylinder3d_head.py +++ b/tests/test_models/test_decode_heads/test_cylinder3d_head.py @@ -60,8 +60,7 @@ def test_cylinder3d_head_loss(self): self.assertGreater(loss_lovasz, 0, 'lovasz loss should be positive') batch_inputs_dict = dict(voxels=dict(voxel_coors=coors)) - datasample.gt_pts_seg.point2voxel_map = torch.randint( - 0, 50, (100, )).int().cuda() + datasample.point2voxel_map = torch.randint(0, 50, (100, )).int().cuda() point_logits = cylinder3d_head.predict(sparse_voxels, batch_inputs_dict, [datasample]) assert point_logits[0].shape == torch.Size([100, 20]) diff --git a/tests/test_models/test_segmentors/test_seg3d_tta_model.py b/tests/test_models/test_segmentors/test_seg3d_tta_model.py index b6a02f7bd7..e0ebf026ec 100644 --- a/tests/test_models/test_segmentors/test_seg3d_tta_model.py +++ b/tests/test_models/test_segmentors/test_seg3d_tta_model.py @@ -36,5 +36,5 @@ def test_seg3d_tta_model(self): pcd_vertical_flip=pcd_vertical_flip_list[i])) ]) if torch.cuda.is_available(): - model.eval() + model.eval().cuda() model.test_step(dict(inputs=points, data_samples=data_samples)) From b900f2de32cbef0a6e742cedceac25c29ee14f95 Mon Sep 17 00:00:00 2001 From: sunjiahao1999 <578431509@qq.com> Date: Tue, 25 Apr 2023 18:50:37 +0800 Subject: [PATCH 2/3] fix gt bug --- .../models/data_preprocessors/data_preprocessor.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/mmdet3d/models/data_preprocessors/data_preprocessor.py b/mmdet3d/models/data_preprocessors/data_preprocessor.py index 87f6d155ce..2483dc783b 100644 --- a/mmdet3d/models/data_preprocessors/data_preprocessor.py +++ b/mmdet3d/models/data_preprocessors/data_preprocessor.py @@ -77,6 +77,7 @@ def __init__(self, voxel: bool = False, voxel_type: str = 'hard', voxel_layer: OptConfigType = None, + max_voxels: int = 80000, mean: Sequence[Number] = None, std: Sequence[Number] = None, pad_size_divisor: int = 1, @@ -103,6 +104,7 @@ def __init__(self, batch_augments=batch_augments) self.voxel = voxel self.voxel_type = voxel_type + self.max_voxels = max_voxels if voxel: self.voxel_layer = VoxelizationByGridShape(**voxel_layer) @@ -427,11 +429,13 @@ def voxelize(self, points: List[torch.Tensor], res_coors_numpy, return_index=True, return_inverse=True) point2voxel_map = torch.from_numpy(point2voxel_map).cuda() if self.training: - if len(inds) > 80000: - inds = np.random.choice(inds, 80000, replace=False) + if len(inds) > self.max_voxels: + inds = np.random.choice( + inds, self.max_voxels, replace=False) inds = torch.from_numpy(inds).cuda() - data_sample.gt_pts_seg.voxel_semantic_mask \ - = data_sample.gt_pts_seg.pts_semantic_mask[inds] + if hasattr(data_sample.gt_pts_seg, 'pts_semantic_mask'): + data_sample.gt_pts_seg.voxel_semantic_mask \ + = data_sample.gt_pts_seg.pts_semantic_mask[inds] res_voxel_coors = res_coors[inds] res_voxels = res[inds] res_voxel_coors = F.pad( From 8e6c33eb0aaf9c2f40287e27d78aa4dd7a554363 Mon Sep 17 00:00:00 2001 From: sunjiahao1999 <578431509@qq.com> Date: Wed, 26 Apr 2023 13:59:06 +0800 Subject: [PATCH 3/3] fix max_voxels --- configs/_base_/models/minkunet.py | 7 +++---- configs/_base_/models/spvcnn.py | 4 ++-- mmdet3d/models/data_preprocessors/data_preprocessor.py | 6 ++++-- 3 files changed, 9 insertions(+), 8 deletions(-) diff --git a/configs/_base_/models/minkunet.py b/configs/_base_/models/minkunet.py index 0a691d876d..665aa7ac90 100644 --- a/configs/_base_/models/minkunet.py +++ b/configs/_base_/models/minkunet.py @@ -9,15 +9,14 @@ point_cloud_range=[-100, -100, -20, 100, 100, 20], voxel_size=[0.05, 0.05, 0.05], max_voxels=(-1, -1)), - ), + max_voxels=80000), backbone=dict( type='MinkUNetBackbone', in_channels=4, base_channels=32, - encoder_channels=[32, 64, 128, 256], - decoder_channels=[256, 128, 96, 96], num_stages=4, - init_cfg=None), + encoder_channels=[32, 64, 128, 256], + decoder_channels=[256, 128, 96, 96]), decode_head=dict( type='MinkUNetHead', channels=96, diff --git a/configs/_base_/models/spvcnn.py b/configs/_base_/models/spvcnn.py index 335407d89e..fa40a37be4 100644 --- a/configs/_base_/models/spvcnn.py +++ b/configs/_base_/models/spvcnn.py @@ -9,14 +9,14 @@ point_cloud_range=[-100, -100, -20, 100, 100, 20], voxel_size=[0.05, 0.05, 0.05], max_voxels=(-1, -1)), - ), + max_voxels=80000), backbone=dict( type='SPVCNNBackbone', in_channels=4, base_channels=32, + num_stages=4, encoder_channels=[32, 64, 128, 256], decoder_channels=[256, 128, 96, 96], - num_stages=4, drop_ratio=0.3), decode_head=dict( type='MinkUNetHead', diff --git a/mmdet3d/models/data_preprocessors/data_preprocessor.py b/mmdet3d/models/data_preprocessors/data_preprocessor.py index 2483dc783b..248ed7a699 100644 --- a/mmdet3d/models/data_preprocessors/data_preprocessor.py +++ b/mmdet3d/models/data_preprocessors/data_preprocessor.py @@ -49,6 +49,8 @@ class Det3DDataPreprocessor(DetDataPreprocessor): voxelization and dynamic voxelization. Defaults to 'hard'. voxel_layer (dict or :obj:`ConfigDict`, optional): Voxelization layer config. Defaults to None. + max_voxels (int): Maximum number of voxels in each voxel grid. Defaults + to None. mean (Sequence[Number], optional): The pixel mean of R, G, B channels. Defaults to None. std (Sequence[Number], optional): The pixel standard deviation of @@ -77,7 +79,7 @@ def __init__(self, voxel: bool = False, voxel_type: str = 'hard', voxel_layer: OptConfigType = None, - max_voxels: int = 80000, + max_voxels: Optional[int] = None, mean: Sequence[Number] = None, std: Sequence[Number] = None, pad_size_divisor: int = 1, @@ -428,7 +430,7 @@ def voxelize(self, points: List[torch.Tensor], inds, point2voxel_map = self.sparse_quantize( res_coors_numpy, return_index=True, return_inverse=True) point2voxel_map = torch.from_numpy(point2voxel_map).cuda() - if self.training: + if self.training and self.max_voxels is not None: if len(inds) > self.max_voxels: inds = np.random.choice( inds, self.max_voxels, replace=False)