Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add convert_image_dtype to functionals #2078

Merged
merged 20 commits into from
Jun 11, 2020
Merged
110 changes: 110 additions & 0 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,22 @@
os.path.dirname(os.path.abspath(__file__)), 'assets', 'grace_hopper_517x606.jpg')


def cycle_over(objs):
objs = list(objs)
for idx, obj in enumerate(objs):
yield obj, objs[:idx] + objs[idx + 1:]


def int_dtypes():
yield from iter(
(torch.uint8, torch.int8, torch.int16, torch.short, torch.int32, torch.int, torch.int64, torch.long,)
)


def float_dtypes():
yield from iter((torch.float32, torch.float, torch.float64, torch.double))


class Tester(unittest.TestCase):

def test_crop(self):
Expand Down Expand Up @@ -510,6 +526,100 @@ def test_to_tensor(self):
output = trans(img)
self.assertTrue(np.allclose(input_data.numpy(), output.numpy()))

def test_convert_image_dtype_float_to_float(self):
for input_dtype, output_dtypes in cycle_over(float_dtypes()):
input_image = torch.tensor((0.0, 1.0), dtype=input_dtype)
for output_dtype in output_dtypes:
with self.subTest(input_dtype=input_dtype, output_dtype=output_dtype):
transform = transforms.ConvertImageDtype(output_dtype)
output_image = transform(input_image)

actual_min, actual_max = output_image.tolist()
desired_min, desired_max = 0.0, 1.0

self.assertAlmostEqual(actual_min, desired_min)
self.assertAlmostEqual(actual_max, desired_max)

def test_convert_image_dtype_float_to_int(self):
for input_dtype in float_dtypes():
input_image = torch.tensor((0.0, 1.0), dtype=input_dtype)
for output_dtype in int_dtypes():
with self.subTest(input_dtype=input_dtype, output_dtype=output_dtype):
transform = transforms.ConvertImageDtype(output_dtype)

if (input_dtype == torch.float32 and output_dtype in (torch.int32, torch.int64)) or (
input_dtype == torch.float64 and output_dtype == torch.int64
):
with self.assertRaises(RuntimeError):
transform(input_image)
else:
output_image = transform(input_image)

actual_min, actual_max = output_image.tolist()
desired_min, desired_max = 0, torch.iinfo(output_dtype).max

self.assertEqual(actual_min, desired_min)
self.assertEqual(actual_max, desired_max)

def test_convert_image_dtype_int_to_float(self):
for input_dtype in int_dtypes():
input_image = torch.tensor((0, torch.iinfo(input_dtype).max), dtype=input_dtype)
for output_dtype in float_dtypes():
with self.subTest(input_dtype=input_dtype, output_dtype=output_dtype):
transform = transforms.ConvertImageDtype(output_dtype)
output_image = transform(input_image)

actual_min, actual_max = output_image.tolist()
desired_min, desired_max = 0.0, 1.0

self.assertAlmostEqual(actual_min, desired_min)
self.assertGreaterEqual(actual_min, desired_min)
self.assertAlmostEqual(actual_max, desired_max)
self.assertLessEqual(actual_max, desired_max)

def test_convert_image_dtype_int_to_int(self):
for input_dtype, output_dtypes in cycle_over(int_dtypes()):
input_max = torch.iinfo(input_dtype).max
input_image = torch.tensor((0, input_max), dtype=input_dtype)
for output_dtype in output_dtypes:
output_max = torch.iinfo(output_dtype).max

with self.subTest(input_dtype=input_dtype, output_dtype=output_dtype):
transform = transforms.ConvertImageDtype(output_dtype)
output_image = transform(input_image)

actual_min, actual_max = output_image.tolist()
desired_min, desired_max = 0, output_max

# see https://github.com/pytorch/vision/pull/2078#issuecomment-641036236 for details
if input_max >= output_max:
error_term = 0
else:
error_term = 1 - (torch.iinfo(output_dtype).max + 1) // (torch.iinfo(input_dtype).max + 1)

self.assertEqual(actual_min, desired_min)
self.assertEqual(actual_max, desired_max + error_term)

def test_convert_image_dtype_int_to_int_consistency(self):
for input_dtype, output_dtypes in cycle_over(int_dtypes()):
input_max = torch.iinfo(input_dtype).max
input_image = torch.tensor((0, input_max), dtype=input_dtype)
for output_dtype in output_dtypes:
output_max = torch.iinfo(output_dtype).max
if output_max <= input_max:
continue

with self.subTest(input_dtype=input_dtype, output_dtype=output_dtype):
transform = transforms.ConvertImageDtype(output_dtype)
inverse_transfrom = transforms.ConvertImageDtype(input_dtype)
output_image = inverse_transfrom(transform(input_image))

