Skip to content

add torch.script tests to convert_image_dtype #2313

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,7 +532,13 @@ def test_convert_image_dtype_float_to_float(self):
for output_dtype in output_dtypes:
with self.subTest(input_dtype=input_dtype, output_dtype=output_dtype):
transform = transforms.ConvertImageDtype(output_dtype)
transform_script = torch.jit.script(F.convert_image_dtype)

output_image = transform(input_image)
output_image_script = transform_script(input_image, output_dtype)

script_diff = output_image_script - output_image
self.assertTrue(script_diff.abs().max() < 1e-6)

actual_min, actual_max = output_image.tolist()
desired_min, desired_max = 0.0, 1.0
Expand All @@ -546,6 +552,7 @@ def test_convert_image_dtype_float_to_int(self):
for output_dtype in int_dtypes():
with self.subTest(input_dtype=input_dtype, output_dtype=output_dtype):
transform = transforms.ConvertImageDtype(output_dtype)
transform_script = torch.jit.script(F.convert_image_dtype)

if (input_dtype == torch.float32 and output_dtype in (torch.int32, torch.int64)) or (
input_dtype == torch.float64 and output_dtype == torch.int64
Expand All @@ -554,6 +561,10 @@ def test_convert_image_dtype_float_to_int(self):
transform(input_image)
else:
output_image = transform(input_image)
output_image_script = transform_script(input_image, output_dtype)

script_diff = output_image_script - output_image
self.assertTrue(script_diff.abs().max() < 1e-6)

actual_min, actual_max = output_image.tolist()
desired_min, desired_max = 0, torch.iinfo(output_dtype).max
Expand All @@ -567,7 +578,13 @@ def test_convert_image_dtype_int_to_float(self):
for output_dtype in float_dtypes():
with self.subTest(input_dtype=input_dtype, output_dtype=output_dtype):
transform = transforms.ConvertImageDtype(output_dtype)
transform_script = torch.jit.script(F.convert_image_dtype)

output_image = transform(input_image)
output_image_script = transform_script(input_image, output_dtype)

script_diff = output_image_script - output_image
self.assertTrue(script_diff.abs().max() < 1e-6)

actual_min, actual_max = output_image.tolist()
desired_min, desired_max = 0.0, 1.0
Expand All @@ -586,7 +603,13 @@ def test_convert_image_dtype_int_to_int(self):

with self.subTest(input_dtype=input_dtype, output_dtype=output_dtype):
transform = transforms.ConvertImageDtype(output_dtype)
transform_script = torch.jit.script(F.convert_image_dtype)

output_image = transform(input_image)
output_image_script = transform_script(input_image, output_dtype)

script_diff = output_image_script - output_image
self.assertTrue(script_diff.abs().max() < 1e-6)

actual_min, actual_max = output_image.tolist()
desired_min, desired_max = 0, output_max
Expand Down
14 changes: 11 additions & 3 deletions torchvision/transforms/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,11 @@ def pil_to_tensor(pic):
return img


def _is_floating_point(dtype: torch.dtype) -> bool:
# helper function since torch.dtype.is_floating_point is not scriptable
return isinstance(dtype, (torch.float32, torch.float, torch.float64, torch.double))


def convert_image_dtype(image: torch.Tensor, dtype: torch.dtype = torch.float) -> torch.Tensor:
"""Convert a tensor image to the given ``dtype`` and scale the values accordingly

Expand All @@ -137,9 +142,12 @@ def convert_image_dtype(image: torch.Tensor, dtype: torch.dtype = torch.float) -
if image.dtype == dtype:
return image

if image.dtype.is_floating_point:
input_is_float = _is_floating_point(image.dtype)
output_is_float = _is_floating_point(dtype)

if input_is_float:
# float to float
if dtype.is_floating_point:
if output_is_float:
return image.to(dtype)

# float to int
Expand All @@ -153,7 +161,7 @@ def convert_image_dtype(image: torch.Tensor, dtype: torch.dtype = torch.float) -
return image.mul(torch.iinfo(dtype).max + 1 - eps).to(dtype)
else:
# int to float
if dtype.is_floating_point:
if output_is_float:
max = torch.iinfo(image.dtype).max
image = image.to(dtype)
return image / max
Expand Down