diff --git a/mmdet3d/datasets/transforms/formating.py b/mmdet3d/datasets/transforms/formating.py index 8e20f0fc15..87661f7d93 100644 --- a/mmdet3d/datasets/transforms/formating.py +++ b/mmdet3d/datasets/transforms/formating.py @@ -147,15 +147,19 @@ def pack_single_results(self, results: dict) -> dict: if 'img' in results: if isinstance(results['img'], list): # process multiple imgs in single frame - imgs = [img.transpose(2, 0, 1) for img in results['img']] - imgs = np.ascontiguousarray(np.stack(imgs, axis=0)) + imgs = [to_tensor(img) for img in results['img']] + imgs = torch.stack( + imgs, dim=0).permute(0, 3, 1, 2).contiguous() results['img'] = to_tensor(imgs) else: img = results['img'] if len(img.shape) < 3: img = np.expand_dims(img, -1) - results['img'] = to_tensor( - np.ascontiguousarray(img.transpose(2, 0, 1))) + # To improve the computational speed by by 3-5 times, apply: + # `torch.permute()` rather than `np.transpose()`. + # Refer to https://github.com/open-mmlab/mmdetection/pull/9533 + # for more details + results['img'] = to_tensor(img).permute(2, 0, 1).contiguous() for key in [ 'proposals', 'gt_bboxes', 'gt_bboxes_ignore', 'gt_labels',