Skip to content

make convert_image_dtype scriptable #2485

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Oct 5, 2020
4 changes: 4 additions & 0 deletions test/test_functional_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -744,6 +744,10 @@ def test_perspective(self):
batch_tensors, F.perspective, startpoints=spoints, endpoints=epoints, interpolation=0
)

def test_convert_image_dtype(self):
# TODO: add tests of CPU/CUDA on tensor and batch
pass


@unittest.skipIf(not torch.cuda.is_available(), reason="Skip if no CUDA device")
class CUDATester(Tester):
Expand Down
33 changes: 33 additions & 0 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import torch
import torchvision.transforms as transforms
import torchvision.transforms.functional as F
import torchvision.transforms.functional_tensor as F_t
from torch._utils_internal import get_file_path_2
from numpy.testing import assert_array_almost_equal
import unittest
Expand Down Expand Up @@ -544,13 +545,26 @@ def test_to_tensor(self):
output = trans(img)
self.assertTrue(np.allclose(input_data.numpy(), output.numpy()))

def test_max_value(self):
for dtype in int_dtypes():
self.assertEqual(F_t._max_value(dtype), torch.iinfo(dtype).max)

for dtype in float_dtypes():
self.assertGreater(F_t._max_value(dtype), torch.finfo(dtype).max)

def test_convert_image_dtype_float_to_float(self):
for input_dtype, output_dtypes in cycle_over(float_dtypes()):
input_image = torch.tensor((0.0, 1.0), dtype=input_dtype)
for output_dtype in output_dtypes:
with self.subTest(input_dtype=input_dtype, output_dtype=output_dtype):
transform = transforms.ConvertImageDtype(output_dtype)
transform_script = torch.jit.script(F.convert_image_dtype)

output_image = transform(input_image)
output_image_script = transform_script(input_image, output_dtype)

script_diff = output_image_script - output_image
self.assertLess(script_diff.abs().max(), 1e-6)

actual_min, actual_max = output_image.tolist()
desired_min, desired_max = 0.0, 1.0
Expand All @@ -564,6 +578,7 @@ def test_convert_image_dtype_float_to_int(self):
for output_dtype in int_dtypes():
with self.subTest(input_dtype=input_dtype, output_dtype=output_dtype):
transform = transforms.ConvertImageDtype(output_dtype)
transform_script = torch.jit.script(F.convert_image_dtype)

