Skip to content

Commit

Permalink
[Feature] PGD BBox Coder (#948)
Browse files Browse the repository at this point in the history
* Support PGD BBox Coder

* Refine docstring
  • Loading branch information
Tai-Wang authored Sep 22, 2021
1 parent 911a333 commit 506f929
Show file tree
Hide file tree
Showing 4 changed files with 264 additions and 39 deletions.
3 changes: 2 additions & 1 deletion mmdet3d/core/bbox/coders/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@
from .fcos3d_bbox_coder import FCOS3DBBoxCoder
from .groupfree3d_bbox_coder import GroupFree3DBBoxCoder
from .partial_bin_based_bbox_coder import PartialBinBasedBBoxCoder
from .pgd_bbox_coder import PGDBBoxCoder
from .point_xyzwhlr_bbox_coder import PointXYZWHLRBBoxCoder

__all__ = [
'build_bbox_coder', 'DeltaXYZWLHRBBoxCoder', 'PartialBinBasedBBoxCoder',
'CenterPointBBoxCoder', 'AnchorFreeBBoxCoder', 'GroupFree3DBBoxCoder',
'PointXYZWHLRBBoxCoder', 'FCOS3DBBoxCoder'
'PointXYZWHLRBBoxCoder', 'FCOS3DBBoxCoder', 'PGDBBoxCoder'
]
2 changes: 1 addition & 1 deletion mmdet3d/core/bbox/coders/fcos3d_bbox_coder.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def decode(self, bbox, scale, stride, training, cls_score=None):
bbox (torch.Tensor): Raw bounding box predictions in shape
[N, C, H, W].
scale (tuple[`Scale`]): Learnable scale parameters.
stride (tuple[int]): Stride for a specific feature level.
stride (int): Stride for a specific feature level.
training (bool): Whether the decoding is in the training
procedure.
cls_score (torch.Tensor): Classification score map for deciding
Expand Down
125 changes: 125 additions & 0 deletions mmdet3d/core/bbox/coders/pgd_bbox_coder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
import numpy as np
from torch.nn import functional as F

from mmdet.core.bbox.builder import BBOX_CODERS
from .fcos3d_bbox_coder import FCOS3DBBoxCoder


@BBOX_CODERS.register_module()
class PGDBBoxCoder(FCOS3DBBoxCoder):
"""Bounding box coder for PGD."""

def encode(self, gt_bboxes_3d, gt_labels_3d, gt_bboxes, gt_labels):
# TODO: refactor the encoder codes in the FCOS3D and PGD head
pass

def decode_2d(self,
bbox,
scale,
stride,
max_regress_range,
training,
pred_keypoints=False,
pred_bbox2d=True):
"""Decode regressed 2D attributes.
Args:
bbox (torch.Tensor): Raw bounding box predictions in shape
[N, C, H, W].
scale (tuple[`Scale`]): Learnable scale parameters.
stride (int): Stride for a specific feature level.
max_regress_range (int): Maximum regression range for a specific
feature level.
training (bool): Whether the decoding is in the training
procedure.
pred_keypoints (bool, optional): Whether to predict keypoints.
Defaults to False.
pred_bbox2d (bool, optional): Whether to predict 2D bounding
boxes. Defaults to False.
Returns:
torch.Tensor: Decoded boxes.
"""
clone_bbox = bbox.clone()
if pred_keypoints:
scale_kpts = scale[3]
# 2 dimension of offsets x 8 corners of a 3D bbox
bbox[:, self.bbox_code_size:self.bbox_code_size + 16] = \
scale_kpts(clone_bbox[
:, self.bbox_code_size:self.bbox_code_size + 16]).float()
if pred_bbox2d:
scale_bbox2d = scale[-1]
# The last four dimensions are offsets to four sides of a 2D bbox
bbox[:, -4:] = scale_bbox2d(clone_bbox[:, -4:]).float()

if self.norm_on_bbox:
if pred_bbox2d:
bbox[:, -4:] = F.relu(bbox.clone()[:, -4:])
if not training:
if pred_keypoints:
bbox[
:, self.bbox_code_size:self.bbox_code_size + 16] *= \
max_regress_range
if pred_bbox2d:
bbox[:, -4:] *= stride
else:
if pred_bbox2d:
bbox[:, -4:] = bbox.clone()[:, -4:].exp()
return bbox

def decode_prob_depth(self, depth_cls_preds, depth_range, depth_unit,
division, num_depth_cls):
"""Decode probabilistic depth map.
Args:
depth_cls_preds (torch.Tensor): Depth probabilistic map in shape
[..., self.num_depth_cls] (raw output before softmax).
depth_range (tuple[float]): Range of depth estimation.
depth_unit (int): Unit of depth range division.
division (str): Depth division method. Options include 'uniform',
'linear', 'log', 'loguniform'.
num_depth_cls (int): Number of depth classes.
Returns:
torch.Tensor: Decoded probabilistic depth estimation.
"""
if division == 'uniform':
depth_multiplier = depth_unit * \
depth_cls_preds.new_tensor(
list(range(num_depth_cls))).reshape([1, -1])
prob_depth_preds = (F.softmax(depth_cls_preds.clone(), dim=-1) *
depth_multiplier).sum(dim=-1)
return prob_depth_preds
elif division == 'linear':
split_pts = depth_cls_preds.new_tensor(list(
range(num_depth_cls))).reshape([1, -1])
depth_multiplier = depth_range[0] + (
depth_range[1] - depth_range[0]) / \
(num_depth_cls * (num_depth_cls - 1)) * \
(split_pts * (split_pts+1))
prob_depth_preds = (F.softmax(depth_cls_preds.clone(), dim=-1) *
depth_multiplier).sum(dim=-1)
return prob_depth_preds
elif division == 'log':
split_pts = depth_cls_preds.new_tensor(list(
range(num_depth_cls))).reshape([1, -1])
start = max(depth_range[0], 1)
end = depth_range[1]
depth_multiplier = (np.log(start) +
split_pts * np.log(end / start) /
(num_depth_cls - 1)).exp()
prob_depth_preds = (F.softmax(depth_cls_preds.clone(), dim=-1) *
depth_multiplier).sum(dim=-1)
return prob_depth_preds
elif division == 'loguniform':
split_pts = depth_cls_preds.new_tensor(list(
range(num_depth_cls))).reshape([1, -1])
start = max(depth_range[0], 1)
end = depth_range[1]
log_multiplier = np.log(start) + \
split_pts * np.log(end / start) / (num_depth_cls - 1)
prob_depth_preds = (F.softmax(depth_cls_preds.clone(), dim=-1) *
log_multiplier).sum(dim=-1).exp()
return prob_depth_preds
else:
raise NotImplementedError
173 changes: 136 additions & 37 deletions tests/test_utils/test_bbox_coders.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,26 +398,24 @@ def test_fcos3d_bbox_coder():

# test decode
# [2, 7, 1, 1]
batch_bbox_out = torch.tensor([[[[0.3130]], [[0.7094]], [[0.8743]],
[[0.0570]], [[0.5579]], [[0.1593]],
[[0.4553]]],
[[[0.7758]], [[0.2298]], [[0.3925]],
[[0.6307]], [[0.4377]], [[0.3339]],
[[0.1966]]]])
batch_bbox = torch.tensor([[[[0.3130]], [[0.7094]], [[0.8743]], [[0.0570]],
[[0.5579]], [[0.1593]], [[0.4553]]],
[[[0.7758]], [[0.2298]], [[0.3925]], [[0.6307]],
[[0.4377]], [[0.3339]], [[0.1966]]]])
batch_scale = nn.ModuleList([Scale(1.0) for _ in range(3)])
stride = 2
training = False
cls_score = torch.randn([2, 2, 1, 1]).sigmoid()
decode_bbox_out = bbox_coder.decode(batch_bbox_out, batch_scale, stride,
training, cls_score)
decode_bbox = bbox_coder.decode(batch_bbox, batch_scale, stride, training,
cls_score)

expected_bbox_out = torch.tensor([[[[0.6261]], [[1.4188]], [[2.3971]],
[[1.0586]], [[1.7470]], [[1.1727]],
[[0.4553]]],
[[[1.5516]], [[0.4596]], [[1.4806]],
[[1.8790]], [[1.5492]], [[1.3965]],
[[0.1966]]]])
assert torch.allclose(decode_bbox_out, expected_bbox_out, atol=1e-3)
expected_bbox = torch.tensor([[[[0.6261]], [[1.4188]], [[2.3971]],
[[1.0586]], [[1.7470]], [[1.1727]],
[[0.4553]]],
[[[1.5516]], [[0.4596]], [[1.4806]],
[[1.8790]], [[1.5492]], [[1.3965]],
[[0.1966]]]])
assert torch.allclose(decode_bbox, expected_bbox, atol=1e-3)

# test a config with priors
prior_bbox_coder_cfg = dict(
Expand All @@ -429,39 +427,140 @@ def test_fcos3d_bbox_coder():
prior_bbox_coder = build_bbox_coder(prior_bbox_coder_cfg)

# test decode
batch_bbox_out = torch.tensor([[[[0.3130]], [[0.7094]], [[0.8743]],
[[0.0570]], [[0.5579]], [[0.1593]],
[[0.4553]]],
[[[0.7758]], [[0.2298]], [[0.3925]],
[[0.6307]], [[0.4377]], [[0.3339]],
[[0.1966]]]])
batch_bbox = torch.tensor([[[[0.3130]], [[0.7094]], [[0.8743]], [[0.0570]],
[[0.5579]], [[0.1593]], [[0.4553]]],
[[[0.7758]], [[0.2298]], [[0.3925]], [[0.6307]],
[[0.4377]], [[0.3339]], [[0.1966]]]])
batch_scale = nn.ModuleList([Scale(1.0) for _ in range(3)])
stride = 2
training = False
cls_score = torch.tensor([[[[0.5811]], [[0.6198]]], [[[0.4889]],
[[0.8142]]]])
decode_bbox_out = prior_bbox_coder.decode(batch_bbox_out, batch_scale,
stride, training, cls_score)
expected_bbox_out = torch.tensor([[[[0.6260]], [[1.4188]], [[35.4916]],
[[1.0587]], [[3.4940]], [[3.5181]],
[[0.4553]]],
[[[1.5516]], [[0.4596]], [[29.7100]],
[[1.8789]], [[3.0983]], [[4.1892]],
[[0.1966]]]])
assert torch.allclose(decode_bbox_out, expected_bbox_out, atol=1e-3)
decode_bbox = prior_bbox_coder.decode(batch_bbox, batch_scale, stride,
training, cls_score)
expected_bbox = torch.tensor([[[[0.6260]], [[1.4188]], [[35.4916]],
[[1.0587]], [[3.4940]], [[3.5181]],
[[0.4553]]],
[[[1.5516]], [[0.4596]], [[29.7100]],
[[1.8789]], [[3.0983]], [[4.1892]],
[[0.1966]]]])
assert torch.allclose(decode_bbox, expected_bbox, atol=1e-3)

# test decode_yaw
decode_bbox_out = decode_bbox_out.permute(0, 2, 3, 1).view(-1, 7)
decode_bbox = decode_bbox.permute(0, 2, 3, 1).view(-1, 7)
batch_centers2d = torch.tensor([[100., 150.], [200., 100.]])
batch_dir_cls = torch.tensor([0., 1.])
dir_offset = 0.7854
cam2img = torch.tensor([[700., 0., 450., 0.], [0., 700., 200., 0.],
[0., 0., 1., 0.], [0., 0., 0., 1.]])
decode_bbox_out = prior_bbox_coder.decode_yaw(decode_bbox_out,
batch_centers2d,
batch_dir_cls, dir_offset,
cam2img)
expected_bbox_out = torch.tensor(
decode_bbox = prior_bbox_coder.decode_yaw(decode_bbox, batch_centers2d,
batch_dir_cls, dir_offset,
cam2img)
expected_bbox = torch.tensor(
[[0.6260, 1.4188, 35.4916, 1.0587, 3.4940, 3.5181, 3.1332],
[1.5516, 0.4596, 29.7100, 1.8789, 3.0983, 4.1892, 6.1368]])
assert torch.allclose(decode_bbox_out, expected_bbox_out, atol=1e-3)
assert torch.allclose(decode_bbox, expected_bbox, atol=1e-3)


def test_pgd_bbox_coder():
# test a config without priors
bbox_coder_cfg = dict(
type='PGDBBoxCoder',
base_depths=None,
base_dims=None,
code_size=7,
norm_on_bbox=True)
bbox_coder = build_bbox_coder(bbox_coder_cfg)

# test decode_2d
# [2, 27, 1, 1]
batch_bbox = torch.tensor([[[[0.0103]], [[0.7394]], [[0.3296]], [[0.4708]],
[[0.1439]], [[0.0778]], [[0.9399]], [[0.8366]],
[[0.1264]], [[0.3030]], [[0.1898]], [[0.0714]],
[[0.4144]], [[0.4341]], [[0.6442]], [[0.2951]],
[[0.2890]], [[0.4486]], [[0.2848]], [[0.1071]],
[[0.9530]], [[0.9460]], [[0.3822]], [[0.9320]],
[[0.2611]], [[0.5580]], [[0.0397]]],
[[[0.8612]], [[0.1680]], [[0.5167]], [[0.8502]],
[[0.0377]], [[0.3615]], [[0.9550]], [[0.5219]],
[[0.1402]], [[0.6843]], [[0.2121]], [[0.9468]],
[[0.6238]], [[0.7918]], [[0.1646]], [[0.0500]],
[[0.6290]], [[0.3956]], [[0.2901]], [[0.4612]],
[[0.7333]], [[0.1194]], [[0.6999]], [[0.3980]],
[[0.3262]], [[0.7185]], [[0.4474]]]])
batch_scale = nn.ModuleList([Scale(1.0) for _ in range(5)])
stride = 2
training = False
cls_score = torch.randn([2, 2, 1, 1]).sigmoid()
decode_bbox = bbox_coder.decode(batch_bbox, batch_scale, stride, training,
cls_score)
max_regress_range = 16
pred_keypoints = True
pred_bbox2d = True
decode_bbox_w2d = bbox_coder.decode_2d(decode_bbox, batch_scale, stride,
max_regress_range, training,
pred_keypoints, pred_bbox2d)
expected_decode_bbox_w2d = torch.tensor(
[[[[0.0206]], [[1.4788]], [[1.3904]], [[1.6013]], [[1.1548]],
[[1.0809]], [[0.9399]], [[13.3856]], [[2.0224]], [[4.8480]],
[[3.0368]], [[1.1424]], [[6.6304]], [[6.9456]], [[10.3072]],
[[4.7216]], [[4.6240]], [[7.1776]], [[4.5568]], [[1.7136]],
[[15.2480]], [[15.1360]], [[6.1152]], [[1.8640]], [[0.5222]],
[[1.1160]], [[0.0794]]],
[[[1.7224]], [[0.3360]], [[1.6765]], [[2.3401]], [[1.0384]],
[[1.4355]], [[0.9550]], [[8.3504]], [[2.2432]], [[10.9488]],
[[3.3936]], [[15.1488]], [[9.9808]], [[12.6688]], [[2.6336]],
[[0.8000]], [[10.0640]], [[6.3296]], [[4.6416]], [[7.3792]],
[[11.7328]], [[1.9104]], [[11.1984]], [[0.7960]], [[0.6524]],
[[1.4370]], [[0.8948]]]])
assert torch.allclose(expected_decode_bbox_w2d, decode_bbox_w2d, atol=1e-3)

# test decode_prob_depth
# [10, 8]
depth_cls_preds = torch.tensor([
[-0.4383, 0.7207, -0.4092, 0.4649, 0.8526, 0.6186, -1.4312, -0.7150],
[0.0621, 0.2369, 0.5170, 0.8484, -0.1099, 0.1829, -0.0072, 1.0618],
[-1.6114, -0.1057, 0.5721, -0.5986, -2.0471, 0.8140, -0.8385, -0.4822],
[0.0742, -0.3261, 0.4607, 1.8155, -0.3571, -0.0234, 0.3787, 2.3251],
[1.0492, -0.6881, -0.0136, -1.8291, 0.8460, -1.0171, 2.5691, -0.8114],
[0.0968, -0.5601, 1.0458, 0.2560, 1.3018, 0.1635, 0.0680, -1.0263],
[-0.0765, 0.1498, -2.7321, 1.0047, -0.2505, 0.0871, -0.4820, -0.3003],
[-0.4123, 0.2298, -0.1330, -0.6008, 0.6526, 0.7118, 0.9728, -0.7793],
[1.6940, 0.3355, 1.4661, 0.5477, 0.8667, 0.0527, -0.9975, -0.0689],
[0.4724, -0.3632, -0.0654, 0.4034, -0.3494, -0.7548, 0.7297, 1.2754]
])
depth_range = (0, 70)
depth_unit = 10
num_depth_cls = 8
uniform_prob_depth_preds = bbox_coder.decode_prob_depth(
depth_cls_preds, depth_range, depth_unit, 'uniform', num_depth_cls)
expected_preds = torch.tensor([
32.0441, 38.4689, 36.1831, 48.2096, 46.1560, 32.7973, 33.2155, 39.9822,
21.9905, 43.0161
])
assert torch.allclose(uniform_prob_depth_preds, expected_preds, atol=1e-3)

linear_prob_depth_preds = bbox_coder.decode_prob_depth(
depth_cls_preds, depth_range, depth_unit, 'linear', num_depth_cls)
expected_preds = torch.tensor([
21.1431, 30.2421, 25.8964, 41.6116, 38.6234, 21.4582, 23.2993, 30.1111,
13.9273, 36.8419
])
assert torch.allclose(linear_prob_depth_preds, expected_preds, atol=1e-3)

log_prob_depth_preds = bbox_coder.decode_prob_depth(
depth_cls_preds, depth_range, depth_unit, 'log', num_depth_cls)
expected_preds = torch.tensor([
12.6458, 24.2487, 17.4015, 36.9375, 27.5982, 12.5510, 15.6635, 19.8408,
9.1605, 31.3765
])
assert torch.allclose(log_prob_depth_preds, expected_preds, atol=1e-3)

loguniform_prob_depth_preds = bbox_coder.decode_prob_depth(
depth_cls_preds, depth_range, depth_unit, 'loguniform', num_depth_cls)
expected_preds = torch.tensor([
6.9925, 10.3273, 8.9895, 18.6524, 16.4667, 7.3196, 7.5078, 11.3207,
3.7987, 13.6095
])
assert torch.allclose(
loguniform_prob_depth_preds, expected_preds, atol=1e-3)

0 comments on commit 506f929

Please sign in to comment.