Skip to content

Commit df0be64

Browse files
authored
[Enhancement] Speedup formatting by replacing np.transpose with torch.permute (#1719)
1 parent f820470 commit df0be64

File tree

2 files changed

+55
-6
lines changed

2 files changed

+55
-6
lines changed

mmocr/datasets/transforms/formatting.py

+33-6
Original file line numberDiff line numberDiff line change
@@ -90,8 +90,17 @@ def transform(self, results: dict) -> dict:
9090
img = results['img']
9191
if len(img.shape) < 3:
9292
img = np.expand_dims(img, -1)
93-
img = np.ascontiguousarray(img.transpose(2, 0, 1))
94-
packed_results['inputs'] = to_tensor(img)
93+
# A simple trick to speedup formatting by 3-5 times when
94+
# OMP_NUM_THREADS != 1
95+
# Refer to https://github.com/open-mmlab/mmdetection/pull/9533
96+
# for more details
97+
if img.flags.c_contiguous:
98+
img = to_tensor(img)
99+
img = img.permute(2, 0, 1).contiguous()
100+
else:
101+
img = np.ascontiguousarray(img.transpose(2, 0, 1))
102+
img = to_tensor(img)
103+
packed_results['inputs'] = img
95104

96105
data_sample = TextDetDataSample()
97106
instance_data = InstanceData()
@@ -174,8 +183,17 @@ def transform(self, results: dict) -> dict:
174183
img = results['img']
175184
if len(img.shape) < 3:
176185
img = np.expand_dims(img, -1)
177-
img = np.ascontiguousarray(img.transpose(2, 0, 1))
178-
packed_results['inputs'] = to_tensor(img)
186+
# A simple trick to speedup formatting by 3-5 times when
187+
# OMP_NUM_THREADS != 1
188+
# Refer to https://github.com/open-mmlab/mmdetection/pull/9533
189+
# for more details
190+
if img.flags.c_contiguous:
191+
img = to_tensor(img)
192+
img = img.permute(2, 0, 1).contiguous()
193+
else:
194+
img = np.ascontiguousarray(img.transpose(2, 0, 1))
195+
img = to_tensor(img)
196+
packed_results['inputs'] = img
179197

180198
data_sample = TextRecogDataSample()
181199
gt_text = LabelData()
@@ -272,8 +290,17 @@ def transform(self, results: dict) -> dict:
272290
img = results['img']
273291
if len(img.shape) < 3:
274292
img = np.expand_dims(img, -1)
275-
img = np.ascontiguousarray(img.transpose(2, 0, 1))
276-
packed_results['inputs'] = to_tensor(img)
293+
# A simple trick to speedup formatting by 3-5 times when
294+
# OMP_NUM_THREADS != 1
295+
# Refer to https://github.com/open-mmlab/mmdetection/pull/9533
296+
# for more details
297+
if img.flags.c_contiguous:
298+
img = to_tensor(img)
299+
img = img.permute(2, 0, 1).contiguous()
300+
else:
301+
img = np.ascontiguousarray(img.transpose(2, 0, 1))
302+
img = to_tensor(img)
303+
packed_results['inputs'] = img
277304
else:
278305
packed_results['inputs'] = torch.FloatTensor().reshape(0, 0, 0)
279306

tests/test_datasets/test_transforms/test_formatting.py

+22
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,17 @@ def test_packdetinput(self):
3636
transform = PackTextDetInputs()
3737
results = transform(copy.deepcopy(datainfo))
3838
self.assertIn('inputs', results)
39+
self.assertEqual(results['inputs'].shape, torch.Size([1, 10, 10]))
3940
self.assertTupleEqual(tuple(results['inputs'].shape), (1, 10, 10))
4041
self.assertIn('data_samples', results)
4142

43+
# test non-contiugous img
44+
nc_datainfo = copy.deepcopy(datainfo)
45+
nc_datainfo['img'] = nc_datainfo['img'].transpose(1, 0)
46+
results = transform(nc_datainfo)
47+
self.assertIn('inputs', results)
48+
self.assertEqual(results['inputs'].shape, torch.Size([1, 10, 10]))
49+
4250
data_sample = results['data_samples']
4351
self.assertIn('bboxes', data_sample.gt_instances)
4452
self.assertIsInstance(data_sample.gt_instances.bboxes, torch.Tensor)
@@ -115,6 +123,13 @@ def test_packrecogtinput(self):
115123
self.assertIn('valid_ratio', data_sample)
116124
self.assertIn('pad_shape', data_sample)
117125

126+
# test non-contiugous img
127+
nc_datainfo = copy.deepcopy(datainfo)
128+
nc_datainfo['img'] = nc_datainfo['img'].transpose(1, 0)
129+
results = transform(nc_datainfo)
130+
self.assertIn('inputs', results)
131+
self.assertEqual(results['inputs'].shape, torch.Size([1, 10, 10]))
132+
118133
transform = PackTextRecogInputs(meta_keys=('img_path', ))
119134
results = transform(copy.deepcopy(datainfo))
120135
self.assertIn('inputs', results)
@@ -174,6 +189,13 @@ def test_transform(self):
174189
torch.int64)
175190
self.assertIsInstance(data_sample.gt_instances.texts, list)
176191

192+
# test non-contiugous img
193+
nc_datainfo = copy.deepcopy(datainfo)
194+
nc_datainfo['img'] = nc_datainfo['img'].transpose(1, 0)
195+
results = self.transform(nc_datainfo)
196+
self.assertIn('inputs', results)
197+
self.assertEqual(results['inputs'].shape, torch.Size([1, 10, 10]))
198+
177199
transform = PackKIEInputs(meta_keys=('img_path', ))
178200
results = transform(copy.deepcopy(datainfo))
179201
self.assertIn('inputs', results)

0 commit comments

Comments
 (0)