Skip to content

Commit

Permalink
[Fix] Centerpoint head nested list transpose (#879)
Browse files Browse the repository at this point in the history
* FIX Transpose nested lists without Numpy

* Removed unused Numpy import
  • Loading branch information
robin-karlsson0 authored Aug 25, 2021
1 parent 319b0e3 commit fc301b9
Showing 1 changed file with 19 additions and 10 deletions.
29 changes: 19 additions & 10 deletions mmdet3d/models/dense_heads/centerpoint_head.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import numpy as np
import torch
from mmcv.cnn import ConvModule, build_conv_layer
from mmcv.runner import BaseModule, force_fp32
Expand Down Expand Up @@ -386,6 +385,17 @@ def _gather_feat(self, feat, ind, mask=None):
def get_targets(self, gt_bboxes_3d, gt_labels_3d):
"""Generate targets.
How each output is transformed:
Each nested list is transposed so that all same-index elements in
each sub-list (1, ..., N) become the new sub-lists.
[ [a0, a1, a2, ... ], [b0, b1, b2, ... ], ... ]
==> [ [a0, b0, ... ], [a1, b1, ... ], [a2, b2, ... ] ]
The new transposed nested list is converted into a list of N
tensors generated by concatenating tensors in the new sub-lists.
[ tensor0, tensor1, tensor2, ... ]
Args:
gt_bboxes_3d (list[:obj:`LiDARInstance3DBoxes`]): Ground
truth gt boxes.
Expand All @@ -405,18 +415,17 @@ def get_targets(self, gt_bboxes_3d, gt_labels_3d):
"""
heatmaps, anno_boxes, inds, masks = multi_apply(
self.get_targets_single, gt_bboxes_3d, gt_labels_3d)
# transpose heatmaps, because the dimension of tensors in each task is
# different, we have to use numpy instead of torch to do the transpose.
heatmaps = np.array(heatmaps).transpose(1, 0).tolist()
# Transpose heatmaps
heatmaps = list(map(list, zip(*heatmaps)))
heatmaps = [torch.stack(hms_) for hms_ in heatmaps]
# transpose anno_boxes
anno_boxes = np.array(anno_boxes).transpose(1, 0).tolist()
# Transpose anno_boxes
anno_boxes = list(map(list, zip(*anno_boxes)))
anno_boxes = [torch.stack(anno_boxes_) for anno_boxes_ in anno_boxes]
# transpose inds
inds = np.array(inds).transpose(1, 0).tolist()
# Transpose inds
inds = list(map(list, zip(*inds)))
inds = [torch.stack(inds_) for inds_ in inds]
# transpose inds
masks = np.array(masks).transpose(1, 0).tolist()
# Transpose inds
masks = list(map(list, zip(*masks)))
masks = [torch.stack(masks_) for masks_ in masks]
return heatmaps, anno_boxes, inds, masks

Expand Down

0 comments on commit fc301b9

Please sign in to comment.