From 286e3166cb61e45038245b979db31aac60d1745b Mon Sep 17 00:00:00 2001 From: Steven Basart Date: Fri, 10 Apr 2020 10:16:07 -0400 Subject: [PATCH 01/23] Adds as_tensor to functional.py Similar functionality to to_tensor without the default conversion to float and division by 255. Also adds support for Image mode 'L'. --- torchvision/transforms/functional.py | 47 ++++++++++++++++++++++++++++ 1 file changed, 47 insertions(+) diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index 7ce1fb6ab36..746ca9eed77 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -81,6 +81,53 @@ def to_tensor(pic): else: return img +def as_tensor(pic): + """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor of same type. + + See ``AsTensor`` for more details. + + Args: + pic (PIL Image or numpy.ndarray): Image to be converted to tensor. + + Returns: + Tensor: Converted image. + """ + if not(_is_pil_image(pic) or _is_numpy(pic)): + raise TypeError('pic should be PIL Image or ndarray. Got {}'.format(type(pic))) + + if _is_numpy(pic) and not _is_numpy_image(pic): + raise ValueError('pic should be 2/3 dimensional. Got {} dimensions.'.format(pic.ndim)) + + if isinstance(pic, np.ndarray): + # handle numpy array + if pic.ndim == 2: + pic = pic[:, :, None] + + img = torch.from_numpy(pic.transpose((2, 0, 1))) + return img + + if accimage is not None and isinstance(pic, accimage.Image): + nppic = np.zeros([pic.channels, pic.height, pic.width], dtype=np.float32) + pic.copyto(nppic) + return torch.from_numpy(nppic) + + # handle PIL Image + if pic.mode == 'I': + img = torch.from_numpy(np.array(pic, np.int32, copy=False)) + elif pic.mode == 'I;16': + img = torch.from_numpy(np.array(pic, np.int16, copy=False)) + elif pic.mode == 'F': + img = torch.from_numpy(np.array(pic, np.float32, copy=False)) + elif pic.mode == '1' or pic.mode == 'L': + img = torch.from_numpy(np.array(pic, np.uint8, copy=False)) + else: + img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes())) + + img = img.view(pic.size[1], pic.size[0], len(pic.getbands())) + # put it from HWC to CHW format + img = img.permute((2, 0, 1)).contiguous() + return img + def to_pil_image(pic, mode=None): """Convert a tensor or an ndarray to PIL Image. From f7eb48997520780c8dcffab5bc5e8c6ea87c2e2b Mon Sep 17 00:00:00 2001 From: Steven Basart Date: Fri, 10 Apr 2020 10:42:38 -0400 Subject: [PATCH 02/23] Adds tests to AsTensor() Adds tests to AsTensor and removes the conversion to float and division by 255. --- test/test_transforms.py | 44 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 44 insertions(+) diff --git a/test/test_transforms.py b/test/test_transforms.py index f19c5480b02..f2c9b71f500 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -512,6 +512,50 @@ def test_accimage_to_tensor(self): self.assertEqual(expected_output.size(), output.size()) self.assertTrue(np.allclose(output.numpy(), expected_output.numpy())) + def test_as_tensor(self): + test_channels = [1, 3, 4] + height, width = 4, 4 + trans = transforms.AsTensor() + + with self.assertRaises(TypeError): + trans(np.random.rand(1, height, width).tolist()) + + with self.assertRaises(ValueError): + trans(np.random.rand(height)) + trans(np.random.rand(1, 1, height, width)) + + for channels in test_channels: + input_data = torch.ByteTensor(channels, height, width).random_(0, 255) + img = transforms.ToPILImage()(input_data) + output = trans(img) + self.assertTrue(np.allclose(input_data.numpy(), output.numpy())) + + ndarray = np.random.randint(low=0, high=255, size=(height, width, channels)).astype(np.uint8) + output = trans(ndarray) + expected_output = ndarray.transpose((2, 0, 1)) + self.assertTrue(np.allclose(output.numpy(), expected_output)) + + ndarray = np.random.rand(height, width, channels).astype(np.float32) + output = trans(ndarray) + expected_output = ndarray.transpose((2, 0, 1)) + self.assertTrue(np.allclose(output.numpy(), expected_output)) + + # separate test for mode '1' PIL images + input_data = torch.ByteTensor(1, height, width).bernoulli_() + img = transforms.ToPILImage()(input_data.mul(255)).convert('1') + output = trans(img) + self.assertTrue(np.allclose(input_data.numpy(), output.numpy())) + + @unittest.skipIf(accimage is None, 'accimage not available') + def test_accimage_as_tensor(self): + trans = transforms.AsTensor() + + expected_output = trans(Image.open(GRACE_HOPPER).convert('RGB')) + output = trans(accimage.Image(GRACE_HOPPER)) + + self.assertEqual(expected_output.size(), output.size()) + self.assertTrue(np.allclose(output.numpy(), expected_output.numpy())) + @unittest.skipIf(accimage is None, 'accimage not available') def test_accimage_resize(self): trans = transforms.Compose([ From f90b3bcb45d5cc614702f404c1ba64c6354063a5 Mon Sep 17 00:00:00 2001 From: Steven Basart Date: Fri, 10 Apr 2020 10:47:59 -0400 Subject: [PATCH 03/23] Adds AsTensor to transforms.py Calls the as_tensor function in functionals and adds the function AsTensor as callable from transforms. --- torchvision/transforms/transforms.py | 24 +++++++++++++++++++++++- 1 file changed, 23 insertions(+), 1 deletion(-) diff --git a/torchvision/transforms/transforms.py b/torchvision/transforms/transforms.py index 10783c8e53d..a7c33bf3691 100644 --- a/torchvision/transforms/transforms.py +++ b/torchvision/transforms/transforms.py @@ -15,7 +15,7 @@ from . import functional as F -__all__ = ["Compose", "ToTensor", "ToPILImage", "Normalize", "Resize", "Scale", "CenterCrop", "Pad", +__all__ = ["Compose", "ToTensor", "AsTensor", "ToPILImage", "Normalize", "Resize", "Scale", "CenterCrop", "Pad", "Lambda", "RandomApply", "RandomChoice", "RandomOrder", "RandomCrop", "RandomHorizontalFlip", "RandomVerticalFlip", "RandomResizedCrop", "RandomSizedCrop", "FiveCrop", "TenCrop", "LinearTransformation", "ColorJitter", "RandomRotation", "RandomAffine", "Grayscale", "RandomGrayscale", @@ -95,6 +95,28 @@ def __repr__(self): return self.__class__.__name__ + '()' +class AsTensor(object): + """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor of the same type. + + Converts a PIL Image or numpy.ndarray (H x W x C) to a torch.Tensor of shape (C x H x W) + if the PIL Image belongs to one of the modes (L, LA, P, I, F, RGB, YCbCr, RGBA, CMYK, 1) + or if the numpy.ndarray has dtype = np.uint8 + """ + + def __call__(self, pic): + """ + Args: + pic (PIL Image or numpy.ndarray): Image to be converted to tensor. + + Returns: + Tensor: Converted image. + """ + return F.as_tensor(pic) + + def __repr__(self): + return self.__class__.__name__ + '()' + + class ToPILImage(object): """Convert a tensor or an ndarray to PIL Image. From 9c2fd3b465ea16033b0b080672612010c319a8c7 Mon Sep 17 00:00:00 2001 From: Steven Basart Date: Fri, 10 Apr 2020 10:49:36 -0400 Subject: [PATCH 04/23] Removes the pic.mode == 'L' This was handled by the else condition previously so I'll remove it. --- torchvision/transforms/functional.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index 746ca9eed77..53e2420a6e3 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -118,7 +118,7 @@ def as_tensor(pic): img = torch.from_numpy(np.array(pic, np.int16, copy=False)) elif pic.mode == 'F': img = torch.from_numpy(np.array(pic, np.float32, copy=False)) - elif pic.mode == '1' or pic.mode == 'L': + elif pic.mode == '1': img = torch.from_numpy(np.array(pic, np.uint8, copy=False)) else: img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes())) From 08ab5ecf3516ffb89e06b90db343d163922d4931 Mon Sep 17 00:00:00 2001 From: Steven Basart Date: Fri, 10 Apr 2020 11:59:59 -0400 Subject: [PATCH 05/23] Fix Lint issue Adds two line breaks between functions to fix lint issue --- torchvision/transforms/functional.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index 53e2420a6e3..0048507e572 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -81,6 +81,7 @@ def to_tensor(pic): else: return img + def as_tensor(pic): """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor of same type. From cb19ed41c0cb19c56fce8ffbd32aa19a462cfdbc Mon Sep 17 00:00:00 2001 From: Steven Basart Date: Wed, 15 Apr 2020 22:49:34 -0400 Subject: [PATCH 06/23] Replace from_numpy with as_tensor Removes the extra if conditionals and replaces from_numpy with as_tensor. --- torchvision/transforms/functional.py | 17 ++++------------- 1 file changed, 4 insertions(+), 13 deletions(-) diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index 0048507e572..cd1696bf7fb 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -104,29 +104,20 @@ def as_tensor(pic): if pic.ndim == 2: pic = pic[:, :, None] - img = torch.from_numpy(pic.transpose((2, 0, 1))) + img = torch.as_tensor(pic.transpose((2, 0, 1))) return img if accimage is not None and isinstance(pic, accimage.Image): nppic = np.zeros([pic.channels, pic.height, pic.width], dtype=np.float32) pic.copyto(nppic) - return torch.from_numpy(nppic) + return torch.as_tensor(nppic) # handle PIL Image - if pic.mode == 'I': - img = torch.from_numpy(np.array(pic, np.int32, copy=False)) - elif pic.mode == 'I;16': - img = torch.from_numpy(np.array(pic, np.int16, copy=False)) - elif pic.mode == 'F': - img = torch.from_numpy(np.array(pic, np.float32, copy=False)) - elif pic.mode == '1': - img = torch.from_numpy(np.array(pic, np.uint8, copy=False)) - else: - img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes())) + img = torch.as_tensor(np.asarray(pic)) img = img.view(pic.size[1], pic.size[0], len(pic.getbands())) # put it from HWC to CHW format - img = img.permute((2, 0, 1)).contiguous() + img = img.permute((2, 0, 1)) return img From 38ad5f3986d6d9c90903ff11c044065a8649dae8 Mon Sep 17 00:00:00 2001 From: Steven Basart Date: Thu, 16 Apr 2020 16:45:55 -0400 Subject: [PATCH 07/23] Renames as_tensor to pil_to_tensor Renames the function as_tensor to pil_to_tensor and narrows the scope of the function. At the same time also creates a flag that defaults to True for swapping to the channels first format. --- torchvision/transforms/functional.py | 29 ++++++++++------------------ 1 file changed, 10 insertions(+), 19 deletions(-) diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index cd1696bf7fb..21a0a32a2cf 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -82,30 +82,20 @@ def to_tensor(pic): return img -def as_tensor(pic): - """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor of same type. +def pil_to_tensor(pic, swap_to_channelsfirst=True): + """Convert a ``PIL Image`` to a tensor of the same type. See ``AsTensor`` for more details. Args: - pic (PIL Image or numpy.ndarray): Image to be converted to tensor. + pic (PIL Image): Image to be converted to tensor. + swap_to_channelsfirst (bool): Boolean indicator to convert to CHW format. Returns: Tensor: Converted image. """ - if not(_is_pil_image(pic) or _is_numpy(pic)): - raise TypeError('pic should be PIL Image or ndarray. Got {}'.format(type(pic))) - - if _is_numpy(pic) and not _is_numpy_image(pic): - raise ValueError('pic should be 2/3 dimensional. Got {} dimensions.'.format(pic.ndim)) - - if isinstance(pic, np.ndarray): - # handle numpy array - if pic.ndim == 2: - pic = pic[:, :, None] - - img = torch.as_tensor(pic.transpose((2, 0, 1))) - return img + if not(_is_pil_image(pic)): + raise TypeError('pic should be PIL Image. Got {}'.format(type(pic))) if accimage is not None and isinstance(pic, accimage.Image): nppic = np.zeros([pic.channels, pic.height, pic.width], dtype=np.float32) @@ -115,9 +105,10 @@ def as_tensor(pic): # handle PIL Image img = torch.as_tensor(np.asarray(pic)) - img = img.view(pic.size[1], pic.size[0], len(pic.getbands())) - # put it from HWC to CHW format - img = img.permute((2, 0, 1)) + if swap_to_channelsfirst: + img = img.view(pic.size[1], pic.size[0], len(pic.getbands())) + # put it from HWC to CHW format + img = img.permute((2, 0, 1)) return img From 1fa91a844075a4b4e7da9568615645a477e2785b Mon Sep 17 00:00:00 2001 From: Steven Basart Date: Thu, 16 Apr 2020 17:04:52 -0400 Subject: [PATCH 08/23] Renames AsTensor to PILToImage Renames the function AsTensor to PILToImage and modifies the description. Adds the swap_to_channelsfirst boolean variable to indicate if the user wishes to change the shape of the input. --- torchvision/transforms/transforms.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/torchvision/transforms/transforms.py b/torchvision/transforms/transforms.py index a7c33bf3691..ee5fd6c3218 100644 --- a/torchvision/transforms/transforms.py +++ b/torchvision/transforms/transforms.py @@ -15,7 +15,7 @@ from . import functional as F -__all__ = ["Compose", "ToTensor", "AsTensor", "ToPILImage", "Normalize", "Resize", "Scale", "CenterCrop", "Pad", +__all__ = ["Compose", "ToTensor", "PILToTensor", "ToPILImage", "Normalize", "Resize", "Scale", "CenterCrop", "Pad", "Lambda", "RandomApply", "RandomChoice", "RandomOrder", "RandomCrop", "RandomHorizontalFlip", "RandomVerticalFlip", "RandomResizedCrop", "RandomSizedCrop", "FiveCrop", "TenCrop", "LinearTransformation", "ColorJitter", "RandomRotation", "RandomAffine", "Grayscale", "RandomGrayscale", @@ -95,23 +95,24 @@ def __repr__(self): return self.__class__.__name__ + '()' -class AsTensor(object): - """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor of the same type. +class PILToTensor(object): + """Convert a ``PIL Image`` to a tensor of the same type. - Converts a PIL Image or numpy.ndarray (H x W x C) to a torch.Tensor of shape (C x H x W) - if the PIL Image belongs to one of the modes (L, LA, P, I, F, RGB, YCbCr, RGBA, CMYK, 1) - or if the numpy.ndarray has dtype = np.uint8 + Converts a PIL Image (H x W x C) to a torch.Tensor. If swap_to_channelsfirst is True + the returned shape will be (C x H x W) otherwise the shape will remain unchanged. """ - def __call__(self, pic): + def __call__(self, pic, swap_to_channelsfirst=True): """ Args: - pic (PIL Image or numpy.ndarray): Image to be converted to tensor. + pic (PIL Image): Image to be converted to tensor. + swap_to_channelsfirst (bool): Boolean indicator to swap to channels first format. + Defaults to True. Returns: Tensor: Converted image. """ - return F.as_tensor(pic) + return F.pil_to_tensor(pic) def __repr__(self): return self.__class__.__name__ + '()' From 0fefbcbd72e24dee2b8ab845434e8e570ce87e31 Mon Sep 17 00:00:00 2001 From: Steven Basart Date: Thu, 16 Apr 2020 17:30:05 -0400 Subject: [PATCH 09/23] Add the __init__ function to PILToTensor Add the __init__ function to PILToTensor since it contains the swap_to_channelsfirst parameter now. --- torchvision/transforms/transforms.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/torchvision/transforms/transforms.py b/torchvision/transforms/transforms.py index ee5fd6c3218..653f9228d29 100644 --- a/torchvision/transforms/transforms.py +++ b/torchvision/transforms/transforms.py @@ -101,8 +101,10 @@ class PILToTensor(object): Converts a PIL Image (H x W x C) to a torch.Tensor. If swap_to_channelsfirst is True the returned shape will be (C x H x W) otherwise the shape will remain unchanged. """ + def __init__(self, swap_to_channelsfirst=True): + self.swap_to_channelsfirst = swap_to_channelsfirst - def __call__(self, pic, swap_to_channelsfirst=True): + def __call__(self, pic): """ Args: pic (PIL Image): Image to be converted to tensor. @@ -112,7 +114,7 @@ def __call__(self, pic, swap_to_channelsfirst=True): Returns: Tensor: Converted image. """ - return F.pil_to_tensor(pic) + return F.pil_to_tensor(pic, self.swap_to_channelsfirst) def __repr__(self): return self.__class__.__name__ + '()' From 7662b23f0e974dfbab3c950cb0cc1588af304ef3 Mon Sep 17 00:00:00 2001 From: Steven Basart Date: Thu, 16 Apr 2020 17:32:26 -0400 Subject: [PATCH 10/23] fix lint issue remove trailing white space --- torchvision/transforms/transforms.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/transforms/transforms.py b/torchvision/transforms/transforms.py index 653f9228d29..4699d71515e 100644 --- a/torchvision/transforms/transforms.py +++ b/torchvision/transforms/transforms.py @@ -108,7 +108,7 @@ def __call__(self, pic): """ Args: pic (PIL Image): Image to be converted to tensor. - swap_to_channelsfirst (bool): Boolean indicator to swap to channels first format. + swap_to_channelsfirst (bool): Boolean indicator to swap to channels first format. Defaults to True. Returns: From 75be7bbfac9f2589fb03f222e1a0945e00e6abff Mon Sep 17 00:00:00 2001 From: Steven Basart Date: Thu, 16 Apr 2020 17:37:24 -0400 Subject: [PATCH 11/23] Fix the tests Reflects the name change to PILToTensor and the parameter to the function as well as the new narrowed scope that the function only accepts PIL images. --- test/test_transforms.py | 36 ++++++++++++++++++++++-------------- 1 file changed, 22 insertions(+), 14 deletions(-) diff --git a/test/test_transforms.py b/test/test_transforms.py index f2c9b71f500..bcdb99be90f 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -512,32 +512,40 @@ def test_accimage_to_tensor(self): self.assertEqual(expected_output.size(), output.size()) self.assertTrue(np.allclose(output.numpy(), expected_output.numpy())) - def test_as_tensor(self): + def test_pil_to_tensor(self): test_channels = [1, 3, 4] height, width = 4, 4 - trans = transforms.AsTensor() + trans = transforms.PILToTensor() + trans_noswap = transforms.PILToTensor(swap_to_channelsfirst=False) with self.assertRaises(TypeError): trans(np.random.rand(1, height, width).tolist()) - - with self.assertRaises(ValueError): - trans(np.random.rand(height)) - trans(np.random.rand(1, 1, height, width)) + trans(np.random.rand(1, height, width)) for channels in test_channels: input_data = torch.ByteTensor(channels, height, width).random_(0, 255) img = transforms.ToPILImage()(input_data) output = trans(img) self.assertTrue(np.allclose(input_data.numpy(), output.numpy())) + output = trans_noswap(img) + self.assertTrue(np.allclose(input_data.numpy(), output.numpy())) - ndarray = np.random.randint(low=0, high=255, size=(height, width, channels)).astype(np.uint8) - output = trans(ndarray) - expected_output = ndarray.transpose((2, 0, 1)) + input_data = np.random.randint(low=0, high=255, size=(height, width, channels)).astype(np.uint8) + img = transforms.ToPILImage()(input_data) + output = trans(img) + expected_output = input_data.transpose((2, 0, 1)) + self.assertTrue(np.allclose(output.numpy(), expected_output)) + output = trans_noswap(img) + expected_output = input_data self.assertTrue(np.allclose(output.numpy(), expected_output)) - ndarray = np.random.rand(height, width, channels).astype(np.float32) - output = trans(ndarray) - expected_output = ndarray.transpose((2, 0, 1)) + input_data = np.random.rand(height, width, channels).astype(np.float32) + img = transforms.ToPILImage()(input_data) + output = trans(img) + expected_output = input_data.transpose((2, 0, 1)) + self.assertTrue(np.allclose(output.numpy(), expected_output)) + output = trans_noswap(img) + expected_output = input_data self.assertTrue(np.allclose(output.numpy(), expected_output)) # separate test for mode '1' PIL images @@ -547,8 +555,8 @@ def test_as_tensor(self): self.assertTrue(np.allclose(input_data.numpy(), output.numpy())) @unittest.skipIf(accimage is None, 'accimage not available') - def test_accimage_as_tensor(self): - trans = transforms.AsTensor() + def test_accimage_pil_to_tensor(self): + trans = transforms.PILToTensor() expected_output = trans(Image.open(GRACE_HOPPER).convert('RGB')) output = trans(accimage.Image(GRACE_HOPPER)) From 123503acb65b6ebf2358500665df15833bf1ed00 Mon Sep 17 00:00:00 2001 From: Steven Basart Date: Thu, 16 Apr 2020 17:54:55 -0400 Subject: [PATCH 12/23] fix tests Instead of undoing the transpose just create a new tensor and test that one. --- test/test_transforms.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/test/test_transforms.py b/test/test_transforms.py index bcdb99be90f..fd681fe4cde 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -527,6 +527,9 @@ def test_pil_to_tensor(self): img = transforms.ToPILImage()(input_data) output = trans(img) self.assertTrue(np.allclose(input_data.numpy(), output.numpy())) + + input_data = torch.ByteTensor(channels, height, width).random_(0, 255) + img = transforms.ToPILImage()(input_data) output = trans_noswap(img) self.assertTrue(np.allclose(input_data.numpy(), output.numpy())) @@ -535,6 +538,9 @@ def test_pil_to_tensor(self): output = trans(img) expected_output = input_data.transpose((2, 0, 1)) self.assertTrue(np.allclose(output.numpy(), expected_output)) + + input_data = np.random.randint(low=0, high=255, size=(height, width, channels)).astype(np.uint8) + img = transforms.ToPILImage()(input_data) output = trans_noswap(img) expected_output = input_data self.assertTrue(np.allclose(output.numpy(), expected_output)) @@ -544,6 +550,10 @@ def test_pil_to_tensor(self): output = trans(img) expected_output = input_data.transpose((2, 0, 1)) self.assertTrue(np.allclose(output.numpy(), expected_output)) + input_data = np.random.rand(height, width, channels).astype(np.float32) + + input_data = np.random.rand(height, width, channels).astype(np.float32) + img = transforms.ToPILImage()(input_data) output = trans_noswap(img) expected_output = input_data self.assertTrue(np.allclose(output.numpy(), expected_output)) From eff1db0954867be1df7f061119bc7b7f12de43bf Mon Sep 17 00:00:00 2001 From: Steven Basart Date: Thu, 16 Apr 2020 19:46:54 -0400 Subject: [PATCH 13/23] Add the view back Add img.view(pic.size[1], pic.size[0], len(pic.getbands())) back to outside the if condition. --- torchvision/transforms/functional.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index 21a0a32a2cf..ed48af7ef36 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -104,9 +104,9 @@ def pil_to_tensor(pic, swap_to_channelsfirst=True): # handle PIL Image img = torch.as_tensor(np.asarray(pic)) + img = img.view(pic.size[1], pic.size[0], len(pic.getbands())) if swap_to_channelsfirst: - img = img.view(pic.size[1], pic.size[0], len(pic.getbands())) # put it from HWC to CHW format img = img.permute((2, 0, 1)) return img From 266860af078e27d57c5cabee34e02fdaecd2c2c9 Mon Sep 17 00:00:00 2001 From: Steven Basart Date: Thu, 16 Apr 2020 20:09:41 -0400 Subject: [PATCH 14/23] fix test fix conversion from torch tensor to PIL back to torch tensor. --- test/test_transforms.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/test_transforms.py b/test/test_transforms.py index fd681fe4cde..bf525bee003 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -529,8 +529,8 @@ def test_pil_to_tensor(self): self.assertTrue(np.allclose(input_data.numpy(), output.numpy())) input_data = torch.ByteTensor(channels, height, width).random_(0, 255) - img = transforms.ToPILImage()(input_data) - output = trans_noswap(img) + img = transforms.ToPILImage()(input_data) #HWC + output = trans_noswap(img).transpose((2, 0, 1)) #HWC -> CHW self.assertTrue(np.allclose(input_data.numpy(), output.numpy())) input_data = np.random.randint(low=0, high=255, size=(height, width, channels)).astype(np.uint8) From 610fc1e6dba9943b336e570547dc2fd1e44332a8 Mon Sep 17 00:00:00 2001 From: Steven Basart Date: Thu, 16 Apr 2020 20:12:24 -0400 Subject: [PATCH 15/23] fix lint issues --- test/test_transforms.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/test_transforms.py b/test/test_transforms.py index bf525bee003..fc9ca302941 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -529,8 +529,8 @@ def test_pil_to_tensor(self): self.assertTrue(np.allclose(input_data.numpy(), output.numpy())) input_data = torch.ByteTensor(channels, height, width).random_(0, 255) - img = transforms.ToPILImage()(input_data) #HWC - output = trans_noswap(img).transpose((2, 0, 1)) #HWC -> CHW + img = transforms.ToPILImage()(input_data) # HWC + output = trans_noswap(img).transpose((2, 0, 1)) # HWC -> CHW self.assertTrue(np.allclose(input_data.numpy(), output.numpy())) input_data = np.random.randint(low=0, high=255, size=(height, width, channels)).astype(np.uint8) From b9cca779907e2e849fdc580d36cd78584e36f93a Mon Sep 17 00:00:00 2001 From: Steven Basart Date: Thu, 16 Apr 2020 20:14:09 -0400 Subject: [PATCH 16/23] fix lint remove trailing white space --- test/test_transforms.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_transforms.py b/test/test_transforms.py index fc9ca302941..a23196dc7ec 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -530,7 +530,7 @@ def test_pil_to_tensor(self): input_data = torch.ByteTensor(channels, height, width).random_(0, 255) img = transforms.ToPILImage()(input_data) # HWC - output = trans_noswap(img).transpose((2, 0, 1)) # HWC -> CHW + output = trans_noswap(img).transpose((2, 0, 1)) # HWC -> CHW self.assertTrue(np.allclose(input_data.numpy(), output.numpy())) input_data = np.random.randint(low=0, high=255, size=(height, width, channels)).astype(np.uint8) From 1b10f779dbf56159553a69b665884e1a6bf2fa2e Mon Sep 17 00:00:00 2001 From: Steven Basart Date: Thu, 16 Apr 2020 20:27:34 -0400 Subject: [PATCH 17/23] Fixed the channel swapping tensor test Torch tranpose operates differently than numpy transpose. Changed operation to permute. --- test/test_transforms.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_transforms.py b/test/test_transforms.py index a23196dc7ec..787a211bde7 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -530,7 +530,7 @@ def test_pil_to_tensor(self): input_data = torch.ByteTensor(channels, height, width).random_(0, 255) img = transforms.ToPILImage()(input_data) # HWC - output = trans_noswap(img).transpose((2, 0, 1)) # HWC -> CHW + output = trans_noswap(img).permute(2, 0, 1) # HWC -> CHW self.assertTrue(np.allclose(input_data.numpy(), output.numpy())) input_data = np.random.randint(low=0, high=255, size=(height, width, channels)).astype(np.uint8) From fbf661c656a68d1826d3b24fac0395f3384d6301 Mon Sep 17 00:00:00 2001 From: Steven Basart Date: Thu, 16 Apr 2020 20:42:46 -0400 Subject: [PATCH 18/23] Add mode='F' Add mode information when converting to PIL Image from Float Tensor. --- test/test_transforms.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/test/test_transforms.py b/test/test_transforms.py index 787a211bde7..fe227890db1 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -546,14 +546,13 @@ def test_pil_to_tensor(self): self.assertTrue(np.allclose(output.numpy(), expected_output)) input_data = np.random.rand(height, width, channels).astype(np.float32) - img = transforms.ToPILImage()(input_data) + img = transforms.ToPILImage(mode='F')(input_data) output = trans(img) expected_output = input_data.transpose((2, 0, 1)) self.assertTrue(np.allclose(output.numpy(), expected_output)) - input_data = np.random.rand(height, width, channels).astype(np.float32) input_data = np.random.rand(height, width, channels).astype(np.float32) - img = transforms.ToPILImage()(input_data) + img = transforms.ToPILImage(mode='F')(input_data) output = trans_noswap(img) expected_output = input_data self.assertTrue(np.allclose(output.numpy(), expected_output)) From 598107f577eb646f2f3450fa8de7f396442715e5 Mon Sep 17 00:00:00 2001 From: Steven Basart Date: Thu, 16 Apr 2020 21:46:18 -0400 Subject: [PATCH 19/23] Added inline comments to follow shape changes --- test/test_transforms.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/test/test_transforms.py b/test/test_transforms.py index fe227890db1..ef87ddd84c7 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -529,7 +529,7 @@ def test_pil_to_tensor(self): self.assertTrue(np.allclose(input_data.numpy(), output.numpy())) input_data = torch.ByteTensor(channels, height, width).random_(0, 255) - img = transforms.ToPILImage()(input_data) # HWC + img = transforms.ToPILImage()(input_data) # CHW -> HWC output = trans_noswap(img).permute(2, 0, 1) # HWC -> CHW self.assertTrue(np.allclose(input_data.numpy(), output.numpy())) @@ -545,17 +545,17 @@ def test_pil_to_tensor(self): expected_output = input_data self.assertTrue(np.allclose(output.numpy(), expected_output)) - input_data = np.random.rand(height, width, channels).astype(np.float32) - img = transforms.ToPILImage(mode='F')(input_data) - output = trans(img) - expected_output = input_data.transpose((2, 0, 1)) - self.assertTrue(np.allclose(output.numpy(), expected_output)) - - input_data = np.random.rand(height, width, channels).astype(np.float32) - img = transforms.ToPILImage(mode='F')(input_data) - output = trans_noswap(img) + input_data = torch.as_tensor(np.random.rand(channels, height, width).astype(np.float32)) + img = transforms.ToPILImage()(input_data) # CHW -> HWC + output = trans(img) # HWC -> CHW expected_output = input_data - self.assertTrue(np.allclose(output.numpy(), expected_output)) + self.assertTrue(np.allclose(output.numpy(), expected_output.numpy())) + + input_data = torch.as_tensor(np.random.rand(channels, height, width).astype(np.float32)) + img = transforms.ToPILImage()(input_data) # CHW -> HWC + output = trans_noswap(img) # HWC -> HWC + expected_output = input_data.permute(1, 2, 0) # CHW -> HWC + self.assertTrue(np.allclose(output.numpy(), expected_output.numpy())) # separate test for mode '1' PIL images input_data = torch.ByteTensor(1, height, width).bernoulli_() From d69048e15aad8a6c97292a1471f142ec2a9a61b6 Mon Sep 17 00:00:00 2001 From: Steven Basart Date: Thu, 16 Apr 2020 22:03:36 -0400 Subject: [PATCH 20/23] ToPILImage converts FloatTensors to uint8 --- test/test_transforms.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/test_transforms.py b/test/test_transforms.py index ef87ddd84c7..059873e5f29 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -546,15 +546,15 @@ def test_pil_to_tensor(self): self.assertTrue(np.allclose(output.numpy(), expected_output)) input_data = torch.as_tensor(np.random.rand(channels, height, width).astype(np.float32)) - img = transforms.ToPILImage()(input_data) # CHW -> HWC + img = transforms.ToPILImage()(input_data) # CHW -> HWC and (* 255).byte() output = trans(img) # HWC -> CHW - expected_output = input_data + expected_output = (input_data * 255).byte() self.assertTrue(np.allclose(output.numpy(), expected_output.numpy())) input_data = torch.as_tensor(np.random.rand(channels, height, width).astype(np.float32)) img = transforms.ToPILImage()(input_data) # CHW -> HWC output = trans_noswap(img) # HWC -> HWC - expected_output = input_data.permute(1, 2, 0) # CHW -> HWC + expected_output = (input_data.permute(1, 2, 0) * 255).byte() # CHW -> HWC self.assertTrue(np.allclose(output.numpy(), expected_output.numpy())) # separate test for mode '1' PIL images From fa1084c950b77396835660218ab0c7109a68e0cd Mon Sep 17 00:00:00 2001 From: Steven Basart Date: Fri, 15 May 2020 15:49:05 -0400 Subject: [PATCH 21/23] Remove testing not swapping --- test/test_transforms.py | 18 ------------------ 1 file changed, 18 deletions(-) diff --git a/test/test_transforms.py b/test/test_transforms.py index 059873e5f29..978ab823c95 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -516,7 +516,6 @@ def test_pil_to_tensor(self): test_channels = [1, 3, 4] height, width = 4, 4 trans = transforms.PILToTensor() - trans_noswap = transforms.PILToTensor(swap_to_channelsfirst=False) with self.assertRaises(TypeError): trans(np.random.rand(1, height, width).tolist()) @@ -528,35 +527,18 @@ def test_pil_to_tensor(self): output = trans(img) self.assertTrue(np.allclose(input_data.numpy(), output.numpy())) - input_data = torch.ByteTensor(channels, height, width).random_(0, 255) - img = transforms.ToPILImage()(input_data) # CHW -> HWC - output = trans_noswap(img).permute(2, 0, 1) # HWC -> CHW - self.assertTrue(np.allclose(input_data.numpy(), output.numpy())) - input_data = np.random.randint(low=0, high=255, size=(height, width, channels)).astype(np.uint8) img = transforms.ToPILImage()(input_data) output = trans(img) expected_output = input_data.transpose((2, 0, 1)) self.assertTrue(np.allclose(output.numpy(), expected_output)) - input_data = np.random.randint(low=0, high=255, size=(height, width, channels)).astype(np.uint8) - img = transforms.ToPILImage()(input_data) - output = trans_noswap(img) - expected_output = input_data - self.assertTrue(np.allclose(output.numpy(), expected_output)) - input_data = torch.as_tensor(np.random.rand(channels, height, width).astype(np.float32)) img = transforms.ToPILImage()(input_data) # CHW -> HWC and (* 255).byte() output = trans(img) # HWC -> CHW expected_output = (input_data * 255).byte() self.assertTrue(np.allclose(output.numpy(), expected_output.numpy())) - input_data = torch.as_tensor(np.random.rand(channels, height, width).astype(np.float32)) - img = transforms.ToPILImage()(input_data) # CHW -> HWC - output = trans_noswap(img) # HWC -> HWC - expected_output = (input_data.permute(1, 2, 0) * 255).byte() # CHW -> HWC - self.assertTrue(np.allclose(output.numpy(), expected_output.numpy())) - # separate test for mode '1' PIL images input_data = torch.ByteTensor(1, height, width).bernoulli_() img = transforms.ToPILImage()(input_data.mul(255)).convert('1') From 2cb7a4fe1a8694a396489fb9b189dc8735b97645 Mon Sep 17 00:00:00 2001 From: Steven Basart Date: Fri, 15 May 2020 15:53:09 -0400 Subject: [PATCH 22/23] Removes the swap_channelsfirst parameter Makes the channel swapping the default behavior. --- torchvision/transforms/transforms.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/torchvision/transforms/transforms.py b/torchvision/transforms/transforms.py index 4699d71515e..eb49b99be93 100644 --- a/torchvision/transforms/transforms.py +++ b/torchvision/transforms/transforms.py @@ -98,23 +98,18 @@ def __repr__(self): class PILToTensor(object): """Convert a ``PIL Image`` to a tensor of the same type. - Converts a PIL Image (H x W x C) to a torch.Tensor. If swap_to_channelsfirst is True - the returned shape will be (C x H x W) otherwise the shape will remain unchanged. + Converts a PIL Image (H x W x C) to a torch.Tensor of shape (C x H x W). """ - def __init__(self, swap_to_channelsfirst=True): - self.swap_to_channelsfirst = swap_to_channelsfirst def __call__(self, pic): """ Args: pic (PIL Image): Image to be converted to tensor. - swap_to_channelsfirst (bool): Boolean indicator to swap to channels first format. - Defaults to True. Returns: Tensor: Converted image. """ - return F.pil_to_tensor(pic, self.swap_to_channelsfirst) + return F.pil_to_tensor(pic) def __repr__(self): return self.__class__.__name__ + '()' From 3d565fd721fa25be831c0bf65604d7d181f7fd63 Mon Sep 17 00:00:00 2001 From: Steven Basart Date: Fri, 15 May 2020 16:10:17 -0400 Subject: [PATCH 23/23] Remove the swap_channelsfirst argument Remove the swap_channelsfirst argument and makes the swapping the default functionality. --- torchvision/transforms/functional.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index ed48af7ef36..bdfa6567a82 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -82,14 +82,13 @@ def to_tensor(pic): return img -def pil_to_tensor(pic, swap_to_channelsfirst=True): +def pil_to_tensor(pic): """Convert a ``PIL Image`` to a tensor of the same type. See ``AsTensor`` for more details. Args: pic (PIL Image): Image to be converted to tensor. - swap_to_channelsfirst (bool): Boolean indicator to convert to CHW format. Returns: Tensor: Converted image. @@ -105,10 +104,8 @@ def pil_to_tensor(pic, swap_to_channelsfirst=True): # handle PIL Image img = torch.as_tensor(np.asarray(pic)) img = img.view(pic.size[1], pic.size[0], len(pic.getbands())) - - if swap_to_channelsfirst: - # put it from HWC to CHW format - img = img.permute((2, 0, 1)) + # put it from HWC to CHW format + img = img.permute((2, 0, 1)) return img