Skip to content

Commit

Permalink
Add mode to ToPILImage constructor (#300)
Browse files Browse the repository at this point in the history
* add assert for byte tensor to_pil_image

* add mode param to ToPILImage

* add some more tests

* fix formatiing on error messages

* added tests for to_pil_image with mode param

* flake8 fixes
  • Loading branch information
alykhantejani authored and fmassa committed Oct 17, 2017
1 parent 005adfd commit 901c1ad
Show file tree
Hide file tree
Showing 2 changed files with 155 additions and 95 deletions.
189 changes: 108 additions & 81 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,109 +302,136 @@ def test_accimage_crop(self):
self.assertEqual(expected_output.size(), output.size())
assert np.allclose(output.numpy(), expected_output.numpy())

def test_tensor_to_pil_image(self):
trans = transforms.ToPILImage()
to_tensor = transforms.ToTensor()

img_data = torch.Tensor(3, 4, 4).uniform_()
img = trans(img_data)
assert img.getbands() == ('R', 'G', 'B')
r, g, b = img.split()

expected_output = img_data.mul(255).int().float().div(255)
assert np.allclose(expected_output[0].numpy(), to_tensor(r).numpy())
assert np.allclose(expected_output[1].numpy(), to_tensor(g).numpy())
assert np.allclose(expected_output[2].numpy(), to_tensor(b).numpy())

# single channel image
img_data = torch.Tensor(1, 4, 4).uniform_()
img = trans(img_data)
assert img.getbands() == ('L',)
l, = img.split()
expected_output = img_data.mul(255).int().float().div(255)
assert np.allclose(expected_output[0].numpy(), to_tensor(l).numpy())

def test_tensor_gray_to_pil_image(self):
trans = transforms.ToPILImage()
def test_1_channel_tensor_to_pil_image(self):
to_tensor = transforms.ToTensor()

img_data_float = torch.Tensor(1, 4, 4).uniform_()
img_data_byte = torch.ByteTensor(1, 4, 4).random_(0, 255)
img_data_short = torch.ShortTensor(1, 4, 4).random_()
img_data_int = torch.IntTensor(1, 4, 4).random_()

img_byte = trans(img_data_byte)
img_short = trans(img_data_short)
img_int = trans(img_data_int)
assert img_byte.mode == 'L'
assert img_short.mode == 'I;16'
assert img_int.mode == 'I'
inputs = [img_data_float, img_data_byte, img_data_short, img_data_int]
expected_outputs = [img_data_float.mul(255).int().float().div(255).numpy(),
img_data_byte.float().div(255.0).numpy(),
img_data_short.numpy(),
img_data_int.numpy()]
expected_modes = ['L', 'L', 'I;16', 'I']

for img_data, expected_output, mode in zip(inputs, expected_outputs, expected_modes):
for transform in [transforms.ToPILImage(), transforms.ToPILImage(mode=mode)]:
img = transform(img_data)
assert img.mode == mode
assert np.allclose(expected_output, to_tensor(img).numpy())

def test_1_channel_ndarray_to_pil_image(self):
img_data_float = torch.Tensor(4, 4, 1).uniform_().numpy()
img_data_byte = torch.ByteTensor(4, 4, 1).random_(0, 255).numpy()
img_data_short = torch.ShortTensor(4, 4, 1).random_().numpy()
img_data_int = torch.IntTensor(4, 4, 1).random_().numpy()

inputs = [img_data_float, img_data_byte, img_data_short, img_data_int]
expected_modes = ['F', 'L', 'I;16', 'I']
for img_data, mode in zip(inputs, expected_modes):
for transform in [transforms.ToPILImage(), transforms.ToPILImage(mode=mode)]:
img = transform(img_data)
assert img.mode == mode
assert np.allclose(img_data[:, :, 0], img)

def test_3_channel_tensor_to_pil_image(self):
def verify_img_data(img_data, expected_output, mode):
if mode is None:
img = transforms.ToPILImage()(img_data)
assert img.mode == 'RGB' # default should assume RGB
else:
img = transforms.ToPILImage(mode=mode)(img_data)
assert img.mode == mode
split = img.split()
for i in range(3):
assert np.allclose(expected_output[i].numpy(), transforms.to_tensor(split[i]).numpy())

assert np.allclose(img_data_short.numpy(), to_tensor(img_short).numpy())
assert np.allclose(img_data_int.numpy(), to_tensor(img_int).numpy())
img_data = torch.Tensor(3, 4, 4).uniform_()
expected_output = img_data.mul(255).int().float().div(255)
for mode in [None, 'RGB', 'HSV', 'YCbCr']:
verify_img_data(img_data, expected_output, mode=mode)

def test_tensor_rgba_to_pil_image(self):
trans = transforms.ToPILImage()
to_tensor = transforms.ToTensor()
with self.assertRaises(ValueError):
# should raise if we try a mode for 4 or 1 channel images
transforms.ToPILImage(mode='RGBA')(img_data)
transforms.ToPILImage(mode='P')(img_data)

def test_3_channel_ndarray_to_pil_image(self):
def verify_img_data(img_data, mode):
if mode is None:
img = transforms.ToPILImage()(img_data)
assert img.mode == 'RGB' # default should assume RGB
else:
img = transforms.ToPILImage(mode=mode)(img_data)
assert img.mode == mode
split = img.split()
for i in range(3):
assert np.allclose(img_data[:, :, i], split[i])

img_data = torch.Tensor(4, 4, 4).uniform_()
img = trans(img_data)
assert img.mode == 'RGBA'
assert img.getbands() == ('R', 'G', 'B', 'A')
r, g, b, a = img.split()
img_data = torch.ByteTensor(4, 4, 3).random_(0, 255).numpy()
for mode in [None, 'RGB', 'HSV', 'YCbCr']:
verify_img_data(img_data, mode)

with self.assertRaises(ValueError):
# should raise if we try a mode for 4 or 1 channel images
transforms.ToPILImage(mode='RGBA')(img_data)
transforms.ToPILImage(mode='P')(img_data)

def test_4_channel_tensor_to_pil_image(self):
def verify_img_data(img_data, expected_output, mode):
if mode is None:
img = transforms.ToPILImage()(img_data)
assert img.mode == 'RGBA' # default should assume RGBA
else:
img = transforms.ToPILImage(mode=mode)(img_data)
assert img.mode == mode

split = img.split()
for i in range(4):
assert np.allclose(expected_output[i].numpy(), transforms.to_tensor(split[i]).numpy())

img_data = torch.Tensor(4, 4, 4).uniform_()
expected_output = img_data.mul(255).int().float().div(255)
assert np.allclose(expected_output[0].numpy(), to_tensor(r).numpy())
assert np.allclose(expected_output[1].numpy(), to_tensor(g).numpy())
assert np.allclose(expected_output[2].numpy(), to_tensor(b).numpy())
assert np.allclose(expected_output[3].numpy(), to_tensor(a).numpy())
for mode in [None, 'RGBA', 'CMYK']:
verify_img_data(img_data, expected_output, mode)

def test_ndarray_to_pil_image(self):
trans = transforms.ToPILImage()
img_data = torch.ByteTensor(4, 4, 3).random_(0, 255).numpy()
img = trans(img_data)
assert img.getbands() == ('R', 'G', 'B')
r, g, b = img.split()
with self.assertRaises(ValueError):
# should raise if we try a mode for 3 or 1 channel images
transforms.ToPILImage(mode='RGB')(img_data)
transforms.ToPILImage(mode='P')(img_data)

def test_4_channel_ndarray_to_pil_image(self):
def verify_img_data(img_data, mode):
if mode is None:
img = transforms.ToPILImage()(img_data)
assert img.mode == 'RGBA' # default should assume RGBA
else:
img = transforms.ToPILImage(mode=mode)(img_data)
assert img.mode == mode
split = img.split()
for i in range(4):
assert np.allclose(img_data[:, :, i], split[i])

assert np.allclose(r, img_data[:, :, 0])
assert np.allclose(g, img_data[:, :, 1])
assert np.allclose(b, img_data[:, :, 2])
img_data = torch.ByteTensor(4, 4, 4).random_(0, 255).numpy()
for mode in [None, 'RGBA', 'CMYK']:
verify_img_data(img_data, mode)

# single channel image
img_data = torch.ByteTensor(4, 4, 1).random_(0, 255).numpy()
img = trans(img_data)
assert img.getbands() == ('L',)
l, = img.split()
assert np.allclose(l, img_data[:, :, 0])
with self.assertRaises(ValueError):
# should raise if we try a mode for 3 or 1 channel images
transforms.ToPILImage(mode='RGB')(img_data)
transforms.ToPILImage(mode='P')(img_data)

def test_ndarray_bad_types_to_pil_image(self):
trans = transforms.ToPILImage()
with self.assertRaises(AssertionError):
with self.assertRaises(TypeError):
trans(np.ones([4, 4, 1], np.int64))
trans(np.ones([4, 4, 1], np.uint16))
trans(np.ones([4, 4, 1], np.uint32))
trans(np.ones([4, 4, 1], np.float64))

def test_ndarray_gray_float32_to_pil_image(self):
trans = transforms.ToPILImage()
img_data = torch.FloatTensor(4, 4, 1).random_().numpy()
img = trans(img_data)
assert img.mode == 'F'
assert np.allclose(img, img_data[:, :, 0])

def test_ndarray_gray_int16_to_pil_image(self):
trans = transforms.ToPILImage()
img_data = torch.ShortTensor(4, 4, 1).random_().numpy()
img = trans(img_data)
assert img.mode == 'I;16'
assert np.allclose(img, img_data[:, :, 0])

def test_ndarray_gray_int32_to_pil_image(self):
trans = transforms.ToPILImage()
img_data = torch.IntTensor(4, 4, 1).random_().numpy()
img = trans(img_data)
assert img.mode == 'I'
assert np.allclose(img, img_data[:, :, 0])

@unittest.skipIf(stats is None, 'scipy.stats not available')
def test_random_vertical_flip(self):
random_state = random.getstate()
Expand Down
61 changes: 47 additions & 14 deletions torchvision/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,13 +78,16 @@ def to_tensor(pic):
return img


def to_pil_image(pic):
def to_pil_image(pic, mode=None):
"""Convert a tensor or an ndarray to PIL Image.
See ``ToPIlImage`` for more details.
See :class:`~torchvision.transforms.ToPIlImage` for more details.
Args:
pic (Tensor or numpy.ndarray): Image to be converted to PIL Image.
mode (`PIL.Image mode`_): color space and pixel depth of input data (optional).
.. _PIL.Image mode: http://pillow.readthedocs.io/en/3.4.x/handbook/concepts.html#modes
Returns:
PIL Image: Image converted to PIL Image.
Expand All @@ -93,30 +96,48 @@ def to_pil_image(pic):
raise TypeError('pic should be Tensor or ndarray. Got {}.'.format(type(pic)))

npimg = pic
mode = None
if isinstance(pic, torch.FloatTensor):
pic = pic.mul(255).byte()
if torch.is_tensor(pic):
npimg = np.transpose(pic.numpy(), (1, 2, 0))
assert isinstance(npimg, np.ndarray)

if not isinstance(npimg, np.ndarray):
raise TypeError('Input pic must be a torch.Tensor or NumPy ndarray, ' +
'not {}'.format(type(npimg)))

if npimg.shape[2] == 1:
expected_mode = None
npimg = npimg[:, :, 0]

if npimg.dtype == np.uint8:
mode = 'L'
expected_mode = 'L'
if npimg.dtype == np.int16:
mode = 'I;16'
expected_mode = 'I;16'
if npimg.dtype == np.int32:
mode = 'I'
expected_mode = 'I'
elif npimg.dtype == np.float32:
mode = 'F'
expected_mode = 'F'
if mode is not None and mode != expected_mode:
raise ValueError("Incorrect mode ({}) supplied for input type {}. Should be {}"
.format(mode, np.dtype, expected_mode))
mode = expected_mode

elif npimg.shape[2] == 4:
if npimg.dtype == np.uint8:
mode = 'RGBA'
permitted_4_channel_modes = ['RGBA', 'CMYK']
if mode is not None and mode not in permitted_4_channel_modes:
raise ValueError("Only modes {} are supported for 4D inputs".format(permitted_4_channel_modes))

if mode is None and npimg.dtype == np.uint8:
mode = 'RGBA'
else:
if npimg.dtype == np.uint8:
permitted_3_channel_modes = ['RGB', 'YCbCr', 'HSV']
if mode is not None and mode not in permitted_3_channel_modes:
raise ValueError("Only modes {} are supported for 3D inputs".format(permitted_3_channel_modes))
if mode is None and npimg.dtype == np.uint8:
mode = 'RGB'
assert mode is not None, '{} is not supported'.format(npimg.dtype)

if mode is None:
raise TypeError('Input type {} is not supported'.format(npimg.dtype))

return Image.fromarray(npimg, mode=mode)


Expand Down Expand Up @@ -540,7 +561,19 @@ class ToPILImage(object):
Converts a torch.*Tensor of shape C x H x W or a numpy ndarray of shape
H x W x C to a PIL Image while preserving the value range.
Args:
mode (`PIL.Image mode`_): color space and pixel depth of input data (optional).
If ``mode`` is ``None`` (default) there are some assumptions made about the input data:
1. If the input has 3 channels, the ``mode`` is assumed to be ``RGB``.
2. If the input has 4 channels, the ``mode`` is assumed to be ``RGBA``.
3. If the input has 1 channel, the ``mode`` is determined by the data type (i,e,
``int``, ``float``, ``short``).
.. _PIL.Image mode: http://pillow.readthedocs.io/en/3.4.x/handbook/concepts.html#modes
"""
def __init__(self, mode=None):
self.mode = mode

def __call__(self, pic):
"""
Expand All @@ -551,7 +584,7 @@ def __call__(self, pic):
PIL Image: Image converted to PIL Image.
"""
return to_pil_image(pic)
return to_pil_image(pic, self.mode)


class Normalize(object):
Expand Down

0 comments on commit 901c1ad

Please sign in to comment.