Skip to content

Commit

Permalink
Merge e31d8c6 into 276ca24
Browse files Browse the repository at this point in the history
  • Loading branch information
HAOCHENYE authored Oct 12, 2022
2 parents 276ca24 + e31d8c6 commit 078375d
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 16 deletions.
14 changes: 9 additions & 5 deletions mmengine/model/base_model/data_preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
from mmengine.utils import is_list_of
from ..utils import stack_batch

CastData = Union[tuple, dict, BaseDataElement, torch.Tensor, list]
CastData = Union[tuple, dict, BaseDataElement, torch.Tensor, list, bytes, str,
None]


@MODELS.register_module()
Expand Down Expand Up @@ -42,17 +43,20 @@ def cast_data(self, data: CastData) -> CastData:
"""
if isinstance(data, Mapping):
return {key: self.cast_data(data[key]) for key in data}
elif isinstance(data, (str, bytes)) or data is None:
return data
elif isinstance(data, tuple) and hasattr(data, '_fields'):
# namedtuple
return type(data)(*(self.cast_data(sample)for sample in data)) # type: ignore # noqa: E501 # yapf:disable
elif isinstance(data, Sequence):
return [self.cast_data(sample) for sample in data]
elif isinstance(data, torch.Tensor):
return data.to(self.device)
elif isinstance(data, BaseDataElement):
elif isinstance(data, (torch.Tensor, BaseDataElement)):
return data.to(self.device)
else:
return data
raise TypeError(
'`BaseDataPreprocessor.cast_data`: batch must contain '
'tensors, numpy arrays, numbers, dicts or lists, but '
f'found {type(data)}')

def forward(self, data: dict, training: bool = False) -> Union[dict, list]:
"""Preprocesses the data into the model input format.
Expand Down
36 changes: 25 additions & 11 deletions tests/test_model/test_base_model/test_data_preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ def test_forward(self):
label1 = torch.randn(1)
label2 = torch.randn(1)

# Test with dict of batch inputs and batch data samples
data = dict(inputs=[input1, input2], data_sample=[label1, label2])

output = base_data_preprocessor(data)
batch_inputs, batch_labels = output['inputs'], output['data_sample']
self.assertTrue(torch.is_floating_point(batch_inputs[0]))
Expand All @@ -36,40 +36,54 @@ def test_forward(self):
assert_allclose(label2, batch_labels[1])

# Test with tuple of batch inputs and batch data samples
data = dict(
inputs=torch.stack([input1, input2]), data_sample=[label1, label2])
output = base_data_preprocessor(data)['inputs']
data = (torch.stack([input1, input2]), (label1, label2))
batch_inputs, batch_labels = base_data_preprocessor(data)
self.assertTrue(torch.is_floating_point(batch_inputs))
self.assertEqual(batch_inputs[0].shape, (1, 3, 5))
self.assertEqual(batch_inputs[1].shape, (1, 3, 5))
self.assertTrue(torch.is_floating_point(batch_inputs[0]))

# Test cuda forward
if torch.cuda.is_available():
# Test with list of data samples.
data = dict(inputs=[input1, input2], data_sample=[label1, label2])
base_data_preprocessor = base_data_preprocessor.cuda()
output = base_data_preprocessor(data)
batch_inputs, batch_labels = output['inputs'], output[
'data_sample']
self.assertTrue(torch.is_floating_point(batch_inputs))
self.assertEqual(batch_inputs.device.type, 'cuda')
self.assertTrue(torch.is_floating_point(batch_inputs[0]))
self.assertEqual(batch_inputs[0].device.type, 'cuda')

# Fallback to test with cpu.
base_data_preprocessor = base_data_preprocessor.cpu()
output = base_data_preprocessor(data)
batch_inputs, batch_labels = output['inputs'], output[
'data_sample']
self.assertTrue(torch.is_floating_point(batch_inputs))
self.assertEqual(batch_inputs.device.type, 'cpu')
self.assertTrue(torch.is_floating_point(batch_inputs[0]))
self.assertEqual(batch_inputs[0].device.type, 'cpu')

# Test `base_data_preprocessor` can be moved to cuda again.
base_data_preprocessor = base_data_preprocessor.to('cuda:0')
output = base_data_preprocessor(data)
batch_inputs, batch_labels = output['inputs'], output[
'data_sample']
self.assertTrue(torch.is_floating_point(batch_inputs))
self.assertEqual(batch_inputs.device.type, 'cuda')
self.assertTrue(torch.is_floating_point(batch_inputs[0]))
self.assertEqual(batch_inputs[0].device.type, 'cuda')

# device of `base_data_preprocessor` is cuda, output should be
# cuda tensor.
self.assertEqual(batch_inputs.device.type, 'cuda')
self.assertEqual(batch_inputs[0].device.type, 'cuda')
self.assertEqual(batch_labels[0].device.type, 'cuda')

# Test forward with string value
data = dict(string='abc')
base_data_preprocessor(data)

with self.assertRaisesRegex(TypeError,
'`BaseDataPreprocessor.cast_data`:'):
data = dict(string=object())
base_data_preprocessor(data)


class TestImgDataPreprocessor(TestBaseDataPreprocessor):

Expand Down

0 comments on commit 078375d

Please sign in to comment.