Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix] Seg tta UT & p2v_map bug when gt is none #2466

Merged
merged 3 commits into from
Apr 26, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 13 additions & 9 deletions mmdet3d/models/data_preprocessors/data_preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def __init__(self,
voxel: bool = False,
voxel_type: str = 'hard',
voxel_layer: OptConfigType = None,
max_voxels: int = 80000,
sunjiahao1999 marked this conversation as resolved.
Show resolved Hide resolved
mean: Sequence[Number] = None,
std: Sequence[Number] = None,
pad_size_divisor: int = 1,
Expand All @@ -103,6 +104,7 @@ def __init__(self,
batch_augments=batch_augments)
self.voxel = voxel
self.voxel_type = voxel_type
self.max_voxels = max_voxels
if voxel:
self.voxel_layer = VoxelizationByGridShape(**voxel_layer)

Expand Down Expand Up @@ -423,20 +425,22 @@ def voxelize(self, points: List[torch.Tensor],
res_coors -= res_coors.min(0)[0]

res_coors_numpy = res_coors.cpu().numpy()
inds, voxel2point_map = self.sparse_quantize(
inds, point2voxel_map = self.sparse_quantize(
res_coors_numpy, return_index=True, return_inverse=True)
voxel2point_map = torch.from_numpy(voxel2point_map).cuda()
point2voxel_map = torch.from_numpy(point2voxel_map).cuda()
if self.training:
if len(inds) > 80000:
inds = np.random.choice(inds, 80000, replace=False)
if len(inds) > self.max_voxels:
inds = np.random.choice(
inds, self.max_voxels, replace=False)
inds = torch.from_numpy(inds).cuda()
data_sample.gt_pts_seg.voxel_semantic_mask \
= data_sample.gt_pts_seg.pts_semantic_mask[inds]
if hasattr(data_sample.gt_pts_seg, 'pts_semantic_mask'):
data_sample.gt_pts_seg.voxel_semantic_mask \
= data_sample.gt_pts_seg.pts_semantic_mask[inds]
res_voxel_coors = res_coors[inds]
res_voxels = res[inds]
res_voxel_coors = F.pad(
res_voxel_coors, (0, 1), mode='constant', value=i)
data_sample.voxel2point_map = voxel2point_map.long()
data_sample.point2voxel_map = point2voxel_map.long()
voxels.append(res_voxels)
coors.append(res_voxel_coors)
voxels = torch.cat(voxels, dim=0)
Expand Down Expand Up @@ -466,12 +470,12 @@ def get_voxel_seg(self, res_coors: torch.Tensor, data_sample: SampleList):
True)
voxel_semantic_mask = torch.argmax(voxel_semantic_mask, dim=-1)
data_sample.gt_pts_seg.voxel_semantic_mask = voxel_semantic_mask
data_sample.gt_pts_seg.point2voxel_map = point2voxel_map
data_sample.point2voxel_map = point2voxel_map
else:
pseudo_tensor = res_coors.new_ones([res_coors.shape[0], 1]).float()
_, _, point2voxel_map = dynamic_scatter_3d(pseudo_tensor,
res_coors, 'mean', True)
data_sample.gt_pts_seg.point2voxel_map = point2voxel_map
data_sample.point2voxel_map = point2voxel_map

def ravel_hash(self, x: np.ndarray) -> np.ndarray:
"""Get voxel coordinates hash for np.unique().
Expand Down
2 changes: 1 addition & 1 deletion mmdet3d/models/decode_heads/cylinder3d_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def predict(
for batch_idx in range(len(batch_data_samples)):
seg_logits_sample = seg_logits[coors[:, 0] == batch_idx]
point2voxel_map = batch_data_samples[
batch_idx].gt_pts_seg.point2voxel_map.long()
batch_idx].point2voxel_map.long()
point_seg_predicts = seg_logits_sample[point2voxel_map]
seg_pred_list.append(point_seg_predicts)

Expand Down
2 changes: 1 addition & 1 deletion mmdet3d/models/decode_heads/minkunet_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def predict(self, inputs: SparseTensor,
seg_logit_list = []
for i, data_sample in enumerate(batch_data_samples):
seg_logit = seg_logits[batch_idx == i]
seg_logit = seg_logit[data_sample.voxel2point_map]
seg_logit = seg_logit[data_sample.point2voxel_map]
seg_logit_list.append(seg_logit)

return seg_logit_list
Expand Down
3 changes: 1 addition & 2 deletions tests/test_models/test_decode_heads/test_cylinder3d_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,7 @@ def test_cylinder3d_head_loss(self):
self.assertGreater(loss_lovasz, 0, 'lovasz loss should be positive')

batch_inputs_dict = dict(voxels=dict(voxel_coors=coors))
datasample.gt_pts_seg.point2voxel_map = torch.randint(
0, 50, (100, )).int().cuda()
datasample.point2voxel_map = torch.randint(0, 50, (100, )).int().cuda()
point_logits = cylinder3d_head.predict(sparse_voxels,
batch_inputs_dict, [datasample])
assert point_logits[0].shape == torch.Size([100, 20])
2 changes: 1 addition & 1 deletion tests/test_models/test_segmentors/test_seg3d_tta_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,5 +36,5 @@ def test_seg3d_tta_model(self):
pcd_vertical_flip=pcd_vertical_flip_list[i]))
])
if torch.cuda.is_available():
model.eval()
model.eval().cuda()
model.test_step(dict(inputs=points, data_samples=data_samples))