actual_min, actual_max = output_image.tolist()
desired_min, desired_max = 0, input_max

self.assertEqual(actual_min, desired_min)
self.assertEqual(actual_max, desired_max)

@unittest.skipIf(accimage is None, 'accimage not available')
def test_accimage_to_tensor(self):
trans = transforms.ToTensor()
Expand Down
59 changes: 59 additions & 0 deletions torchvision/transforms/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,65 @@ def pil_to_tensor(pic):
return img


def convert_image_dtype(image: torch.Tensor, dtype: torch.dtype = torch.float) -> torch.Tensor:
"""Convert a tensor image to the given ``dtype`` and scale the values accordingly

Args:
image (torch.Tensor): Image to be converted
dtype (torch.dtype): Desired data type of the output

Returns:
(torch.Tensor): Converted image

.. note::

When converting from a smaller to a larger integer ``dtype`` the maximum values are **not** mapped exactly.
If converted back and forth, this mismatch has no effect.

Raises:
RuntimeError: When trying to cast :class:`torch.float32` to :class:`torch.int32` or :class:`torch.int64` as
well as for trying to cast :class:`torch.float64` to :class:`torch.int64`. These conversions might lead to
overflow errors since the floating point ``dtype`` cannot store consecutive integers over the whole range
of the integer ``dtype``.
"""
if image.dtype == dtype:
return image

if image.dtype.is_floating_point:
# float to float
if dtype.is_floating_point:
return image.to(dtype)

# float to int
if (image.dtype == torch.float32 and dtype in (torch.int32, torch.int64)) or (
image.dtype == torch.float64 and dtype == torch.int64
):
msg = f"The cast from {image.dtype} to {dtype} cannot be performed safely."
raise RuntimeError(msg)

eps = 1e-3
return image.mul(torch.iinfo(dtype).max + 1 - eps).to(dtype)
else:
# int to float
if dtype.is_floating_point:
max = torch.iinfo(image.dtype).max
image = image.to(dtype)
return image / max

# int to int
input_max = torch.iinfo(image.dtype).max
output_max = torch.iinfo(dtype).max

if input_max > output_max:
factor = (input_max + 1) // (output_max + 1)
image = image // factor
return image.to(dtype)
else:
factor = (output_max + 1) // (input_max + 1)
image = image.to(dtype)
return image * factor


def to_pil_image(pic, mode=None):
"""Convert a tensor or an ndarray to PIL Image.

Expand Down
33 changes: 29 additions & 4 deletions torchvision/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@
from . import functional as F


__all__ = ["Compose", "ToTensor", "PILToTensor", "ToPILImage", "Normalize", "Resize", "Scale", "CenterCrop", "Pad",
"Lambda", "RandomApply", "RandomChoice", "RandomOrder", "RandomCrop", "RandomHorizontalFlip",
"RandomVerticalFlip", "RandomResizedCrop", "RandomSizedCrop", "FiveCrop", "TenCrop", "LinearTransformation",
"ColorJitter", "RandomRotation", "RandomAffine", "Grayscale", "RandomGrayscale",
__all__ = ["Compose", "ToTensor", "PILToTensor", "ConvertImageDtype", "ToPILImage", "Normalize", "Resize", "Scale",
"CenterCrop", "Pad", "Lambda", "RandomApply", "RandomChoice", "RandomOrder", "RandomCrop",
"RandomHorizontalFlip", "RandomVerticalFlip", "RandomResizedCrop", "RandomSizedCrop", "FiveCrop", "TenCrop",
"LinearTransformation", "ColorJitter", "RandomRotation", "RandomAffine", "Grayscale", "RandomGrayscale",
"RandomPerspective", "RandomErasing"]

_pil_interpolation_to_str = {
Expand Down Expand Up @@ -115,6 +115,31 @@ def __repr__(self):
return self.__class__.__name__ + '()'


class ConvertImageDtype(object):
"""Convert a tensor image to the given ``dtype`` and scale the values accordingly

Args:
dtype (torch.dtype): Desired data type of the output

.. note::

When converting from a smaller to a larger integer ``dtype`` the maximum values are **not** mapped exactly.
If converted back and forth, this mismatch has no effect.

Raises:
RuntimeError: When trying to cast :class:`torch.float32` to :class:`torch.int32` or :class:`torch.int64` as
well as for trying to cast :class:`torch.float64` to :class:`torch.int64`. These conversions might lead to
overflow errors since the floating point ``dtype`` cannot store consecutive integers over the whole range
of the integer ``dtype``.
"""

def __init__(self, dtype: torch.dtype) -> None:
self.dtype = dtype

def __call__(self, image: torch.Tensor) -> torch.Tensor:
return F.convert_image_dtype(image, self.dtype)


class ToPILImage(object):
"""Convert a tensor or an ndarray to PIL Image.

Expand Down