diff --git a/test/test_functional_tensor.py b/test/test_functional_tensor.py index 87373359e83..855754ea68d 100644 --- a/test/test_functional_tensor.py +++ b/test/test_functional_tensor.py @@ -744,6 +744,10 @@ def test_perspective(self): batch_tensors, F.perspective, startpoints=spoints, endpoints=epoints, interpolation=0 ) + def test_convert_image_dtype(self): + # TODO: add tests of CPU/CUDA on tensor and batch + pass + @unittest.skipIf(not torch.cuda.is_available(), reason="Skip if no CUDA device") class CUDATester(Tester): diff --git a/test/test_transforms.py b/test/test_transforms.py index dd92f2b868b..76d70b82ed6 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -2,6 +2,7 @@ import torch import torchvision.transforms as transforms import torchvision.transforms.functional as F +import torchvision.transforms.functional_tensor as F_t from torch._utils_internal import get_file_path_2 from numpy.testing import assert_array_almost_equal import unittest @@ -544,13 +545,26 @@ def test_to_tensor(self): output = trans(img) self.assertTrue(np.allclose(input_data.numpy(), output.numpy())) + def test_max_value(self): + for dtype in int_dtypes(): + self.assertEqual(F_t._max_value(dtype), torch.iinfo(dtype).max) + + for dtype in float_dtypes(): + self.assertGreater(F_t._max_value(dtype), torch.finfo(dtype).max) + def test_convert_image_dtype_float_to_float(self): for input_dtype, output_dtypes in cycle_over(float_dtypes()): input_image = torch.tensor((0.0, 1.0), dtype=input_dtype) for output_dtype in output_dtypes: with self.subTest(input_dtype=input_dtype, output_dtype=output_dtype): transform = transforms.ConvertImageDtype(output_dtype) + transform_script = torch.jit.script(F.convert_image_dtype) + output_image = transform(input_image) + output_image_script = transform_script(input_image, output_dtype) + + script_diff = output_image_script - output_image + self.assertLess(script_diff.abs().max(), 1e-6) actual_min, actual_max = output_image.tolist() desired_min, desired_max = 0.0, 1.0 @@ -564,6 +578,7 @@ def test_convert_image_dtype_float_to_int(self): for output_dtype in int_dtypes(): with self.subTest(input_dtype=input_dtype, output_dtype=output_dtype): transform = transforms.ConvertImageDtype(output_dtype) + transform_script = torch.jit.script(F.convert_image_dtype) if (input_dtype == torch.float32 and output_dtype in (torch.int32, torch.int64)) or ( input_dtype == torch.float64 and output_dtype == torch.int64 @@ -572,6 +587,10 @@ def test_convert_image_dtype_float_to_int(self): transform(input_image) else: output_image = transform(input_image) + output_image_script = transform_script(input_image, output_dtype) + + script_diff = output_image_script - output_image + self.assertLess(script_diff.abs().max(), 1e-6) actual_min, actual_max = output_image.tolist() desired_min, desired_max = 0, torch.iinfo(output_dtype).max @@ -585,7 +604,13 @@ def test_convert_image_dtype_int_to_float(self): for output_dtype in float_dtypes(): with self.subTest(input_dtype=input_dtype, output_dtype=output_dtype): transform = transforms.ConvertImageDtype(output_dtype) + transform_script = torch.jit.script(F.convert_image_dtype) + output_image = transform(input_image) + output_image_script = transform_script(input_image, output_dtype) + + script_diff = output_image_script - output_image + self.assertLess(script_diff.abs().max(), 1e-6) actual_min, actual_max = output_image.tolist() desired_min, desired_max = 0.0, 1.0 @@ -604,7 +629,15 @@ def test_convert_image_dtype_int_to_int(self): with self.subTest(input_dtype=input_dtype, output_dtype=output_dtype): transform = transforms.ConvertImageDtype(output_dtype) + transform_script = torch.jit.script(F.convert_image_dtype) + output_image = transform(input_image) + output_image_script = transform_script(input_image, output_dtype) + + script_diff = output_image_script.float() - output_image.float() + self.assertLess( + script_diff.abs().max(), 1e-6, msg="{} vs {}".format(output_image_script, output_image) + ) actual_min, actual_max = output_image.tolist() desired_min, desired_max = 0, output_max diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index 3e1e57238a6..9085b0c45e8 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -152,48 +152,10 @@ def convert_image_dtype(image: torch.Tensor, dtype: torch.dtype = torch.float) - overflow errors since the floating point ``dtype`` cannot store consecutive integers over the whole range of the integer ``dtype``. """ - if image.dtype == dtype: - return image - - if image.dtype.is_floating_point: - # float to float - if dtype.is_floating_point: - return image.to(dtype) - - # float to int - if (image.dtype == torch.float32 and dtype in (torch.int32, torch.int64)) or ( - image.dtype == torch.float64 and dtype == torch.int64 - ): - msg = f"The cast from {image.dtype} to {dtype} cannot be performed safely." - raise RuntimeError(msg) - - # https://github.com/pytorch/vision/pull/2078#issuecomment-612045321 - # For data in the range 0-1, (float * 255).to(uint) is only 255 - # when float is exactly 1.0. - # `max + 1 - epsilon` provides more evenly distributed mapping of - # ranges of floats to ints. - eps = 1e-3 - result = image.mul(torch.iinfo(dtype).max + 1 - eps) - return result.to(dtype) - else: - # int to float - if dtype.is_floating_point: - max = torch.iinfo(image.dtype).max - image = image.to(dtype) - return image / max - - # int to int - input_max = torch.iinfo(image.dtype).max - output_max = torch.iinfo(dtype).max - - if input_max > output_max: - factor = (input_max + 1) // (output_max + 1) - image = image // factor - return image.to(dtype) - else: - factor = (output_max + 1) // (input_max + 1) - image = image.to(dtype) - return image * factor + if not isinstance(image, torch.Tensor): + raise TypeError('Input img should be Tensor Image') + + return F_t.convert_image_dtype(image, dtype) def to_pil_image(pic, mode=None): diff --git a/torchvision/transforms/functional_tensor.py b/torchvision/transforms/functional_tensor.py index 81d059e9d14..5436aeff9c0 100644 --- a/torchvision/transforms/functional_tensor.py +++ b/torchvision/transforms/functional_tensor.py @@ -27,6 +27,101 @@ def _get_image_num_channels(img: Tensor) -> int: raise TypeError("Input ndim should be 2 or more. Got {}".format(img.ndim)) +def _max_value(dtype: torch.dtype) -> float: + # TODO: replace this method with torch.iinfo when it gets torchscript support. + # https://github.com/pytorch/pytorch/issues/41492 + + a = torch.tensor(2, dtype=dtype) + signed = 1 if torch.tensor(0, dtype=dtype).is_signed() else 0 + bits = 1 + max_value = torch.tensor(-signed, dtype=torch.long) + while True: + next_value = a.pow(bits - signed).sub(1) + if next_value > max_value: + max_value = next_value + bits *= 2 + else: + return max_value.item() + return max_value.item() + + +def convert_image_dtype(image: torch.Tensor, dtype: torch.dtype = torch.float) -> torch.Tensor: + """PRIVATE METHOD. Convert a tensor image to the given ``dtype`` and scale the values accordingly + + .. warning:: + + Module ``transforms.functional_tensor`` is private and should not be used in user application. + Please, consider instead using methods from `transforms.functional` module. + + Args: + image (torch.Tensor): Image to be converted + dtype (torch.dtype): Desired data type of the output + + Returns: + (torch.Tensor): Converted image + + .. note:: + + When converting from a smaller to a larger integer ``dtype`` the maximum values are **not** mapped exactly. + If converted back and forth, this mismatch has no effect. + + Raises: + RuntimeError: When trying to cast :class:`torch.float32` to :class:`torch.int32` or :class:`torch.int64` as + well as for trying to cast :class:`torch.float64` to :class:`torch.int64`. These conversions might lead to + overflow errors since the floating point ``dtype`` cannot store consecutive integers over the whole range + of the integer ``dtype``. + """ + if image.dtype == dtype: + return image + + # TODO: replace with image.dtype.is_floating_point when torchscript supports it + if torch.empty(0, dtype=image.dtype).is_floating_point(): + + # TODO: replace with dtype.is_floating_point when torchscript supports it + if torch.tensor(0, dtype=dtype).is_floating_point(): + return image.to(dtype) + + # float to int + if (image.dtype == torch.float32 and dtype in (torch.int32, torch.int64)) or ( + image.dtype == torch.float64 and dtype == torch.int64 + ): + msg = f"The cast from {image.dtype} to {dtype} cannot be performed safely." + raise RuntimeError(msg) + + # https://github.com/pytorch/vision/pull/2078#issuecomment-612045321 + # For data in the range 0-1, (float * 255).to(uint) is only 255 + # when float is exactly 1.0. + # `max + 1 - epsilon` provides more evenly distributed mapping of + # ranges of floats to ints. + eps = 1e-3 + max_val = _max_value(dtype) + result = image.mul(max_val + 1.0 - eps) + return result.to(dtype) + else: + input_max = _max_value(image.dtype) + output_max = _max_value(dtype) + + # int to float + # TODO: replace with dtype.is_floating_point when torchscript supports it + if torch.tensor(0, dtype=dtype).is_floating_point(): + image = image.to(dtype) + return image / input_max + + # int to int + if input_max > output_max: + # factor should be forced to int for torch jit script + # otherwise factor is a float and image // factor can produce different results + factor = int((input_max + 1) // (output_max + 1)) + image = image // factor + return image.to(dtype) + else: + # factor should be forced to int for torch jit script + # otherwise factor is a float and image * factor can produce different results + factor = int((output_max + 1) // (input_max + 1)) + image = image.to(dtype) + return image * factor + + def vflip(img: Tensor) -> Tensor: """PRIVATE METHOD. Vertically flip the given the Image Tensor. @@ -302,13 +397,11 @@ def adjust_gamma(img: Tensor, gamma: float, gain: float = 1) -> Tensor: result = img dtype = img.dtype if not torch.is_floating_point(img): - result = result / 255.0 + result = convert_image_dtype(result, torch.float32) result = (gain * result ** gamma).clamp(0, 1) - if result.dtype != dtype: - eps = 1e-3 - result = (255 + 1.0 - eps) * result + result = convert_image_dtype(result, dtype) result = result.to(dtype) return result diff --git a/torchvision/transforms/transforms.py b/torchvision/transforms/transforms.py index a672741baef..2a585f98c3f 100644 --- a/torchvision/transforms/transforms.py +++ b/torchvision/transforms/transforms.py @@ -16,7 +16,6 @@ from . import functional as F - __all__ = ["Compose", "ToTensor", "PILToTensor", "ConvertImageDtype", "ToPILImage", "Normalize", "Resize", "Scale", "CenterCrop", "Pad", "Lambda", "RandomApply", "RandomChoice", "RandomOrder", "RandomCrop", "RandomHorizontalFlip", "RandomVerticalFlip", "RandomResizedCrop", "RandomSizedCrop", "FiveCrop", "TenCrop", @@ -127,7 +126,7 @@ def __repr__(self): return self.__class__.__name__ + '()' -class ConvertImageDtype: +class ConvertImageDtype(torch.nn.Module): """Convert a tensor image to the given ``dtype`` and scale the values accordingly Args: @@ -146,9 +145,10 @@ class ConvertImageDtype: """ def __init__(self, dtype: torch.dtype) -> None: + super().__init__() self.dtype = dtype - def __call__(self, image: torch.Tensor) -> torch.Tensor: + def forward(self, image: torch.Tensor) -> torch.Tensor: return F.convert_image_dtype(image, self.dtype)