-
Notifications
You must be signed in to change notification settings - Fork 1.5k
/
centerpoint_head.py
926 lines (807 loc) · 36.3 KB
/
centerpoint_head.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
# Copyright (c) OpenMMLab. All rights reserved.
import copy
from typing import Dict, List, Optional, Tuple, Union
import torch
from mmcv.cnn import ConvModule, build_conv_layer
from mmdet.models.utils import multi_apply
from mmengine.model import BaseModule
from mmengine.structures import InstanceData
from torch import Tensor, nn
from mmdet3d.models.utils import (clip_sigmoid, draw_heatmap_gaussian,
gaussian_radius)
from mmdet3d.registry import MODELS, TASK_UTILS
from mmdet3d.structures import Det3DDataSample, xywhr2xyxyr
from ..layers import circle_nms, nms_bev
@MODELS.register_module()
class SeparateHead(BaseModule):
"""SeparateHead for CenterHead.
Args:
in_channels (int): Input channels for conv_layer.
heads (dict): Conv information.
head_conv (int, optional): Output channels.
Default: 64.
final_kernel (int, optional): Kernel size for the last conv layer.
Default: 1.
init_bias (float, optional): Initial bias. Default: -2.19.
conv_cfg (dict, optional): Config of conv layer.
Default: dict(type='Conv2d')
norm_cfg (dict, optional): Config of norm layer.
Default: dict(type='BN2d').
bias (str, optional): Type of bias. Default: 'auto'.
"""
def __init__(self,
in_channels,
heads,
head_conv=64,
final_kernel=1,
init_bias=-2.19,
conv_cfg=dict(type='Conv2d'),
norm_cfg=dict(type='BN2d'),
bias='auto',
init_cfg=None,
**kwargs):
assert init_cfg is None, 'To prevent abnormal initialization ' \
'behavior, init_cfg is not allowed to be set'
super(SeparateHead, self).__init__(init_cfg=init_cfg)
self.heads = heads
self.init_bias = init_bias
for head in self.heads:
classes, num_conv = self.heads[head]
conv_layers = []
c_in = in_channels
for i in range(num_conv - 1):
conv_layers.append(
ConvModule(
c_in,
head_conv,
kernel_size=final_kernel,
stride=1,
padding=final_kernel // 2,
bias=bias,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg))
c_in = head_conv
conv_layers.append(
build_conv_layer(
conv_cfg,
head_conv,
classes,
kernel_size=final_kernel,
stride=1,
padding=final_kernel // 2,
bias=True))
conv_layers = nn.Sequential(*conv_layers)
self.__setattr__(head, conv_layers)
if init_cfg is None:
self.init_cfg = dict(type='Kaiming', layer='Conv2d')
def init_weights(self):
"""Initialize weights."""
super().init_weights()
for head in self.heads:
if head == 'heatmap':
self.__getattr__(head)[-1].bias.data.fill_(self.init_bias)
def forward(self, x):
"""Forward function for SepHead.
Args:
x (torch.Tensor): Input feature map with the shape of
[B, 512, 128, 128].
Returns:
dict[str: torch.Tensor]: contains the following keys:
-reg (torch.Tensor): 2D regression value with the
shape of [B, 2, H, W].
-height (torch.Tensor): Height value with the
shape of [B, 1, H, W].
-dim (torch.Tensor): Size value with the shape
of [B, 3, H, W].
-rot (torch.Tensor): Rotation value with the
shape of [B, 2, H, W].
-vel (torch.Tensor): Velocity value with the
shape of [B, 2, H, W].
-heatmap (torch.Tensor): Heatmap with the shape of
[B, N, H, W].
"""
ret_dict = dict()
for head in self.heads:
ret_dict[head] = self.__getattr__(head)(x)
return ret_dict
@MODELS.register_module()
class DCNSeparateHead(BaseModule):
r"""DCNSeparateHead for CenterHead.
.. code-block:: none
/-----> DCN for heatmap task -----> heatmap task.
feature
\-----> DCN for regression tasks -----> regression tasks
Args:
in_channels (int): Input channels for conv_layer.
num_cls (int): Number of classes.
heads (dict): Conv information.
dcn_config (dict): Config of dcn layer.
head_conv (int, optional): Output channels.
Default: 64.
final_kernel (int, optional): Kernel size for the last conv
layer. Default: 1.
init_bias (float, optional): Initial bias. Default: -2.19.
conv_cfg (dict, optional): Config of conv layer.
Default: dict(type='Conv2d')
norm_cfg (dict, optional): Config of norm layer.
Default: dict(type='BN2d').
bias (str, optional): Type of bias. Default: 'auto'.
""" # noqa: W605
def __init__(self,
in_channels,
num_cls,
heads,
dcn_config,
head_conv=64,
final_kernel=1,
init_bias=-2.19,
conv_cfg=dict(type='Conv2d'),
norm_cfg=dict(type='BN2d'),
bias='auto',
init_cfg=None,
**kwargs):
assert init_cfg is None, 'To prevent abnormal initialization ' \
'behavior, init_cfg is not allowed to be set'
super(DCNSeparateHead, self).__init__(init_cfg=init_cfg)
if 'heatmap' in heads:
heads.pop('heatmap')
# feature adaptation with dcn
# use separate features for classification / regression
self.feature_adapt_cls = build_conv_layer(dcn_config)
self.feature_adapt_reg = build_conv_layer(dcn_config)
# heatmap prediction head
cls_head = [
ConvModule(
in_channels,
head_conv,
kernel_size=3,
padding=1,
conv_cfg=conv_cfg,
bias=bias,
norm_cfg=norm_cfg),
build_conv_layer(
conv_cfg,
head_conv,
num_cls,
kernel_size=3,
stride=1,
padding=1,
bias=bias)
]
self.cls_head = nn.Sequential(*cls_head)
self.init_bias = init_bias
# other regression target
self.task_head = SeparateHead(
in_channels,
heads,
head_conv=head_conv,
final_kernel=final_kernel,
bias=bias)
if init_cfg is None:
self.init_cfg = dict(type='Kaiming', layer='Conv2d')
def init_weights(self):
"""Initialize weights."""
super().init_weights()
self.cls_head[-1].bias.data.fill_(self.init_bias)
def forward(self, x):
"""Forward function for DCNSepHead.
Args:
x (torch.Tensor): Input feature map with the shape of
[B, 512, 128, 128].
Returns:
dict[str: torch.Tensor]: contains the following keys:
-reg (torch.Tensor): 2D regression value with the
shape of [B, 2, H, W].
-height (torch.Tensor): Height value with the
shape of [B, 1, H, W].
-dim (torch.Tensor): Size value with the shape
of [B, 3, H, W].
-rot (torch.Tensor): Rotation value with the
shape of [B, 2, H, W].
-vel (torch.Tensor): Velocity value with the
shape of [B, 2, H, W].
-heatmap (torch.Tensor): Heatmap with the shape of
[B, N, H, W].
"""
center_feat = self.feature_adapt_cls(x)
reg_feat = self.feature_adapt_reg(x)
cls_score = self.cls_head(center_feat)
ret = self.task_head(reg_feat)
ret['heatmap'] = cls_score
return ret
@MODELS.register_module()
class CenterHead(BaseModule):
"""CenterHead for CenterPoint.
Args:
in_channels (list[int] | int, optional): Channels of the input
feature map. Default: [128].
tasks (list[dict], optional): Task information including class number
and class names. Default: None.
bbox_coder (dict, optional): Bbox coder configs. Default: None.
common_heads (dict, optional): Conv information for common heads.
Default: dict().
loss_cls (dict, optional): Config of classification loss function.
Default: dict(type='GaussianFocalLoss', reduction='mean').
loss_bbox (dict, optional): Config of regression loss function.
Default: dict(type='L1Loss', reduction='none').
separate_head (dict, optional): Config of separate head. Default: dict(
type='SeparateHead', init_bias=-2.19, final_kernel=3)
share_conv_channel (int, optional): Output channels for share_conv
layer. Default: 64.
num_heatmap_convs (int, optional): Number of conv layers for heatmap
conv layer. Default: 2.
conv_cfg (dict, optional): Config of conv layer.
Default: dict(type='Conv2d')
norm_cfg (dict, optional): Config of norm layer.
Default: dict(type='BN2d').
bias (str): Type of bias. Default: 'auto'.
norm_bbox (bool): Whether normalize the bbox predictions.
Defaults to True.
train_cfg (dict, optional): Train-time configs. Default: None.
test_cfg (dict, optional): Test-time configs. Default: None.
init_cfg (dict, optional): Config for initialization.
"""
def __init__(self,
in_channels: Union[List[int], int] = [128],
tasks: Optional[List[dict]] = None,
bbox_coder: Optional[dict] = None,
common_heads: dict = dict(),
loss_cls: dict = dict(
type='mmdet.GaussianFocalLoss', reduction='mean'),
loss_bbox: dict = dict(
type='mmdet.L1Loss', reduction='none', loss_weight=0.25),
separate_head: dict = dict(
type='mmdet.SeparateHead',
init_bias=-2.19,
final_kernel=3),
share_conv_channel: int = 64,
num_heatmap_convs: int = 2,
conv_cfg: dict = dict(type='Conv2d'),
norm_cfg: dict = dict(type='BN2d'),
bias: str = 'auto',
norm_bbox: bool = True,
train_cfg: Optional[dict] = None,
test_cfg: Optional[dict] = None,
init_cfg: Optional[dict] = None,
**kwargs):
assert init_cfg is None, 'To prevent abnormal initialization ' \
'behavior, init_cfg is not allowed to be set'
super(CenterHead, self).__init__(init_cfg=init_cfg, **kwargs)
# TODO we should rename this variable,
# for example num_classes_per_task ?
# {'num_class': 2, 'class_names': ['pedestrian', 'traffic_cone']}]
# TODO seems num_classes is useless
num_classes = [len(t['class_names']) for t in tasks]
self.class_names = [t['class_names'] for t in tasks]
self.train_cfg = train_cfg
self.test_cfg = test_cfg
self.in_channels = in_channels
self.num_classes = num_classes
self.norm_bbox = norm_bbox
self.loss_cls = MODELS.build(loss_cls)
self.loss_bbox = MODELS.build(loss_bbox)
self.bbox_coder = TASK_UTILS.build(bbox_coder)
self.num_anchor_per_locs = [n for n in num_classes]
self.fp16_enabled = False
# a shared convolution
self.shared_conv = ConvModule(
in_channels,
share_conv_channel,
kernel_size=3,
padding=1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
bias=bias)
self.task_heads = nn.ModuleList()
for num_cls in num_classes:
heads = copy.deepcopy(common_heads)
heads.update(dict(heatmap=(num_cls, num_heatmap_convs)))
separate_head.update(
in_channels=share_conv_channel, heads=heads, num_cls=num_cls)
self.task_heads.append(MODELS.build(separate_head))
def forward_single(self, x: Tensor) -> dict:
"""Forward function for CenterPoint.
Args:
x (torch.Tensor): Input feature map with the shape of
[B, 512, 128, 128].
Returns:
list[dict]: Output results for tasks.
"""
ret_dicts = []
x = self.shared_conv(x)
for task in self.task_heads:
ret_dicts.append(task(x))
return ret_dicts
def forward(self, feats: List[Tensor]) -> Tuple[List[Tensor]]:
"""Forward pass.
Args:
feats (list[torch.Tensor]): Multi-level features, e.g.,
features produced by FPN.
Returns:
tuple(list[dict]): Output results for tasks.
"""
return multi_apply(self.forward_single, feats)
def _gather_feat(self, feat, ind, mask=None):
"""Gather feature map.
Given feature map and index, return indexed feature map.
Args:
feat (torch.tensor): Feature map with the shape of [B, H*W, 10].
ind (torch.Tensor): Index of the ground truth boxes with the
shape of [B, max_obj].
mask (torch.Tensor, optional): Mask of the feature map with the
shape of [B, max_obj]. Default: None.
Returns:
torch.Tensor: Feature map after gathering with the shape
of [B, max_obj, 10].
"""
dim = feat.size(2)
ind = ind.unsqueeze(2).expand(ind.size(0), ind.size(1), dim)
feat = feat.gather(1, ind)
if mask is not None:
mask = mask.unsqueeze(2).expand_as(feat)
feat = feat[mask]
feat = feat.view(-1, dim)
return feat
def get_targets(
self,
batch_gt_instances_3d: List[InstanceData],
) -> Tuple[List[Tensor]]:
"""Generate targets.
How each output is transformed:
Each nested list is transposed so that all same-index elements in
each sub-list (1, ..., N) become the new sub-lists.
[ [a0, a1, a2, ... ], [b0, b1, b2, ... ], ... ]
==> [ [a0, b0, ... ], [a1, b1, ... ], [a2, b2, ... ] ]
The new transposed nested list is converted into a list of N
tensors generated by concatenating tensors in the new sub-lists.
[ tensor0, tensor1, tensor2, ... ]
Args:
batch_gt_instances_3d (list[:obj:`InstanceData`]): Batch of
gt_instances. It usually includes ``bboxes_3d`` and\
``labels_3d`` attributes.
Returns:
Returns:
tuple[list[torch.Tensor]]: Tuple of target including
the following results in order.
- list[torch.Tensor]: Heatmap scores.
- list[torch.Tensor]: Ground truth boxes.
- list[torch.Tensor]: Indexes indicating the
position of the valid boxes.
- list[torch.Tensor]: Masks indicating which
boxes are valid.
"""
heatmaps, anno_boxes, inds, masks = multi_apply(
self.get_targets_single, batch_gt_instances_3d)
# Transpose heatmaps
heatmaps = list(map(list, zip(*heatmaps)))
heatmaps = [torch.stack(hms_) for hms_ in heatmaps]
# Transpose anno_boxes
anno_boxes = list(map(list, zip(*anno_boxes)))
anno_boxes = [torch.stack(anno_boxes_) for anno_boxes_ in anno_boxes]
# Transpose inds
inds = list(map(list, zip(*inds)))
inds = [torch.stack(inds_) for inds_ in inds]
# Transpose inds
masks = list(map(list, zip(*masks)))
masks = [torch.stack(masks_) for masks_ in masks]
return heatmaps, anno_boxes, inds, masks
def get_targets_single(self,
gt_instances_3d: InstanceData) -> Tuple[Tensor]:
"""Generate training targets for a single sample.
Args:
gt_instances_3d (:obj:`InstanceData`): Gt_instances of
single data sample. It usually includes
``bboxes_3d`` and ``labels_3d`` attributes.
Returns:
tuple[list[torch.Tensor]]: Tuple of target including
the following results in order.
- list[torch.Tensor]: Heatmap scores.
- list[torch.Tensor]: Ground truth boxes.
- list[torch.Tensor]: Indexes indicating the position
of the valid boxes.
- list[torch.Tensor]: Masks indicating which boxes
are valid.
"""
gt_labels_3d = gt_instances_3d.labels_3d
gt_bboxes_3d = gt_instances_3d.bboxes_3d
device = gt_labels_3d.device
gt_bboxes_3d = torch.cat(
(gt_bboxes_3d.gravity_center, gt_bboxes_3d.tensor[:, 3:]),
dim=1).to(device)
max_objs = self.train_cfg['max_objs'] * self.train_cfg['dense_reg']
grid_size = torch.tensor(self.train_cfg['grid_size'])
pc_range = torch.tensor(self.train_cfg['point_cloud_range'])
voxel_size = torch.tensor(self.train_cfg['voxel_size'])
feature_map_size = grid_size[:2] // self.train_cfg['out_size_factor']
# reorganize the gt_dict by tasks
task_masks = []
flag = 0
for class_name in self.class_names:
task_masks.append([
torch.where(gt_labels_3d == class_name.index(i) + flag)
for i in class_name
])
flag += len(class_name)
task_boxes = []
task_classes = []
flag2 = 0
for idx, mask in enumerate(task_masks):
task_box = []
task_class = []
for m in mask:
task_box.append(gt_bboxes_3d[m])
# 0 is background for each task, so we need to add 1 here.
task_class.append(gt_labels_3d[m] + 1 - flag2)
task_boxes.append(torch.cat(task_box, axis=0).to(device))
task_classes.append(torch.cat(task_class).long().to(device))
flag2 += len(mask)
draw_gaussian = draw_heatmap_gaussian
heatmaps, anno_boxes, inds, masks = [], [], [], []
for idx, task_head in enumerate(self.task_heads):
heatmap = gt_bboxes_3d.new_zeros(
(len(self.class_names[idx]), feature_map_size[1],
feature_map_size[0]))
anno_box = gt_bboxes_3d.new_zeros((max_objs, 10),
dtype=torch.float32)
ind = gt_labels_3d.new_zeros((max_objs), dtype=torch.int64)
mask = gt_bboxes_3d.new_zeros((max_objs), dtype=torch.uint8)
num_objs = min(task_boxes[idx].shape[0], max_objs)
for k in range(num_objs):
cls_id = task_classes[idx][k] - 1
width = task_boxes[idx][k][3]
length = task_boxes[idx][k][4]
width = width / voxel_size[0] / self.train_cfg[
'out_size_factor']
length = length / voxel_size[1] / self.train_cfg[
'out_size_factor']
if width > 0 and length > 0:
radius = gaussian_radius(
(length, width),
min_overlap=self.train_cfg['gaussian_overlap'])
radius = max(self.train_cfg['min_radius'], int(radius))
# be really careful for the coordinate system of
# your box annotation.
x, y, z = task_boxes[idx][k][0], task_boxes[idx][k][
1], task_boxes[idx][k][2]
coor_x = (
x - pc_range[0]
) / voxel_size[0] / self.train_cfg['out_size_factor']
coor_y = (
y - pc_range[1]
) / voxel_size[1] / self.train_cfg['out_size_factor']
center = torch.tensor([coor_x, coor_y],
dtype=torch.float32,
device=device)
center_int = center.to(torch.int32)
# throw out not in range objects to avoid out of array
# area when creating the heatmap
if not (0 <= center_int[0] < feature_map_size[0]
and 0 <= center_int[1] < feature_map_size[1]):
continue
draw_gaussian(heatmap[cls_id], center_int, radius)
new_idx = k
x, y = center_int[0], center_int[1]
assert (y * feature_map_size[0] + x <
feature_map_size[0] * feature_map_size[1])
ind[new_idx] = y * feature_map_size[0] + x
mask[new_idx] = 1
# TODO: support other outdoor dataset
vx, vy = task_boxes[idx][k][7:]
rot = task_boxes[idx][k][6]
box_dim = task_boxes[idx][k][3:6]
if self.norm_bbox:
box_dim = box_dim.log()
anno_box[new_idx] = torch.cat([
center - torch.tensor([x, y], device=device),
z.unsqueeze(0), box_dim,
torch.sin(rot).unsqueeze(0),
torch.cos(rot).unsqueeze(0),
vx.unsqueeze(0),
vy.unsqueeze(0)
])
heatmaps.append(heatmap)
anno_boxes.append(anno_box)
masks.append(mask)
inds.append(ind)
return heatmaps, anno_boxes, inds, masks
def loss(self, pts_feats: List[Tensor],
batch_data_samples: List[Det3DDataSample], *args,
**kwargs) -> Dict[str, Tensor]:
"""Forward function for point cloud branch.
Args:
pts_feats (list[torch.Tensor]): Features of point cloud branch
batch_data_samples (List[:obj:`Det3DDataSample`]): The Data
Samples. It usually includes information such as
`gt_instance_3d`, .
Returns:
dict: Losses of each branch.
"""
outs = self(pts_feats)
batch_gt_instance_3d = []
for data_sample in batch_data_samples:
batch_gt_instance_3d.append(data_sample.gt_instances_3d)
losses = self.loss_by_feat(outs, batch_gt_instance_3d)
return losses
def loss_by_feat(self, preds_dicts: Tuple[List[dict]],
batch_gt_instances_3d: List[InstanceData], *args,
**kwargs):
"""Loss function for CenterHead.
Args:
preds_dicts (tuple[list[dict]]): Prediction results of
multiple tasks. The outer tuple indicate different
tasks head, and the internal list indicate different
FPN level.
batch_gt_instances_3d (list[:obj:`InstanceData`]): Batch of
gt_instances. It usually includes ``bboxes_3d`` and\
``labels_3d`` attributes.
Returns:
dict[str,torch.Tensor]: Loss of heatmap and bbox of each task.
"""
heatmaps, anno_boxes, inds, masks = self.get_targets(
batch_gt_instances_3d)
loss_dict = dict()
for task_id, preds_dict in enumerate(preds_dicts):
# heatmap focal loss
preds_dict[0]['heatmap'] = clip_sigmoid(preds_dict[0]['heatmap'])
num_pos = heatmaps[task_id].eq(1).float().sum().item()
loss_heatmap = self.loss_cls(
preds_dict[0]['heatmap'],
heatmaps[task_id],
avg_factor=max(num_pos, 1))
target_box = anno_boxes[task_id]
# reconstruct the anno_box from multiple reg heads
preds_dict[0]['anno_box'] = torch.cat(
(preds_dict[0]['reg'], preds_dict[0]['height'],
preds_dict[0]['dim'], preds_dict[0]['rot'],
preds_dict[0]['vel']),
dim=1)
# Regression loss for dimension, offset, height, rotation
ind = inds[task_id]
num = masks[task_id].float().sum()
pred = preds_dict[0]['anno_box'].permute(0, 2, 3, 1).contiguous()
pred = pred.view(pred.size(0), -1, pred.size(3))
pred = self._gather_feat(pred, ind)
mask = masks[task_id].unsqueeze(2).expand_as(target_box).float()
isnotnan = (~torch.isnan(target_box)).float()
mask *= isnotnan
code_weights = self.train_cfg.get('code_weights', None)
bbox_weights = mask * mask.new_tensor(code_weights)
loss_bbox = self.loss_bbox(
pred, target_box, bbox_weights, avg_factor=(num + 1e-4))
loss_dict[f'task{task_id}.loss_heatmap'] = loss_heatmap
loss_dict[f'task{task_id}.loss_bbox'] = loss_bbox
return loss_dict
def predict(self,
pts_feats: Dict[str, torch.Tensor],
batch_data_samples: List[Det3DDataSample],
rescale=True,
**kwargs) -> List[InstanceData]:
"""
Args:
pts_feats (dict): Point features..
batch_data_samples (List[:obj:`Det3DDataSample`]): The Data
Samples. It usually includes meta information of data.
rescale (bool): Whether rescale the resutls to
the original scale.
Returns:
list[:obj:`InstanceData`]: List of processed predictions. Each
InstanceData contains 3d Bounding boxes and corresponding
scores and labels.
"""
preds_dict = self(pts_feats)
batch_size = len(batch_data_samples)
batch_input_metas = []
for batch_index in range(batch_size):
metainfo = batch_data_samples[batch_index].metainfo
batch_input_metas.append(metainfo)
results_list = self.predict_by_feat(
preds_dict, batch_input_metas, rescale=rescale, **kwargs)
return results_list
def predict_by_feat(self, preds_dicts: Tuple[List[dict]],
batch_input_metas: List[dict], *args,
**kwargs) -> List[InstanceData]:
"""Generate bboxes from bbox head predictions.
Args:
preds_dicts (tuple[list[dict]]): Prediction results of
multiple tasks. The outer tuple indicate different
tasks head, and the internal list indicate different
FPN level.
batch_input_metas (list[dict]): Meta info of multiple
inputs.
Returns:
list[:obj:`InstanceData`]: Instance prediction
results of each sample after the post process.
Each item usually contains following keys.
- scores_3d (Tensor): Classification scores, has a shape
(num_instance, )
- labels_3d (Tensor): Labels of bboxes, has a shape
(num_instances, ).
- bboxes_3d (:obj:`LiDARInstance3DBoxes`): Prediction
of bboxes, contains a tensor with shape
(num_instances, 7) or (num_instances, 9), and
the last 2 dimensions of 9 is
velocity.
"""
rets = []
for task_id, preds_dict in enumerate(preds_dicts):
num_class_with_bg = self.num_classes[task_id]
batch_size = preds_dict[0]['heatmap'].shape[0]
batch_heatmap = preds_dict[0]['heatmap'].sigmoid()
batch_reg = preds_dict[0]['reg']
batch_hei = preds_dict[0]['height']
if self.norm_bbox:
batch_dim = torch.exp(preds_dict[0]['dim'])
else:
batch_dim = preds_dict[0]['dim']
batch_rots = preds_dict[0]['rot'][:, 0].unsqueeze(1)
batch_rotc = preds_dict[0]['rot'][:, 1].unsqueeze(1)
if 'vel' in preds_dict[0]:
batch_vel = preds_dict[0]['vel']
else:
batch_vel = None
temp = self.bbox_coder.decode(
batch_heatmap,
batch_rots,
batch_rotc,
batch_hei,
batch_dim,
batch_vel,
reg=batch_reg,
task_id=task_id)
assert self.test_cfg['nms_type'] in ['circle', 'rotate']
batch_reg_preds = [box['bboxes'] for box in temp]
batch_cls_preds = [box['scores'] for box in temp]
batch_cls_labels = [box['labels'] for box in temp]
if self.test_cfg['nms_type'] == 'circle':
ret_task = []
for i in range(batch_size):
boxes3d = temp[i]['bboxes']
scores = temp[i]['scores']
labels = temp[i]['labels']
centers = boxes3d[:, [0, 1]]
boxes = torch.cat([centers, scores.view(-1, 1)], dim=1)
keep = torch.tensor(
circle_nms(
boxes.detach().cpu().numpy(),
self.test_cfg['min_radius'][task_id],
post_max_size=self.test_cfg['post_max_size']),
dtype=torch.long,
device=boxes.device)
boxes3d = boxes3d[keep]
scores = scores[keep]
labels = labels[keep]
ret = dict(bboxes=boxes3d, scores=scores, labels=labels)
ret_task.append(ret)
rets.append(ret_task)
else:
rets.append(
self.get_task_detections(num_class_with_bg,
batch_cls_preds, batch_reg_preds,
batch_cls_labels,
batch_input_metas))
# Merge branches results
num_samples = len(rets[0])
ret_list = []
for i in range(num_samples):
temp_instances = InstanceData()
for k in rets[0][i].keys():
if k == 'bboxes':
bboxes = torch.cat([ret[i][k] for ret in rets])
bboxes[:, 2] = bboxes[:, 2] - bboxes[:, 5] * 0.5
bboxes = batch_input_metas[i]['box_type_3d'](
bboxes, self.bbox_coder.code_size)
elif k == 'scores':
scores = torch.cat([ret[i][k] for ret in rets])
elif k == 'labels':
flag = 0
for j, num_class in enumerate(self.num_classes):
rets[j][i][k] += flag
flag += num_class
labels = torch.cat([ret[i][k].int() for ret in rets])
temp_instances.bboxes_3d = bboxes
temp_instances.scores_3d = scores
temp_instances.labels_3d = labels
ret_list.append(temp_instances)
return ret_list
def get_task_detections(self, num_class_with_bg, batch_cls_preds,
batch_reg_preds, batch_cls_labels, img_metas):
"""Rotate nms for each task.
Args:
num_class_with_bg (int): Number of classes for the current task.
batch_cls_preds (list[torch.Tensor]): Prediction score with the
shape of [N].
batch_reg_preds (list[torch.Tensor]): Prediction bbox with the
shape of [N, 9].
batch_cls_labels (list[torch.Tensor]): Prediction label with the
shape of [N].
img_metas (list[dict]): Meta information of each sample.
Returns:
list[dict[str: torch.Tensor]]: contains the following keys:
-bboxes (torch.Tensor): Prediction bboxes after nms with the
shape of [N, 9].
-scores (torch.Tensor): Prediction scores after nms with the
shape of [N].
-labels (torch.Tensor): Prediction labels after nms with the
shape of [N].
"""
predictions_dicts = []
post_center_range = self.test_cfg['post_center_limit_range']
if len(post_center_range) > 0:
post_center_range = torch.tensor(
post_center_range,
dtype=batch_reg_preds[0].dtype,
device=batch_reg_preds[0].device)
for i, (box_preds, cls_preds, cls_labels) in enumerate(
zip(batch_reg_preds, batch_cls_preds, batch_cls_labels)):
# Apply NMS in bird eye view
# get the highest score per prediction, then apply nms
# to remove overlapped box.
if num_class_with_bg == 1:
top_scores = cls_preds.squeeze(-1)
top_labels = torch.zeros(
cls_preds.shape[0],
device=cls_preds.device,
dtype=torch.long)
else:
top_labels = cls_labels.long()
top_scores = cls_preds.squeeze(-1)
if self.test_cfg['score_threshold'] > 0.0:
thresh = torch.tensor(
[self.test_cfg['score_threshold']],
device=cls_preds.device).type_as(cls_preds)
top_scores_keep = top_scores >= thresh
top_scores = top_scores.masked_select(top_scores_keep)
if top_scores.shape[0] != 0:
if self.test_cfg['score_threshold'] > 0.0:
box_preds = box_preds[top_scores_keep]
top_labels = top_labels[top_scores_keep]
boxes_for_nms = xywhr2xyxyr(img_metas[i]['box_type_3d'](
box_preds[:, :], self.bbox_coder.code_size).bev)
# the nms in 3d detection just remove overlap boxes.
selected = nms_bev(
boxes_for_nms,
top_scores,
thresh=self.test_cfg['nms_thr'],
pre_max_size=self.test_cfg['pre_max_size'],
post_max_size=self.test_cfg['post_max_size'])
else:
selected = []
# if selected is not None:
selected_boxes = box_preds[selected]
selected_labels = top_labels[selected]
selected_scores = top_scores[selected]
# finally generate predictions.
if selected_boxes.shape[0] != 0:
box_preds = selected_boxes
scores = selected_scores
label_preds = selected_labels
final_box_preds = box_preds
final_scores = scores
final_labels = label_preds
if post_center_range is not None:
mask = (final_box_preds[:, :3] >=
post_center_range[:3]).all(1)
mask &= (final_box_preds[:, :3] <=
post_center_range[3:]).all(1)
predictions_dict = dict(
bboxes=final_box_preds[mask],
scores=final_scores[mask],
labels=final_labels[mask])
else:
predictions_dict = dict(
bboxes=final_box_preds,
scores=final_scores,
labels=final_labels)
else:
dtype = batch_reg_preds[0].dtype
device = batch_reg_preds[0].device
predictions_dict = dict(
bboxes=torch.zeros([0, self.bbox_coder.code_size],
dtype=dtype,
device=device),
scores=torch.zeros([0], dtype=dtype, device=device),
labels=torch.zeros([0],
dtype=top_labels.dtype,
device=device))
predictions_dicts.append(predictions_dict)
return predictions_dicts