From 886ea4612ac803b8f24d155b5daaa2e89db31bab Mon Sep 17 00:00:00 2001 From: Alexey Demyanchuk Date: Tue, 17 Nov 2020 11:25:08 +0100 Subject: [PATCH 01/10] Add explicit check for number of channels Example why you need to check it: `M = torch.randint(low=0, high=2, size=(6, 64, 64), dtype = torch.float)` When you put this input through to_pil_image without mode argument, it converts to uint8 here: ``` if pic.is_floating_point() and mode != 'F': pic = pic.mul(255).byte() ``` and change the mode to RGB here: ``` if mode is None and npimg.dtype == np.uint8: mode = 'RGB' ``` Image.fromarray doesn't raise if provided with mode RGB and just cut number of channels from what you have to 3 --- torchvision/transforms/functional.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index 13135807091..995290f2054 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -200,6 +200,8 @@ def to_pil_image(pic, mode=None): if not isinstance(npimg, np.ndarray): raise TypeError('Input pic must be a torch.Tensor or NumPy ndarray, ' + 'not {}'.format(type(npimg))) + if npimg.shape[2] > 4: + raise ValueError('pic should not have > 4 channels. Got {} channels.'.format(npimg.shape[2])) if npimg.shape[2] == 1: expected_mode = None From 5cb7f5b59c569889a52962376eeedc6a70e67b99 Mon Sep 17 00:00:00 2001 From: Demyanchuk Date: Wed, 18 Nov 2020 12:15:00 +0100 Subject: [PATCH 02/10] Check number of channels before processing --- torchvision/transforms/functional.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index 995290f2054..899fd1199db 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -179,6 +179,11 @@ def to_pil_image(pic, mode=None): if pic.ndimension() not in {2, 3}: raise ValueError('pic should be 2/3 dimensional. Got {} dimensions.'.format(pic.ndimension())) + elif pic.ndimension() == 3: + # check number of channels + if pic.shape[0] > 4: + raise ValueError('pic should not have > 4 channels. Got {} channels.'.format(pic.shape[0])) + elif pic.ndimension() == 2: # if 2D image, add channel dimension (CHW) pic = pic.unsqueeze(0) @@ -187,6 +192,11 @@ def to_pil_image(pic, mode=None): if pic.ndim not in {2, 3}: raise ValueError('pic should be 2/3 dimensional. Got {} dimensions.'.format(pic.ndim)) + elif pic.ndim == 3: + # check number of channels + if pic.shape[2] > 4: + raise ValueError('pic should not have > 4 channels. Got {} channels.'.format(pic.shape[2])) + elif pic.ndim == 2: # if 2D image, add channel dimension (HWC) pic = np.expand_dims(pic, 2) @@ -200,8 +210,6 @@ def to_pil_image(pic, mode=None): if not isinstance(npimg, np.ndarray): raise TypeError('Input pic must be a torch.Tensor or NumPy ndarray, ' + 'not {}'.format(type(npimg))) - if npimg.shape[2] > 4: - raise ValueError('pic should not have > 4 channels. Got {} channels.'.format(npimg.shape[2])) if npimg.shape[2] == 1: expected_mode = None From 2399e3ff30c47cda912bc4691655367e18ad85f8 Mon Sep 17 00:00:00 2001 From: Demyanchuk Date: Wed, 18 Nov 2020 12:16:19 +0100 Subject: [PATCH 03/10] Add test for invalid number of channels --- test/test_transforms.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/test/test_transforms.py b/test/test_transforms.py index f9add6d1b57..aed917fe340 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -989,6 +989,7 @@ def test_2d_ndarray_to_pil_image(self): def test_tensor_bad_types_to_pil_image(self): with self.assertRaises(ValueError): transforms.ToPILImage()(torch.ones(1, 3, 4, 4)) + transforms.ToPILImage()(torch.ones(6, 4, 4)) def test_ndarray_bad_types_to_pil_image(self): trans = transforms.ToPILImage() @@ -1000,6 +1001,7 @@ def test_ndarray_bad_types_to_pil_image(self): with self.assertRaises(ValueError): transforms.ToPILImage()(np.ones([1, 4, 4, 3])) + transforms.ToPILImage()(np.ones([4, 4, 6])) @unittest.skipIf(stats is None, 'scipy.stats not available') def test_random_vertical_flip(self): From b504d213478def191d16c826335be49ec4ef6b2d Mon Sep 17 00:00:00 2001 From: Alexey Demyanchuk Date: Tue, 17 Nov 2020 11:25:08 +0100 Subject: [PATCH 04/10] Add explicit check for number of channels Example why you need to check it: `M = torch.randint(low=0, high=2, size=(6, 64, 64), dtype = torch.float)` When you put this input through to_pil_image without mode argument, it converts to uint8 here: ``` if pic.is_floating_point() and mode != 'F': pic = pic.mul(255).byte() ``` and change the mode to RGB here: ``` if mode is None and npimg.dtype == np.uint8: mode = 'RGB' ``` Image.fromarray doesn't raise if provided with mode RGB and just cut number of channels from what you have to 3 --- torchvision/transforms/functional.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index 13135807091..995290f2054 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -200,6 +200,8 @@ def to_pil_image(pic, mode=None): if not isinstance(npimg, np.ndarray): raise TypeError('Input pic must be a torch.Tensor or NumPy ndarray, ' + 'not {}'.format(type(npimg))) + if npimg.shape[2] > 4: + raise ValueError('pic should not have > 4 channels. Got {} channels.'.format(npimg.shape[2])) if npimg.shape[2] == 1: expected_mode = None From 2824ff964de4fa2c091617a5de1620d4569c9576 Mon Sep 17 00:00:00 2001 From: Demyanchuk Date: Wed, 18 Nov 2020 12:15:00 +0100 Subject: [PATCH 05/10] Check number of channels before processing --- torchvision/transforms/functional.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index 995290f2054..899fd1199db 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -179,6 +179,11 @@ def to_pil_image(pic, mode=None): if pic.ndimension() not in {2, 3}: raise ValueError('pic should be 2/3 dimensional. Got {} dimensions.'.format(pic.ndimension())) + elif pic.ndimension() == 3: + # check number of channels + if pic.shape[0] > 4: + raise ValueError('pic should not have > 4 channels. Got {} channels.'.format(pic.shape[0])) + elif pic.ndimension() == 2: # if 2D image, add channel dimension (CHW) pic = pic.unsqueeze(0) @@ -187,6 +192,11 @@ def to_pil_image(pic, mode=None): if pic.ndim not in {2, 3}: raise ValueError('pic should be 2/3 dimensional. Got {} dimensions.'.format(pic.ndim)) + elif pic.ndim == 3: + # check number of channels + if pic.shape[2] > 4: + raise ValueError('pic should not have > 4 channels. Got {} channels.'.format(pic.shape[2])) + elif pic.ndim == 2: # if 2D image, add channel dimension (HWC) pic = np.expand_dims(pic, 2) @@ -200,8 +210,6 @@ def to_pil_image(pic, mode=None): if not isinstance(npimg, np.ndarray): raise TypeError('Input pic must be a torch.Tensor or NumPy ndarray, ' + 'not {}'.format(type(npimg))) - if npimg.shape[2] > 4: - raise ValueError('pic should not have > 4 channels. Got {} channels.'.format(npimg.shape[2])) if npimg.shape[2] == 1: expected_mode = None From 61fa06b6efe18a8e0ae2b9274f20bd21fb87a540 Mon Sep 17 00:00:00 2001 From: Demyanchuk Date: Wed, 18 Nov 2020 12:16:19 +0100 Subject: [PATCH 06/10] Add test for invalid number of channels --- test/test_transforms.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/test/test_transforms.py b/test/test_transforms.py index 72100d0feac..f1e4dbe7e3f 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -989,6 +989,7 @@ def test_2d_ndarray_to_pil_image(self): def test_tensor_bad_types_to_pil_image(self): with self.assertRaises(ValueError): transforms.ToPILImage()(torch.ones(1, 3, 4, 4)) + transforms.ToPILImage()(torch.ones(6, 4, 4)) def test_ndarray_bad_types_to_pil_image(self): trans = transforms.ToPILImage() @@ -1000,6 +1001,7 @@ def test_ndarray_bad_types_to_pil_image(self): with self.assertRaises(ValueError): transforms.ToPILImage()(np.ones([1, 4, 4, 3])) + transforms.ToPILImage()(np.ones([4, 4, 6])) @unittest.skipIf(stats is None, 'scipy.stats not available') def test_random_vertical_flip(self): From dc9a8ed7736b6726342d8e7d1e03574ab4d64420 Mon Sep 17 00:00:00 2001 From: Demyanchuk Date: Wed, 18 Nov 2020 14:36:58 +0100 Subject: [PATCH 07/10] Put check after channel dim unsqueeze --- torchvision/transforms/functional.py | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index 899fd1199db..4d8a0c09e34 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -179,28 +179,26 @@ def to_pil_image(pic, mode=None): if pic.ndimension() not in {2, 3}: raise ValueError('pic should be 2/3 dimensional. Got {} dimensions.'.format(pic.ndimension())) - elif pic.ndimension() == 3: - # check number of channels - if pic.shape[0] > 4: - raise ValueError('pic should not have > 4 channels. Got {} channels.'.format(pic.shape[0])) - elif pic.ndimension() == 2: # if 2D image, add channel dimension (CHW) pic = pic.unsqueeze(0) + # check number of channels + if pic.shape[-3] > 4: + raise ValueError('pic should not have > 4 channels. Got {} channels.'.format(pic.shape[-3])) + elif isinstance(pic, np.ndarray): if pic.ndim not in {2, 3}: raise ValueError('pic should be 2/3 dimensional. Got {} dimensions.'.format(pic.ndim)) - elif pic.ndim == 3: - # check number of channels - if pic.shape[2] > 4: - raise ValueError('pic should not have > 4 channels. Got {} channels.'.format(pic.shape[2])) - elif pic.ndim == 2: # if 2D image, add channel dimension (HWC) pic = np.expand_dims(pic, 2) + # check number of channels + if pic.shape[-1] > 4: + raise ValueError('pic should not have > 4 channels. Got {} channels.'.format(pic.shape[-1])) + npimg = pic if isinstance(pic, torch.Tensor): if pic.is_floating_point() and mode != 'F': From 85eb98c26477c5748eafd63921618172f9a87842 Mon Sep 17 00:00:00 2001 From: Demyanchuk Date: Wed, 18 Nov 2020 14:38:21 +0100 Subject: [PATCH 08/10] Add test if error message is matching --- test/test_transforms.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/test/test_transforms.py b/test/test_transforms.py index f1e4dbe7e3f..9822552095c 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -987,20 +987,22 @@ def test_2d_ndarray_to_pil_image(self): self.assertTrue(np.allclose(img_data, img)) def test_tensor_bad_types_to_pil_image(self): - with self.assertRaises(ValueError): + with self.assertRaisesRegex(ValueError, r'pic should be 2/3 dimensional. Got \d+ dimensions.'): transforms.ToPILImage()(torch.ones(1, 3, 4, 4)) + with self.assertRaisesRegex(ValueError, r'pic should not have > 4 channels. Got \d+ channels.'): transforms.ToPILImage()(torch.ones(6, 4, 4)) def test_ndarray_bad_types_to_pil_image(self): trans = transforms.ToPILImage() - with self.assertRaises(TypeError): + with self.assertRaisesRegex(TypeError, r'Input type \w+ is not supported'): trans(np.ones([4, 4, 1], np.int64)) trans(np.ones([4, 4, 1], np.uint16)) trans(np.ones([4, 4, 1], np.uint32)) trans(np.ones([4, 4, 1], np.float64)) - with self.assertRaises(ValueError): + with self.assertRaisesRegex(ValueError, r'pic should be 2/3 dimensional. Got \d+ dimensions.'): transforms.ToPILImage()(np.ones([1, 4, 4, 3])) + with self.assertRaisesRegex(ValueError, r'pic should not have > 4 channels. Got \d+ channels.'): transforms.ToPILImage()(np.ones([4, 4, 6])) @unittest.skipIf(stats is None, 'scipy.stats not available') From b772f23335da62b9a33bddb499879316e1b097af Mon Sep 17 00:00:00 2001 From: Alexey Demyanchuk Date: Wed, 18 Nov 2020 19:30:00 +0100 Subject: [PATCH 09/10] Delete redundant code --- torchvision/transforms/functional.py | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index 7eaaab498a8..4d8a0c09e34 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -179,11 +179,6 @@ def to_pil_image(pic, mode=None): if pic.ndimension() not in {2, 3}: raise ValueError('pic should be 2/3 dimensional. Got {} dimensions.'.format(pic.ndimension())) - elif pic.ndimension() == 3: - # check number of channels - if pic.shape[0] > 4: - raise ValueError('pic should not have > 4 channels. Got {} channels.'.format(pic.shape[0])) - elif pic.ndimension() == 2: # if 2D image, add channel dimension (CHW) pic = pic.unsqueeze(0) @@ -196,11 +191,6 @@ def to_pil_image(pic, mode=None): if pic.ndim not in {2, 3}: raise ValueError('pic should be 2/3 dimensional. Got {} dimensions.'.format(pic.ndim)) - elif pic.ndim == 3: - # check number of channels - if pic.shape[2] > 4: - raise ValueError('pic should not have > 4 channels. Got {} channels.'.format(pic.shape[2])) - elif pic.ndim == 2: # if 2D image, add channel dimension (HWC) pic = np.expand_dims(pic, 2) From 9d625f3923c172064e2cf6ca705c0e9983bde84d Mon Sep 17 00:00:00 2001 From: Alexey Demyanchuk Date: Wed, 18 Nov 2020 19:44:31 +0100 Subject: [PATCH 10/10] Bug fix in checking for bad types --- test/test_transforms.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/test/test_transforms.py b/test/test_transforms.py index 9822552095c..d6651816cd2 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -994,10 +994,14 @@ def test_tensor_bad_types_to_pil_image(self): def test_ndarray_bad_types_to_pil_image(self): trans = transforms.ToPILImage() - with self.assertRaisesRegex(TypeError, r'Input type \w+ is not supported'): + reg_msg = r'Input type \w+ is not supported' + with self.assertRaisesRegex(TypeError, reg_msg): trans(np.ones([4, 4, 1], np.int64)) + with self.assertRaisesRegex(TypeError, reg_msg): trans(np.ones([4, 4, 1], np.uint16)) + with self.assertRaisesRegex(TypeError, reg_msg): trans(np.ones([4, 4, 1], np.uint32)) + with self.assertRaisesRegex(TypeError, reg_msg): trans(np.ones([4, 4, 1], np.float64)) with self.assertRaisesRegex(ValueError, r'pic should be 2/3 dimensional. Got \d+ dimensions.'):