Skip to content

Commit

Permalink
[Fix] Fix BaseDataPreprocessor.cast_data cound not handle string da…
Browse files Browse the repository at this point in the history
…ta (open-mmlab#602)

* [Fix] Fix  cound not handle string data

* Minor refine

* Refine type hint

Refine type hint

* fix as comment

* Minor refine

* Update mmengine/model/base_model/data_preprocessor.py

Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>
  • Loading branch information
2 people authored and ly015 committed Nov 9, 2022
1 parent 3548d14 commit 66d0865
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 18 deletions.
18 changes: 11 additions & 7 deletions mmengine/model/base_model/data_preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
from mmengine.utils import is_list_of
from ..utils import stack_batch

CastData = Union[tuple, dict, BaseDataElement, torch.Tensor, list]
CastData = Union[tuple, dict, BaseDataElement, torch.Tensor, list, bytes, str,
None]


@MODELS.register_module()
Expand Down Expand Up @@ -48,17 +49,20 @@ def cast_data(self, data: CastData) -> CastData:
"""
if isinstance(data, Mapping):
return {key: self.cast_data(data[key]) for key in data}
elif isinstance(data, (str, bytes)) or data is None:
return data
elif isinstance(data, tuple) and hasattr(data, '_fields'):
# namedtuple
return type(data)(*(self.cast_data(sample)for sample in data)) # type: ignore # noqa: E501 # yapf:disable
return type(data)(*(self.cast_data(sample) for sample in data)) # type: ignore # noqa: E501 # yapf:disable
elif isinstance(data, Sequence):
return [self.cast_data(sample) for sample in data]
elif isinstance(data, torch.Tensor):
return data.to(self.device, non_blocking=self._non_blocking)
elif isinstance(data, BaseDataElement):
return type(data)(self.cast_data(sample) for sample in data) # type: ignore # noqa: E501 # yapf:disable
elif isinstance(data, (torch.Tensor, BaseDataElement)):
return data.to(self.device, non_blocking=self._non_blocking)
else:
return data
raise TypeError(
'`BaseDataPreprocessor.cast_data`: batch data must contain '
'tensors, numpy arrays, numbers, dicts or lists, but '
f'found {type(data)}')

def forward(self, data: dict, training: bool = False) -> Union[dict, list]:
"""Preprocesses the data into the model input format.
Expand Down
36 changes: 25 additions & 11 deletions tests/test_model/test_base_model/test_data_preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ def test_forward(self):
label1 = torch.randn(1)
label2 = torch.randn(1)

# Test with dict of batch inputs and batch data samples
data = dict(inputs=[input1, input2], data_sample=[label1, label2])

output = base_data_preprocessor(data)
batch_inputs, batch_labels = output['inputs'], output['data_sample']
self.assertTrue(torch.is_floating_point(batch_inputs[0]))
Expand All @@ -41,40 +41,54 @@ def test_forward(self):
assert_allclose(label2, batch_labels[1])

# Test with tuple of batch inputs and batch data samples
data = dict(
inputs=torch.stack([input1, input2]), data_sample=[label1, label2])
output = base_data_preprocessor(data)['inputs']
data = (torch.stack([input1, input2]), (label1, label2))
batch_inputs, batch_labels = base_data_preprocessor(data)
self.assertTrue(torch.is_floating_point(batch_inputs))
self.assertEqual(batch_inputs[0].shape, (1, 3, 5))
self.assertEqual(batch_inputs[1].shape, (1, 3, 5))
self.assertTrue(torch.is_floating_point(batch_inputs[0]))

# Test cuda forward
if torch.cuda.is_available():
# Test with list of data samples.
data = dict(inputs=[input1, input2], data_sample=[label1, label2])
base_data_preprocessor = base_data_preprocessor.cuda()
output = base_data_preprocessor(data)
batch_inputs, batch_labels = output['inputs'], output[
'data_sample']
self.assertTrue(torch.is_floating_point(batch_inputs))
self.assertEqual(batch_inputs.device.type, 'cuda')
self.assertTrue(torch.is_floating_point(batch_inputs[0]))
self.assertEqual(batch_inputs[0].device.type, 'cuda')

# Fallback to test with cpu.
base_data_preprocessor = base_data_preprocessor.cpu()
output = base_data_preprocessor(data)
batch_inputs, batch_labels = output['inputs'], output[
'data_sample']
self.assertTrue(torch.is_floating_point(batch_inputs))
self.assertEqual(batch_inputs.device.type, 'cpu')
self.assertTrue(torch.is_floating_point(batch_inputs[0]))
self.assertEqual(batch_inputs[0].device.type, 'cpu')

# Test `base_data_preprocessor` can be moved to cuda again.
base_data_preprocessor = base_data_preprocessor.to('cuda:0')
output = base_data_preprocessor(data)
batch_inputs, batch_labels = output['inputs'], output[
'data_sample']
self.assertTrue(torch.is_floating_point(batch_inputs))
self.assertEqual(batch_inputs.device.type, 'cuda')
self.assertTrue(torch.is_floating_point(batch_inputs[0]))
self.assertEqual(batch_inputs[0].device.type, 'cuda')

# device of `base_data_preprocessor` is cuda, output should be
# cuda tensor.
self.assertEqual(batch_inputs.device.type, 'cuda')
self.assertEqual(batch_inputs[0].device.type, 'cuda')
self.assertEqual(batch_labels[0].device.type, 'cuda')

# Test forward with string value
data = dict(string='abc')
base_data_preprocessor(data)

with self.assertRaisesRegex(TypeError,
'`BaseDataPreprocessor.cast_data`:'):
data = dict(string=object())
base_data_preprocessor(data)


class TestImgDataPreprocessor(TestBaseDataPreprocessor):

Expand Down

0 comments on commit 66d0865

Please sign in to comment.