diff --git a/test/test_transforms.py b/test/test_transforms.py index 72100d0feac..d6651816cd2 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -987,19 +987,27 @@ def test_2d_ndarray_to_pil_image(self): self.assertTrue(np.allclose(img_data, img)) def test_tensor_bad_types_to_pil_image(self): - with self.assertRaises(ValueError): + with self.assertRaisesRegex(ValueError, r'pic should be 2/3 dimensional. Got \d+ dimensions.'): transforms.ToPILImage()(torch.ones(1, 3, 4, 4)) + with self.assertRaisesRegex(ValueError, r'pic should not have > 4 channels. Got \d+ channels.'): + transforms.ToPILImage()(torch.ones(6, 4, 4)) def test_ndarray_bad_types_to_pil_image(self): trans = transforms.ToPILImage() - with self.assertRaises(TypeError): + reg_msg = r'Input type \w+ is not supported' + with self.assertRaisesRegex(TypeError, reg_msg): trans(np.ones([4, 4, 1], np.int64)) + with self.assertRaisesRegex(TypeError, reg_msg): trans(np.ones([4, 4, 1], np.uint16)) + with self.assertRaisesRegex(TypeError, reg_msg): trans(np.ones([4, 4, 1], np.uint32)) + with self.assertRaisesRegex(TypeError, reg_msg): trans(np.ones([4, 4, 1], np.float64)) - with self.assertRaises(ValueError): + with self.assertRaisesRegex(ValueError, r'pic should be 2/3 dimensional. Got \d+ dimensions.'): transforms.ToPILImage()(np.ones([1, 4, 4, 3])) + with self.assertRaisesRegex(ValueError, r'pic should not have > 4 channels. Got \d+ channels.'): + transforms.ToPILImage()(np.ones([4, 4, 6])) @unittest.skipIf(stats is None, 'scipy.stats not available') def test_random_vertical_flip(self): diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index 13135807091..4d8a0c09e34 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -183,6 +183,10 @@ def to_pil_image(pic, mode=None): # if 2D image, add channel dimension (CHW) pic = pic.unsqueeze(0) + # check number of channels + if pic.shape[-3] > 4: + raise ValueError('pic should not have > 4 channels. Got {} channels.'.format(pic.shape[-3])) + elif isinstance(pic, np.ndarray): if pic.ndim not in {2, 3}: raise ValueError('pic should be 2/3 dimensional. Got {} dimensions.'.format(pic.ndim)) @@ -191,6 +195,10 @@ def to_pil_image(pic, mode=None): # if 2D image, add channel dimension (HWC) pic = np.expand_dims(pic, 2) + # check number of channels + if pic.shape[-1] > 4: + raise ValueError('pic should not have > 4 channels. Got {} channels.'.format(pic.shape[-1])) + npimg = pic if isinstance(pic, torch.Tensor): if pic.is_floating_point() and mode != 'F':