From a32c079415d4f05971298f054d8872b5d465ddc5 Mon Sep 17 00:00:00 2001 From: Sun Jiahao <72679458+sunjiahao1999@users.noreply.github.com> Date: Sat, 18 Feb 2023 21:42:48 +0800 Subject: [PATCH] [Feature] dev-1.x change np.transpose to torch.permute for speed up (#2277) * change np.transpose to torch.permute * add comments --- mmdet3d/datasets/transforms/formating.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/mmdet3d/datasets/transforms/formating.py b/mmdet3d/datasets/transforms/formating.py index 8e20f0fc15..87661f7d93 100644 --- a/mmdet3d/datasets/transforms/formating.py +++ b/mmdet3d/datasets/transforms/formating.py @@ -147,15 +147,19 @@ def pack_single_results(self, results: dict) -> dict: if 'img' in results: if isinstance(results['img'], list): # process multiple imgs in single frame - imgs = [img.transpose(2, 0, 1) for img in results['img']] - imgs = np.ascontiguousarray(np.stack(imgs, axis=0)) + imgs = [to_tensor(img) for img in results['img']] + imgs = torch.stack( + imgs, dim=0).permute(0, 3, 1, 2).contiguous() results['img'] = to_tensor(imgs) else: img = results['img'] if len(img.shape) < 3: img = np.expand_dims(img, -1) - results['img'] = to_tensor( - np.ascontiguousarray(img.transpose(2, 0, 1))) + # To improve the computational speed by by 3-5 times, apply: + # `torch.permute()` rather than `np.transpose()`. + # Refer to https://github.com/open-mmlab/mmdetection/pull/9533 + # for more details + results['img'] = to_tensor(img).permute(2, 0, 1).contiguous() for key in [ 'proposals', 'gt_bboxes', 'gt_bboxes_ignore', 'gt_labels',