diff --git a/mmseg/models/data_preprocessor.py b/mmseg/models/data_preprocessor.py index 34087d0c074..ffbc8c8b40d 100644 --- a/mmseg/models/data_preprocessor.py +++ b/mmseg/models/data_preprocessor.py @@ -48,18 +48,24 @@ class SegDataPreProcessor(BaseDataPreprocessor): rgb_to_bgr (bool): whether to convert image from RGB to RGB. Defaults to False. batch_augments (list[dict], optional): Batch-level augmentations + test_cfg (dict, optional): The padding size config in testing, if not + specify, will use `size` and `size_divisor` params as default. + Defaults to None, only supports keys `size` or `size_divisor`. """ - def __init__(self, - mean: Sequence[Number] = None, - std: Sequence[Number] = None, - size: Optional[tuple] = None, - size_divisor: Optional[int] = None, - pad_val: Number = 0, - seg_pad_val: Number = 255, - bgr_to_rgb: bool = False, - rgb_to_bgr: bool = False, - batch_augments: Optional[List[dict]] = None): + def __init__( + self, + mean: Sequence[Number] = None, + std: Sequence[Number] = None, + size: Optional[tuple] = None, + size_divisor: Optional[int] = None, + pad_val: Number = 0, + seg_pad_val: Number = 255, + bgr_to_rgb: bool = False, + rgb_to_bgr: bool = False, + batch_augments: Optional[List[dict]] = None, + test_cfg: dict = None, + ): super().__init__() self.size = size self.size_divisor = size_divisor @@ -86,6 +92,9 @@ def __init__(self, # TODO: support batch augmentations. self.batch_augments = batch_augments + # Support different padding methods in testing + self.test_cfg = test_cfg + def forward(self, data: dict, training: bool = False) -> Dict[str, Any]: """Perform normalization、padding and bgr2rgb conversion based on ``BaseDataPreprocessor``. @@ -122,10 +131,19 @@ def forward(self, data: dict, training: bool = False) -> Dict[str, Any]: if self.batch_augments is not None: inputs, data_samples = self.batch_augments( inputs, data_samples) - return dict(inputs=inputs, data_samples=data_samples) else: assert len(inputs) == 1, ( 'Batch inference is not support currently, ' 'as the image size might be different in a batch') - return dict( - inputs=torch.stack(inputs, dim=0), data_samples=data_samples) + # pad images when testing + if self.test_cfg: + inputs, _ = stack_batch( + inputs=inputs, + size=self.test_cfg.get('size', None), + size_divisor=self.test_cfg.get('size_divisor', None), + pad_val=self.pad_val, + seg_pad_val=self.seg_pad_val) + else: + inputs = torch.stack(inputs, dim=0) + + return dict(inputs=inputs, data_samples=data_samples) diff --git a/mmseg/models/segmentors/base.py b/mmseg/models/segmentors/base.py index dfceddd99f7..9fd496f27b8 100644 --- a/mmseg/models/segmentors/base.py +++ b/mmseg/models/segmentors/base.py @@ -165,6 +165,7 @@ def postprocess_result(self, i_seg_logits = seg_logits[i:i + 1, :, padding_top:H - padding_bottom, padding_left:W - padding_right] + # resize as original shape i_seg_logits = resize( i_seg_logits, diff --git a/tests/test_models/test_data_preprocessor.py b/tests/test_models/test_data_preprocessor.py index 6b2903ff324..d05eef1c7d8 100644 --- a/tests/test_models/test_data_preprocessor.py +++ b/tests/test_models/test_data_preprocessor.py @@ -46,3 +46,19 @@ def test_forward(self): out = processor(data, training=True) self.assertEqual(out['inputs'].shape, (2, 3, 20, 20)) self.assertEqual(len(out['data_samples']), 2) + + # test predict with padding + processor = SegDataPreProcessor( + mean=[0, 0, 0], + std=[1, 1, 1], + size=(20, 20), + test_cfg=dict(size_divisor=15)) + data = { + 'inputs': [ + torch.randint(0, 256, (3, 11, 10)), + ], + 'data_samples': [data_sample] + } + out = processor(data, training=False) + self.assertEqual(out['inputs'].shape[2] % 15, 0) + self.assertEqual(out['inputs'].shape[3] % 15, 0)