diff --git a/test/test_transforms.py b/test/test_transforms.py index 45871276073..945bf5fde4b 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -299,13 +299,22 @@ def test_pad(self): width = random.randint(10, 32) * 2 img = torch.ones(3, height, width) padding = random.randint(1, 20) + fill = random.randint(1, 50) result = transforms.Compose([ transforms.ToPILImage(), - transforms.Pad(padding), + transforms.Pad(padding, fill=fill), transforms.ToTensor(), ])(img) self.assertEqual(result.size(1), height + 2 * padding) self.assertEqual(result.size(2), width + 2 * padding) + # check that all elements in the padded region correspond + # to the pad value + fill_v = fill / 255 + eps = 1e-5 + self.assertTrue((result[:, :padding, :] - fill_v).abs().max() < eps) + self.assertTrue((result[:, :, :padding] - fill_v).abs().max() < eps) + self.assertRaises(ValueError, transforms.Pad(padding, fill=(1, 2)), + transforms.ToPILImage()(img)) def test_pad_with_tuple_of_pad_values(self): height = random.randint(10, 32) * 2 diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index 7f22fc51391..eb4bc95f66d 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -329,6 +329,12 @@ def pad(img, padding, fill=0, padding_mode='constant'): 'Padding mode should be either constant, edge, reflect or symmetric' if padding_mode == 'constant': + if isinstance(fill, numbers.Number): + fill = (fill,) * len(img.getbands()) + if len(fill) != len(img.getbands()): + raise ValueError('fill should have the same number of elements ' + 'as the number of channels in the image ' + '({}), got {} instead'.format(len(img.getbands()), len(fill))) if img.mode == 'P': palette = img.getpalette() image = ImageOps.expand(img, border=padding, fill=fill)