if (input_dtype == torch.float32 and output_dtype in (torch.int32, torch.int64)) or (
input_dtype == torch.float64 and output_dtype == torch.int64
Expand All @@ -572,6 +587,10 @@ def test_convert_image_dtype_float_to_int(self):
transform(input_image)
else:
output_image = transform(input_image)
output_image_script = transform_script(input_image, output_dtype)

script_diff = output_image_script - output_image
self.assertLess(script_diff.abs().max(), 1e-6)

actual_min, actual_max = output_image.tolist()
desired_min, desired_max = 0, torch.iinfo(output_dtype).max
Expand All @@ -585,7 +604,13 @@ def test_convert_image_dtype_int_to_float(self):
for output_dtype in float_dtypes():
with self.subTest(input_dtype=input_dtype, output_dtype=output_dtype):
transform = transforms.ConvertImageDtype(output_dtype)
transform_script = torch.jit.script(F.convert_image_dtype)

output_image = transform(input_image)
output_image_script = transform_script(input_image, output_dtype)

script_diff = output_image_script - output_image
self.assertLess(script_diff.abs().max(), 1e-6)

actual_min, actual_max = output_image.tolist()
desired_min, desired_max = 0.0, 1.0
Expand All @@ -604,7 +629,15 @@ def test_convert_image_dtype_int_to_int(self):

with self.subTest(input_dtype=input_dtype, output_dtype=output_dtype):
transform = transforms.ConvertImageDtype(output_dtype)
transform_script = torch.jit.script(F.convert_image_dtype)

output_image = transform(input_image)
output_image_script = transform_script(input_image, output_dtype)

script_diff = output_image_script.float() - output_image.float()
self.assertLess(
script_diff.abs().max(), 1e-6, msg="{} vs {}".format(output_image_script, output_image)
)

actual_min, actual_max = output_image.tolist()
desired_min, desired_max = 0, output_max
Expand Down
46 changes: 4 additions & 42 deletions torchvision/transforms/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,48 +152,10 @@ def convert_image_dtype(image: torch.Tensor, dtype: torch.dtype = torch.float) -
overflow errors since the floating point ``dtype`` cannot store consecutive integers over the whole range
of the integer ``dtype``.
"""
if image.dtype == dtype:
return image

if image.dtype.is_floating_point:
# float to float
if dtype.is_floating_point:
return image.to(dtype)

# float to int
if (image.dtype == torch.float32 and dtype in (torch.int32, torch.int64)) or (
image.dtype == torch.float64 and dtype == torch.int64
):
msg = f"The cast from {image.dtype} to {dtype} cannot be performed safely."
raise RuntimeError(msg)

# https://github.com/pytorch/vision/pull/2078#issuecomment-612045321
# For data in the range 0-1, (float * 255).to(uint) is only 255
# when float is exactly 1.0.
# `max + 1 - epsilon` provides more evenly distributed mapping of
# ranges of floats to ints.
eps = 1e-3
result = image.mul(torch.iinfo(dtype).max + 1 - eps)
return result.to(dtype)
else:
# int to float
if dtype.is_floating_point:
max = torch.iinfo(image.dtype).max
image = image.to(dtype)
return image / max

# int to int
input_max = torch.iinfo(image.dtype).max
output_max = torch.iinfo(dtype).max

if input_max > output_max:
factor = (input_max + 1) // (output_max + 1)
image = image // factor
return image.to(dtype)
else:
factor = (output_max + 1) // (input_max + 1)
image = image.to(dtype)
return image * factor
if not isinstance(image, torch.Tensor):
raise TypeError('Input img should be Tensor Image')

return F_t.convert_image_dtype(image, dtype)


def to_pil_image(pic, mode=None):
Expand Down
101 changes: 97 additions & 4 deletions torchvision/transforms/functional_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,101 @@ def _get_image_num_channels(img: Tensor) -> int:
raise TypeError("Input ndim should be 2 or more. Got {}".format(img.ndim))


def _max_value(dtype: torch.dtype) -> float:
# TODO: replace this method with torch.iinfo when it gets torchscript support.
# https://github.com/pytorch/pytorch/issues/41492

a = torch.tensor(2, dtype=dtype)
signed = 1 if torch.tensor(0, dtype=dtype).is_signed() else 0
bits = 1
max_value = torch.tensor(-signed, dtype=torch.long)
while True:
next_value = a.pow(bits - signed).sub(1)
if next_value > max_value:
max_value = next_value
bits *= 2
else:
return max_value.item()
return max_value.item()


def convert_image_dtype(image: torch.Tensor, dtype: torch.dtype = torch.float) -> torch.Tensor:
"""PRIVATE METHOD. Convert a tensor image to the given ``dtype`` and scale the values accordingly

.. warning::

Module ``transforms.functional_tensor`` is private and should not be used in user application.
Please, consider instead using methods from `transforms.functional` module.

Args:
image (torch.Tensor): Image to be converted
dtype (torch.dtype): Desired data type of the output

Returns:
(torch.Tensor): Converted image

.. note::

When converting from a smaller to a larger integer ``dtype`` the maximum values are **not** mapped exactly.
If converted back and forth, this mismatch has no effect.

Raises:
RuntimeError: When trying to cast :class:`torch.float32` to :class:`torch.int32` or :class:`torch.int64` as
well as for trying to cast :class:`torch.float64` to :class:`torch.int64`. These conversions might lead to
overflow errors since the floating point ``dtype`` cannot store consecutive integers over the whole range
of the integer ``dtype``.
"""
if image.dtype == dtype:
return image

# TODO: replace with image.dtype.is_floating_point when torchscript supports it
if torch.empty(0, dtype=image.dtype).is_floating_point():

# TODO: replace with dtype.is_floating_point when torchscript supports it
if torch.tensor(0, dtype=dtype).is_floating_point():
return image.to(dtype)

# float to int
if (image.dtype == torch.float32 and dtype in (torch.int32, torch.int64)) or (
image.dtype == torch.float64 and dtype == torch.int64
):
msg = f"The cast from {image.dtype} to {dtype} cannot be performed safely."
raise RuntimeError(msg)

# https://github.com/pytorch/vision/pull/2078#issuecomment-612045321
# For data in the range 0-1, (float * 255).to(uint) is only 255
# when float is exactly 1.0.
# `max + 1 - epsilon` provides more evenly distributed mapping of
# ranges of floats to ints.
eps = 1e-3
max_val = _max_value(dtype)
result = image.mul(max_val + 1.0 - eps)
return result.to(dtype)
else:
input_max = _max_value(image.dtype)
output_max = _max_value(dtype)

# int to float
# TODO: replace with dtype.is_floating_point when torchscript supports it
if torch.tensor(0, dtype=dtype).is_floating_point():
image = image.to(dtype)
return image / input_max

# int to int
if input_max > output_max:
# factor should be forced to int for torch jit script
# otherwise factor is a float and image // factor can produce different results
factor = int((input_max + 1) // (output_max + 1))
image = image // factor
return image.to(dtype)
else:
# factor should be forced to int for torch jit script
# otherwise factor is a float and image * factor can produce different results
factor = int((output_max + 1) // (input_max + 1))
image = image.to(dtype)
return image * factor


def vflip(img: Tensor) -> Tensor:
"""PRIVATE METHOD. Vertically flip the given the Image Tensor.

Expand Down Expand Up @@ -302,13 +397,11 @@ def adjust_gamma(img: Tensor, gamma: float, gain: float = 1) -> Tensor:
result = img
dtype = img.dtype
if not torch.is_floating_point(img):
result = result / 255.0
result = convert_image_dtype(result, torch.float32)

result = (gain * result ** gamma).clamp(0, 1)

if result.dtype != dtype:
eps = 1e-3
result = (255 + 1.0 - eps) * result
result = convert_image_dtype(result, dtype)
result = result.to(dtype)
return result

Expand Down
6 changes: 3 additions & 3 deletions torchvision/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

from . import functional as F


__all__ = ["Compose", "ToTensor", "PILToTensor", "ConvertImageDtype", "ToPILImage", "Normalize", "Resize", "Scale",
"CenterCrop", "Pad", "Lambda", "RandomApply", "RandomChoice", "RandomOrder", "RandomCrop",
"RandomHorizontalFlip", "RandomVerticalFlip", "RandomResizedCrop", "RandomSizedCrop", "FiveCrop", "TenCrop",
Expand Down Expand Up @@ -127,7 +126,7 @@ def __repr__(self):
return self.__class__.__name__ + '()'


class ConvertImageDtype:
class ConvertImageDtype(torch.nn.Module):
"""Convert a tensor image to the given ``dtype`` and scale the values accordingly

Args:
Expand All @@ -146,9 +145,10 @@ class ConvertImageDtype:
"""

def __init__(self, dtype: torch.dtype) -> None:
super().__init__()
self.dtype = dtype

def __call__(self, image: torch.Tensor) -> torch.Tensor:
def forward(self, image: torch.Tensor) -> torch.Tensor:
return F.convert_image_dtype(image, self.dtype)


Expand Down