Skip to content

Commit

Permalink
Add pil_to_tensor to functionals (#2092)
Browse files Browse the repository at this point in the history
* Adds as_tensor to functional.py

Similar functionality to to_tensor without the default conversion to float and division by 255.
Also adds support for Image mode 'L'.

* Adds tests to AsTensor()

Adds tests to AsTensor and removes the conversion to float and division by 255.

* Adds AsTensor to transforms.py

Calls the as_tensor function in functionals and adds the function AsTensor as callable from transforms.

* Removes the pic.mode == 'L'

This was handled by the else condition previously so I'll remove it.

* Fix Lint issue

Adds two line breaks between functions to fix lint issue

* Replace from_numpy with as_tensor

Removes the extra if conditionals and replaces from_numpy with as_tensor.

* Renames as_tensor to pil_to_tensor

Renames the function as_tensor to pil_to_tensor and narrows the scope of the function.  At the same time also creates a flag that defaults to True for swapping to the channels first format.

* Renames AsTensor to PILToImage

Renames the function AsTensor to PILToImage and modifies the description.  Adds the swap_to_channelsfirst boolean variable to indicate if the user wishes to change the shape of the input.

* Add the __init__ function to PILToTensor 

Add the __init__ function to PILToTensor since it contains the swap_to_channelsfirst parameter now.

* fix lint issue

remove trailing white space

* Fix the tests

Reflects the name change to PILToTensor and the parameter to the function as well as the new narrowed scope that the function only accepts PIL images.

* fix tests

Instead of undoing the transpose just create a new tensor and test that one.

* Add the view back

Add img.view(pic.size[1], pic.size[0], len(pic.getbands())) back to outside the if condition.

* fix test

fix conversion from torch tensor to PIL back to torch tensor.

* fix lint issues

* fix lint

remove trailing white space

* Fixed the channel swapping tensor test

Torch tranpose operates differently than numpy transpose.  Changed operation to permute.

* Add mode='F'

Add mode information when converting to PIL Image from Float Tensor.

* Added inline comments to follow shape changes

* ToPILImage converts FloatTensors to uint8

* Remove testing not swapping

* Removes the swap_channelsfirst parameter

Makes the channel swapping the default behavior.

* Remove the swap_channelsfirst argument

Remove the swap_channelsfirst argument and makes the swapping the default functionality.
  • Loading branch information
xksteven authored May 18, 2020
1 parent e2e511b commit e6d3f8c
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 1 deletion.
43 changes: 43 additions & 0 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,6 +511,49 @@ def test_accimage_to_tensor(self):
self.assertEqual(expected_output.size(), output.size())
self.assertTrue(np.allclose(output.numpy(), expected_output.numpy()))

def test_pil_to_tensor(self):
test_channels = [1, 3, 4]
height, width = 4, 4
trans = transforms.PILToTensor()

with self.assertRaises(TypeError):
trans(np.random.rand(1, height, width).tolist())
trans(np.random.rand(1, height, width))

for channels in test_channels:
input_data = torch.ByteTensor(channels, height, width).random_(0, 255)
img = transforms.ToPILImage()(input_data)
output = trans(img)
self.assertTrue(np.allclose(input_data.numpy(), output.numpy()))

input_data = np.random.randint(low=0, high=255, size=(height, width, channels)).astype(np.uint8)
img = transforms.ToPILImage()(input_data)
output = trans(img)
expected_output = input_data.transpose((2, 0, 1))
self.assertTrue(np.allclose(output.numpy(), expected_output))

input_data = torch.as_tensor(np.random.rand(channels, height, width).astype(np.float32))
img = transforms.ToPILImage()(input_data) # CHW -> HWC and (* 255).byte()
output = trans(img) # HWC -> CHW
expected_output = (input_data * 255).byte()
self.assertTrue(np.allclose(output.numpy(), expected_output.numpy()))

# separate test for mode '1' PIL images
input_data = torch.ByteTensor(1, height, width).bernoulli_()
img = transforms.ToPILImage()(input_data.mul(255)).convert('1')
output = trans(img)
self.assertTrue(np.allclose(input_data.numpy(), output.numpy()))

@unittest.skipIf(accimage is None, 'accimage not available')
def test_accimage_pil_to_tensor(self):
trans = transforms.PILToTensor()

expected_output = trans(Image.open(GRACE_HOPPER).convert('RGB'))
output = trans(accimage.Image(GRACE_HOPPER))

self.assertEqual(expected_output.size(), output.size())
self.assertTrue(np.allclose(output.numpy(), expected_output.numpy()))

@unittest.skipIf(accimage is None, 'accimage not available')
def test_accimage_resize(self):
trans = transforms.Compose([
Expand Down
27 changes: 27 additions & 0 deletions torchvision/transforms/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,33 @@ def to_tensor(pic):
return img


def pil_to_tensor(pic):
"""Convert a ``PIL Image`` to a tensor of the same type.
See ``AsTensor`` for more details.
Args:
pic (PIL Image): Image to be converted to tensor.
Returns:
Tensor: Converted image.
"""
if not(_is_pil_image(pic)):
raise TypeError('pic should be PIL Image. Got {}'.format(type(pic)))

if accimage is not None and isinstance(pic, accimage.Image):
nppic = np.zeros([pic.channels, pic.height, pic.width], dtype=np.float32)
pic.copyto(nppic)
return torch.as_tensor(nppic)

# handle PIL Image
img = torch.as_tensor(np.asarray(pic))
img = img.view(pic.size[1], pic.size[0], len(pic.getbands()))
# put it from HWC to CHW format
img = img.permute((2, 0, 1))
return img


def to_pil_image(pic, mode=None):
"""Convert a tensor or an ndarray to PIL Image.
Expand Down
22 changes: 21 additions & 1 deletion torchvision/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from . import functional as F


__all__ = ["Compose", "ToTensor", "ToPILImage", "Normalize", "Resize", "Scale", "CenterCrop", "Pad",
__all__ = ["Compose", "ToTensor", "PILToTensor", "ToPILImage", "Normalize", "Resize", "Scale", "CenterCrop", "Pad",
"Lambda", "RandomApply", "RandomChoice", "RandomOrder", "RandomCrop", "RandomHorizontalFlip",
"RandomVerticalFlip", "RandomResizedCrop", "RandomSizedCrop", "FiveCrop", "TenCrop", "LinearTransformation",
"ColorJitter", "RandomRotation", "RandomAffine", "Grayscale", "RandomGrayscale",
Expand Down Expand Up @@ -95,6 +95,26 @@ def __repr__(self):
return self.__class__.__name__ + '()'


class PILToTensor(object):
"""Convert a ``PIL Image`` to a tensor of the same type.
Converts a PIL Image (H x W x C) to a torch.Tensor of shape (C x H x W).
"""

def __call__(self, pic):
"""
Args:
pic (PIL Image): Image to be converted to tensor.
Returns:
Tensor: Converted image.
"""
return F.pil_to_tensor(pic)

def __repr__(self):
return self.__class__.__name__ + '()'


class ToPILImage(object):
"""Convert a tensor or an ndarray to PIL Image.
Expand Down

0 comments on commit e6d3f8c

Please sign in to comment.