Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace numpy transpose with torch permute to speed-up #9533

Merged
merged 12 commits into from
Jan 4, 2023
23 changes: 17 additions & 6 deletions mmdet/datasets/pipelines/formatting.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,20 +80,20 @@ def __init__(self, keys):

def __call__(self, results):
"""Call function to convert image in results to :obj:`torch.Tensor` and
transpose the channel order.
permute the channel order.

Args:
results (dict): Result dict contains the image data to convert.

Returns:
dict: The result dict contains the image converted
to :obj:`torch.Tensor` and transposed to (C, H, W) order.
to :obj:`torch.Tensor` and permuted to (C, H, W) order.
"""
for key in self.keys:
img = results[key]
if len(img.shape) < 3:
img = np.expand_dims(img, -1)
results[key] = (to_tensor(img.transpose(2, 0, 1))).contiguous()
results[key] = to_tensor(img).permute(2, 0, 1).contiguous()
return results

def __repr__(self):
Expand Down Expand Up @@ -179,7 +179,7 @@ class DefaultFormatBundle:
"proposals", "gt_bboxes", "gt_labels", "gt_masks" and "gt_semantic_seg".
These fields are formatted as follows.

- img: (1)transpose, (2)to tensor, (3)to DataContainer (stack=True)
- img: (1)transpose & to tensor, (2)to DataContainer (stack=True)
- proposals: (1)to tensor, (2)to DataContainer
- gt_bboxes: (1)to tensor, (2)to DataContainer
- gt_bboxes_ignore: (1)to tensor, (2)to DataContainer
Expand Down Expand Up @@ -226,9 +226,20 @@ def __call__(self, results):
results = self._add_default_meta_keys(results)
if len(img.shape) < 3:
img = np.expand_dims(img, -1)
img = np.ascontiguousarray(img.transpose(2, 0, 1))
# To improve the computational speed by by 3-5 times, apply:
# If image is not contiguous, use
# `numpy.transpose()` followed by `numpy.ascontiguousarray()`
# If image is already contiguous, use
# `torch.permute()` followed by `torch.contiguous()`
# Refer to https://github.com/open-mmlab/mmdetection/pull/9533
# for more details
if not img.flags.c_contiguous:
Min-Sheng marked this conversation as resolved.
Show resolved Hide resolved
img = np.ascontiguousarray(img.transpose(2, 0, 1))
img = to_tensor(img)
else:
img = to_tensor(img).permute(2, 0, 1).contiguous()
results['img'] = DC(
to_tensor(img), padding_value=self.pad_val['img'], stack=True)
img, padding_value=self.pad_val['img'], stack=True)
for key in ['proposals', 'gt_bboxes', 'gt_bboxes_ignore', 'gt_labels']:
if key not in results:
continue
Expand Down
2 changes: 1 addition & 1 deletion requirements/optional.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
cityscapesscripts
imagecorruptions
sklearn
scikit-learn