diff --git a/mmdet3d/core/bbox/coders/__init__.py b/mmdet3d/core/bbox/coders/__init__.py index 0e44042212..e2a9d9f7ff 100644 --- a/mmdet3d/core/bbox/coders/__init__.py +++ b/mmdet3d/core/bbox/coders/__init__.py @@ -6,10 +6,11 @@ from .fcos3d_bbox_coder import FCOS3DBBoxCoder from .groupfree3d_bbox_coder import GroupFree3DBBoxCoder from .partial_bin_based_bbox_coder import PartialBinBasedBBoxCoder +from .pgd_bbox_coder import PGDBBoxCoder from .point_xyzwhlr_bbox_coder import PointXYZWHLRBBoxCoder __all__ = [ 'build_bbox_coder', 'DeltaXYZWLHRBBoxCoder', 'PartialBinBasedBBoxCoder', 'CenterPointBBoxCoder', 'AnchorFreeBBoxCoder', 'GroupFree3DBBoxCoder', - 'PointXYZWHLRBBoxCoder', 'FCOS3DBBoxCoder' + 'PointXYZWHLRBBoxCoder', 'FCOS3DBBoxCoder', 'PGDBBoxCoder' ] diff --git a/mmdet3d/core/bbox/coders/fcos3d_bbox_coder.py b/mmdet3d/core/bbox/coders/fcos3d_bbox_coder.py index 7245a57e09..ae90c5d60b 100644 --- a/mmdet3d/core/bbox/coders/fcos3d_bbox_coder.py +++ b/mmdet3d/core/bbox/coders/fcos3d_bbox_coder.py @@ -44,7 +44,7 @@ def decode(self, bbox, scale, stride, training, cls_score=None): bbox (torch.Tensor): Raw bounding box predictions in shape [N, C, H, W]. scale (tuple[`Scale`]): Learnable scale parameters. - stride (tuple[int]): Stride for a specific feature level. + stride (int): Stride for a specific feature level. training (bool): Whether the decoding is in the training procedure. cls_score (torch.Tensor): Classification score map for deciding diff --git a/mmdet3d/core/bbox/coders/pgd_bbox_coder.py b/mmdet3d/core/bbox/coders/pgd_bbox_coder.py new file mode 100644 index 0000000000..5e42128189 --- /dev/null +++ b/mmdet3d/core/bbox/coders/pgd_bbox_coder.py @@ -0,0 +1,125 @@ +import numpy as np +from torch.nn import functional as F + +from mmdet.core.bbox.builder import BBOX_CODERS +from .fcos3d_bbox_coder import FCOS3DBBoxCoder + + +@BBOX_CODERS.register_module() +class PGDBBoxCoder(FCOS3DBBoxCoder): + """Bounding box coder for PGD.""" + + def encode(self, gt_bboxes_3d, gt_labels_3d, gt_bboxes, gt_labels): + # TODO: refactor the encoder codes in the FCOS3D and PGD head + pass + + def decode_2d(self, + bbox, + scale, + stride, + max_regress_range, + training, + pred_keypoints=False, + pred_bbox2d=True): + """Decode regressed 2D attributes. + + Args: + bbox (torch.Tensor): Raw bounding box predictions in shape + [N, C, H, W]. + scale (tuple[`Scale`]): Learnable scale parameters. + stride (int): Stride for a specific feature level. + max_regress_range (int): Maximum regression range for a specific + feature level. + training (bool): Whether the decoding is in the training + procedure. + pred_keypoints (bool, optional): Whether to predict keypoints. + Defaults to False. + pred_bbox2d (bool, optional): Whether to predict 2D bounding + boxes. Defaults to False. + + Returns: + torch.Tensor: Decoded boxes. + """ + clone_bbox = bbox.clone() + if pred_keypoints: + scale_kpts = scale[3] + # 2 dimension of offsets x 8 corners of a 3D bbox + bbox[:, self.bbox_code_size:self.bbox_code_size + 16] = \ + scale_kpts(clone_bbox[ + :, self.bbox_code_size:self.bbox_code_size + 16]).float() + if pred_bbox2d: + scale_bbox2d = scale[-1] + # The last four dimensions are offsets to four sides of a 2D bbox + bbox[:, -4:] = scale_bbox2d(clone_bbox[:, -4:]).float() + + if self.norm_on_bbox: + if pred_bbox2d: + bbox[:, -4:] = F.relu(bbox.clone()[:, -4:]) + if not training: + if pred_keypoints: + bbox[ + :, self.bbox_code_size:self.bbox_code_size + 16] *= \ + max_regress_range + if pred_bbox2d: + bbox[:, -4:] *= stride + else: + if pred_bbox2d: + bbox[:, -4:] = bbox.clone()[:, -4:].exp() + return bbox + + def decode_prob_depth(self, depth_cls_preds, depth_range, depth_unit, + division, num_depth_cls): + """Decode probabilistic depth map. + + Args: + depth_cls_preds (torch.Tensor): Depth probabilistic map in shape + [..., self.num_depth_cls] (raw output before softmax). + depth_range (tuple[float]): Range of depth estimation. + depth_unit (int): Unit of depth range division. + division (str): Depth division method. Options include 'uniform', + 'linear', 'log', 'loguniform'. + num_depth_cls (int): Number of depth classes. + + Returns: + torch.Tensor: Decoded probabilistic depth estimation. + """ + if division == 'uniform': + depth_multiplier = depth_unit * \ + depth_cls_preds.new_tensor( + list(range(num_depth_cls))).reshape([1, -1]) + prob_depth_preds = (F.softmax(depth_cls_preds.clone(), dim=-1) * + depth_multiplier).sum(dim=-1) + return prob_depth_preds + elif division == 'linear': + split_pts = depth_cls_preds.new_tensor(list( + range(num_depth_cls))).reshape([1, -1]) + depth_multiplier = depth_range[0] + ( + depth_range[1] - depth_range[0]) / \ + (num_depth_cls * (num_depth_cls - 1)) * \ + (split_pts * (split_pts+1)) + prob_depth_preds = (F.softmax(depth_cls_preds.clone(), dim=-1) * + depth_multiplier).sum(dim=-1) + return prob_depth_preds + elif division == 'log': + split_pts = depth_cls_preds.new_tensor(list( + range(num_depth_cls))).reshape([1, -1]) + start = max(depth_range[0], 1) + end = depth_range[1] + depth_multiplier = (np.log(start) + + split_pts * np.log(end / start) / + (num_depth_cls - 1)).exp() + prob_depth_preds = (F.softmax(depth_cls_preds.clone(), dim=-1) * + depth_multiplier).sum(dim=-1) + return prob_depth_preds + elif division == 'loguniform': + split_pts = depth_cls_preds.new_tensor(list( + range(num_depth_cls))).reshape([1, -1]) + start = max(depth_range[0], 1) + end = depth_range[1] + log_multiplier = np.log(start) + \ + split_pts * np.log(end / start) / (num_depth_cls - 1) + prob_depth_preds = (F.softmax(depth_cls_preds.clone(), dim=-1) * + log_multiplier).sum(dim=-1).exp() + return prob_depth_preds + else: + raise NotImplementedError diff --git a/tests/test_utils/test_bbox_coders.py b/tests/test_utils/test_bbox_coders.py index 385f27609d..e9928bc1cf 100644 --- a/tests/test_utils/test_bbox_coders.py +++ b/tests/test_utils/test_bbox_coders.py @@ -398,26 +398,24 @@ def test_fcos3d_bbox_coder(): # test decode # [2, 7, 1, 1] - batch_bbox_out = torch.tensor([[[[0.3130]], [[0.7094]], [[0.8743]], - [[0.0570]], [[0.5579]], [[0.1593]], - [[0.4553]]], - [[[0.7758]], [[0.2298]], [[0.3925]], - [[0.6307]], [[0.4377]], [[0.3339]], - [[0.1966]]]]) + batch_bbox = torch.tensor([[[[0.3130]], [[0.7094]], [[0.8743]], [[0.0570]], + [[0.5579]], [[0.1593]], [[0.4553]]], + [[[0.7758]], [[0.2298]], [[0.3925]], [[0.6307]], + [[0.4377]], [[0.3339]], [[0.1966]]]]) batch_scale = nn.ModuleList([Scale(1.0) for _ in range(3)]) stride = 2 training = False cls_score = torch.randn([2, 2, 1, 1]).sigmoid() - decode_bbox_out = bbox_coder.decode(batch_bbox_out, batch_scale, stride, - training, cls_score) + decode_bbox = bbox_coder.decode(batch_bbox, batch_scale, stride, training, + cls_score) - expected_bbox_out = torch.tensor([[[[0.6261]], [[1.4188]], [[2.3971]], - [[1.0586]], [[1.7470]], [[1.1727]], - [[0.4553]]], - [[[1.5516]], [[0.4596]], [[1.4806]], - [[1.8790]], [[1.5492]], [[1.3965]], - [[0.1966]]]]) - assert torch.allclose(decode_bbox_out, expected_bbox_out, atol=1e-3) + expected_bbox = torch.tensor([[[[0.6261]], [[1.4188]], [[2.3971]], + [[1.0586]], [[1.7470]], [[1.1727]], + [[0.4553]]], + [[[1.5516]], [[0.4596]], [[1.4806]], + [[1.8790]], [[1.5492]], [[1.3965]], + [[0.1966]]]]) + assert torch.allclose(decode_bbox, expected_bbox, atol=1e-3) # test a config with priors prior_bbox_coder_cfg = dict( @@ -429,39 +427,140 @@ def test_fcos3d_bbox_coder(): prior_bbox_coder = build_bbox_coder(prior_bbox_coder_cfg) # test decode - batch_bbox_out = torch.tensor([[[[0.3130]], [[0.7094]], [[0.8743]], - [[0.0570]], [[0.5579]], [[0.1593]], - [[0.4553]]], - [[[0.7758]], [[0.2298]], [[0.3925]], - [[0.6307]], [[0.4377]], [[0.3339]], - [[0.1966]]]]) + batch_bbox = torch.tensor([[[[0.3130]], [[0.7094]], [[0.8743]], [[0.0570]], + [[0.5579]], [[0.1593]], [[0.4553]]], + [[[0.7758]], [[0.2298]], [[0.3925]], [[0.6307]], + [[0.4377]], [[0.3339]], [[0.1966]]]]) batch_scale = nn.ModuleList([Scale(1.0) for _ in range(3)]) stride = 2 training = False cls_score = torch.tensor([[[[0.5811]], [[0.6198]]], [[[0.4889]], [[0.8142]]]]) - decode_bbox_out = prior_bbox_coder.decode(batch_bbox_out, batch_scale, - stride, training, cls_score) - expected_bbox_out = torch.tensor([[[[0.6260]], [[1.4188]], [[35.4916]], - [[1.0587]], [[3.4940]], [[3.5181]], - [[0.4553]]], - [[[1.5516]], [[0.4596]], [[29.7100]], - [[1.8789]], [[3.0983]], [[4.1892]], - [[0.1966]]]]) - assert torch.allclose(decode_bbox_out, expected_bbox_out, atol=1e-3) + decode_bbox = prior_bbox_coder.decode(batch_bbox, batch_scale, stride, + training, cls_score) + expected_bbox = torch.tensor([[[[0.6260]], [[1.4188]], [[35.4916]], + [[1.0587]], [[3.4940]], [[3.5181]], + [[0.4553]]], + [[[1.5516]], [[0.4596]], [[29.7100]], + [[1.8789]], [[3.0983]], [[4.1892]], + [[0.1966]]]]) + assert torch.allclose(decode_bbox, expected_bbox, atol=1e-3) # test decode_yaw - decode_bbox_out = decode_bbox_out.permute(0, 2, 3, 1).view(-1, 7) + decode_bbox = decode_bbox.permute(0, 2, 3, 1).view(-1, 7) batch_centers2d = torch.tensor([[100., 150.], [200., 100.]]) batch_dir_cls = torch.tensor([0., 1.]) dir_offset = 0.7854 cam2img = torch.tensor([[700., 0., 450., 0.], [0., 700., 200., 0.], [0., 0., 1., 0.], [0., 0., 0., 1.]]) - decode_bbox_out = prior_bbox_coder.decode_yaw(decode_bbox_out, - batch_centers2d, - batch_dir_cls, dir_offset, - cam2img) - expected_bbox_out = torch.tensor( + decode_bbox = prior_bbox_coder.decode_yaw(decode_bbox, batch_centers2d, + batch_dir_cls, dir_offset, + cam2img) + expected_bbox = torch.tensor( [[0.6260, 1.4188, 35.4916, 1.0587, 3.4940, 3.5181, 3.1332], [1.5516, 0.4596, 29.7100, 1.8789, 3.0983, 4.1892, 6.1368]]) - assert torch.allclose(decode_bbox_out, expected_bbox_out, atol=1e-3) + assert torch.allclose(decode_bbox, expected_bbox, atol=1e-3) + + +def test_pgd_bbox_coder(): + # test a config without priors + bbox_coder_cfg = dict( + type='PGDBBoxCoder', + base_depths=None, + base_dims=None, + code_size=7, + norm_on_bbox=True) + bbox_coder = build_bbox_coder(bbox_coder_cfg) + + # test decode_2d + # [2, 27, 1, 1] + batch_bbox = torch.tensor([[[[0.0103]], [[0.7394]], [[0.3296]], [[0.4708]], + [[0.1439]], [[0.0778]], [[0.9399]], [[0.8366]], + [[0.1264]], [[0.3030]], [[0.1898]], [[0.0714]], + [[0.4144]], [[0.4341]], [[0.6442]], [[0.2951]], + [[0.2890]], [[0.4486]], [[0.2848]], [[0.1071]], + [[0.9530]], [[0.9460]], [[0.3822]], [[0.9320]], + [[0.2611]], [[0.5580]], [[0.0397]]], + [[[0.8612]], [[0.1680]], [[0.5167]], [[0.8502]], + [[0.0377]], [[0.3615]], [[0.9550]], [[0.5219]], + [[0.1402]], [[0.6843]], [[0.2121]], [[0.9468]], + [[0.6238]], [[0.7918]], [[0.1646]], [[0.0500]], + [[0.6290]], [[0.3956]], [[0.2901]], [[0.4612]], + [[0.7333]], [[0.1194]], [[0.6999]], [[0.3980]], + [[0.3262]], [[0.7185]], [[0.4474]]]]) + batch_scale = nn.ModuleList([Scale(1.0) for _ in range(5)]) + stride = 2 + training = False + cls_score = torch.randn([2, 2, 1, 1]).sigmoid() + decode_bbox = bbox_coder.decode(batch_bbox, batch_scale, stride, training, + cls_score) + max_regress_range = 16 + pred_keypoints = True + pred_bbox2d = True + decode_bbox_w2d = bbox_coder.decode_2d(decode_bbox, batch_scale, stride, + max_regress_range, training, + pred_keypoints, pred_bbox2d) + expected_decode_bbox_w2d = torch.tensor( + [[[[0.0206]], [[1.4788]], [[1.3904]], [[1.6013]], [[1.1548]], + [[1.0809]], [[0.9399]], [[13.3856]], [[2.0224]], [[4.8480]], + [[3.0368]], [[1.1424]], [[6.6304]], [[6.9456]], [[10.3072]], + [[4.7216]], [[4.6240]], [[7.1776]], [[4.5568]], [[1.7136]], + [[15.2480]], [[15.1360]], [[6.1152]], [[1.8640]], [[0.5222]], + [[1.1160]], [[0.0794]]], + [[[1.7224]], [[0.3360]], [[1.6765]], [[2.3401]], [[1.0384]], + [[1.4355]], [[0.9550]], [[8.3504]], [[2.2432]], [[10.9488]], + [[3.3936]], [[15.1488]], [[9.9808]], [[12.6688]], [[2.6336]], + [[0.8000]], [[10.0640]], [[6.3296]], [[4.6416]], [[7.3792]], + [[11.7328]], [[1.9104]], [[11.1984]], [[0.7960]], [[0.6524]], + [[1.4370]], [[0.8948]]]]) + assert torch.allclose(expected_decode_bbox_w2d, decode_bbox_w2d, atol=1e-3) + + # test decode_prob_depth + # [10, 8] + depth_cls_preds = torch.tensor([ + [-0.4383, 0.7207, -0.4092, 0.4649, 0.8526, 0.6186, -1.4312, -0.7150], + [0.0621, 0.2369, 0.5170, 0.8484, -0.1099, 0.1829, -0.0072, 1.0618], + [-1.6114, -0.1057, 0.5721, -0.5986, -2.0471, 0.8140, -0.8385, -0.4822], + [0.0742, -0.3261, 0.4607, 1.8155, -0.3571, -0.0234, 0.3787, 2.3251], + [1.0492, -0.6881, -0.0136, -1.8291, 0.8460, -1.0171, 2.5691, -0.8114], + [0.0968, -0.5601, 1.0458, 0.2560, 1.3018, 0.1635, 0.0680, -1.0263], + [-0.0765, 0.1498, -2.7321, 1.0047, -0.2505, 0.0871, -0.4820, -0.3003], + [-0.4123, 0.2298, -0.1330, -0.6008, 0.6526, 0.7118, 0.9728, -0.7793], + [1.6940, 0.3355, 1.4661, 0.5477, 0.8667, 0.0527, -0.9975, -0.0689], + [0.4724, -0.3632, -0.0654, 0.4034, -0.3494, -0.7548, 0.7297, 1.2754] + ]) + depth_range = (0, 70) + depth_unit = 10 + num_depth_cls = 8 + uniform_prob_depth_preds = bbox_coder.decode_prob_depth( + depth_cls_preds, depth_range, depth_unit, 'uniform', num_depth_cls) + expected_preds = torch.tensor([ + 32.0441, 38.4689, 36.1831, 48.2096, 46.1560, 32.7973, 33.2155, 39.9822, + 21.9905, 43.0161 + ]) + assert torch.allclose(uniform_prob_depth_preds, expected_preds, atol=1e-3) + + linear_prob_depth_preds = bbox_coder.decode_prob_depth( + depth_cls_preds, depth_range, depth_unit, 'linear', num_depth_cls) + expected_preds = torch.tensor([ + 21.1431, 30.2421, 25.8964, 41.6116, 38.6234, 21.4582, 23.2993, 30.1111, + 13.9273, 36.8419 + ]) + assert torch.allclose(linear_prob_depth_preds, expected_preds, atol=1e-3) + + log_prob_depth_preds = bbox_coder.decode_prob_depth( + depth_cls_preds, depth_range, depth_unit, 'log', num_depth_cls) + expected_preds = torch.tensor([ + 12.6458, 24.2487, 17.4015, 36.9375, 27.5982, 12.5510, 15.6635, 19.8408, + 9.1605, 31.3765 + ]) + assert torch.allclose(log_prob_depth_preds, expected_preds, atol=1e-3) + + loguniform_prob_depth_preds = bbox_coder.decode_prob_depth( + depth_cls_preds, depth_range, depth_unit, 'loguniform', num_depth_cls) + expected_preds = torch.tensor([ + 6.9925, 10.3273, 8.9895, 18.6524, 16.4667, 7.3196, 7.5078, 11.3207, + 3.7987, 13.6095 + ]) + assert torch.allclose( + loguniform_prob_depth_preds, expected_preds, atol=1e-3)