From fba59a11b4ff0c43cec5b661b86fa4f932eb713d Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 11 Jun 2020 19:35:20 +0200 Subject: [PATCH 1/6] add torch.script tests to convert_image_dtype --- test/test_transforms.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/test/test_transforms.py b/test/test_transforms.py index 8423bf99ee3..24c26a50a47 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -532,7 +532,13 @@ def test_convert_image_dtype_float_to_float(self): for output_dtype in output_dtypes: with self.subTest(input_dtype=input_dtype, output_dtype=output_dtype): transform = transforms.ConvertImageDtype(output_dtype) + transform_script = lambda image: torch.jit.script(F.convert_image_dtype(image, output_dtype)) + output_image = transform(input_image) + output_image_script = transform_script(input_image) + + script_diff = output_image_script - output_image + self.assertTrue(script_diff.abs().max() < 1e-6) actual_min, actual_max = output_image.tolist() desired_min, desired_max = 0.0, 1.0 @@ -546,6 +552,7 @@ def test_convert_image_dtype_float_to_int(self): for output_dtype in int_dtypes(): with self.subTest(input_dtype=input_dtype, output_dtype=output_dtype): transform = transforms.ConvertImageDtype(output_dtype) + transform_script = lambda image: torch.jit.script(F.convert_image_dtype(image, output_dtype)) if (input_dtype == torch.float32 and output_dtype in (torch.int32, torch.int64)) or ( input_dtype == torch.float64 and output_dtype == torch.int64 @@ -554,6 +561,10 @@ def test_convert_image_dtype_float_to_int(self): transform(input_image) else: output_image = transform(input_image) + output_image_script = transform_script(input_image) + + script_diff = output_image_script - output_image + self.assertTrue(script_diff.abs().max() < 1e-6) actual_min, actual_max = output_image.tolist() desired_min, desired_max = 0, torch.iinfo(output_dtype).max @@ -567,7 +578,13 @@ def test_convert_image_dtype_int_to_float(self): for output_dtype in float_dtypes(): with self.subTest(input_dtype=input_dtype, output_dtype=output_dtype): transform = transforms.ConvertImageDtype(output_dtype) + transform_script = lambda image: torch.jit.script(F.convert_image_dtype(image, output_dtype)) + output_image = transform(input_image) + output_image_script = transform_script(input_image) + + script_diff = output_image_script - output_image + self.assertTrue(script_diff.abs().max() < 1e-6) actual_min, actual_max = output_image.tolist() desired_min, desired_max = 0.0, 1.0 @@ -586,7 +603,13 @@ def test_convert_image_dtype_int_to_int(self): with self.subTest(input_dtype=input_dtype, output_dtype=output_dtype): transform = transforms.ConvertImageDtype(output_dtype) + transform_script = lambda image: torch.jit.script(F.convert_image_dtype(image, output_dtype)) + output_image = transform(input_image) + output_image_script = transform_script(input_image) + + script_diff = output_image_script - output_image + self.assertTrue(script_diff.abs().max() < 1e-6) actual_min, actual_max = output_image.tolist() desired_min, desired_max = 0, output_max From 074e8ceec3c8e1d99e16aec6a02fa61afd12c145 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 11 Jun 2020 19:38:48 +0200 Subject: [PATCH 2/6] lint --- test/test_transforms.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/test/test_transforms.py b/test/test_transforms.py index 24c26a50a47..ce72f46750d 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -532,7 +532,7 @@ def test_convert_image_dtype_float_to_float(self): for output_dtype in output_dtypes: with self.subTest(input_dtype=input_dtype, output_dtype=output_dtype): transform = transforms.ConvertImageDtype(output_dtype) - transform_script = lambda image: torch.jit.script(F.convert_image_dtype(image, output_dtype)) + transform_script = lambda image: torch.jit.script(F.convert_image_dtype(image, output_dtype)) # noqa: E731 output_image = transform(input_image) output_image_script = transform_script(input_image) @@ -552,7 +552,7 @@ def test_convert_image_dtype_float_to_int(self): for output_dtype in int_dtypes(): with self.subTest(input_dtype=input_dtype, output_dtype=output_dtype): transform = transforms.ConvertImageDtype(output_dtype) - transform_script = lambda image: torch.jit.script(F.convert_image_dtype(image, output_dtype)) + transform_script = lambda image: torch.jit.script(F.convert_image_dtype(image, output_dtype)) # noqa: E731 if (input_dtype == torch.float32 and output_dtype in (torch.int32, torch.int64)) or ( input_dtype == torch.float64 and output_dtype == torch.int64 @@ -578,7 +578,7 @@ def test_convert_image_dtype_int_to_float(self): for output_dtype in float_dtypes(): with self.subTest(input_dtype=input_dtype, output_dtype=output_dtype): transform = transforms.ConvertImageDtype(output_dtype) - transform_script = lambda image: torch.jit.script(F.convert_image_dtype(image, output_dtype)) + transform_script = lambda image: torch.jit.script(F.convert_image_dtype(image, output_dtype)) # noqa: E731 output_image = transform(input_image) output_image_script = transform_script(input_image) @@ -603,8 +603,8 @@ def test_convert_image_dtype_int_to_int(self): with self.subTest(input_dtype=input_dtype, output_dtype=output_dtype): transform = transforms.ConvertImageDtype(output_dtype) - transform_script = lambda image: torch.jit.script(F.convert_image_dtype(image, output_dtype)) - + transform_script = lambda image: torch.jit.script(F.convert_image_dtype(image, output_dtype)) # noqa: E731 + output_image = transform(input_image) output_image_script = transform_script(input_image) From 7100a6ed94b09d31ca176e0649f15ce7f392253f Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 11 Jun 2020 19:44:05 +0200 Subject: [PATCH 3/6] lint --- test/test_transforms.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/test/test_transforms.py b/test/test_transforms.py index ce72f46750d..c779801218a 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -532,7 +532,9 @@ def test_convert_image_dtype_float_to_float(self): for output_dtype in output_dtypes: with self.subTest(input_dtype=input_dtype, output_dtype=output_dtype): transform = transforms.ConvertImageDtype(output_dtype) - transform_script = lambda image: torch.jit.script(F.convert_image_dtype(image, output_dtype)) # noqa: E731 + transform_script = lambda image: torch.jit.script( # noqa: E731 + F.convert_image_dtype(image, output_dtype) + ) output_image = transform(input_image) output_image_script = transform_script(input_image) @@ -552,7 +554,9 @@ def test_convert_image_dtype_float_to_int(self): for output_dtype in int_dtypes(): with self.subTest(input_dtype=input_dtype, output_dtype=output_dtype): transform = transforms.ConvertImageDtype(output_dtype) - transform_script = lambda image: torch.jit.script(F.convert_image_dtype(image, output_dtype)) # noqa: E731 + transform_script = lambda image: torch.jit.script( # noqa: E731 + F.convert_image_dtype(image, output_dtype) + ) if (input_dtype == torch.float32 and output_dtype in (torch.int32, torch.int64)) or ( input_dtype == torch.float64 and output_dtype == torch.int64 @@ -578,7 +582,9 @@ def test_convert_image_dtype_int_to_float(self): for output_dtype in float_dtypes(): with self.subTest(input_dtype=input_dtype, output_dtype=output_dtype): transform = transforms.ConvertImageDtype(output_dtype) - transform_script = lambda image: torch.jit.script(F.convert_image_dtype(image, output_dtype)) # noqa: E731 + transform_script = lambda image: torch.jit.script( # noqa: E731 + F.convert_image_dtype(image, output_dtype) + ) output_image = transform(input_image) output_image_script = transform_script(input_image) @@ -603,7 +609,9 @@ def test_convert_image_dtype_int_to_int(self): with self.subTest(input_dtype=input_dtype, output_dtype=output_dtype): transform = transforms.ConvertImageDtype(output_dtype) - transform_script = lambda image: torch.jit.script(F.convert_image_dtype(image, output_dtype)) # noqa: E731 + transform_script = lambda image: torch.jit.script( # noqa: E731 + F.convert_image_dtype(image, output_dtype) + ) output_image = transform(input_image) output_image_script = transform_script(input_image) From 06393e452b59e9d10b1d73e7030c0603cbd08923 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Fri, 12 Jun 2020 09:23:44 +0200 Subject: [PATCH 4/6] bug fix --- test/test_transforms.py | 16 ++++------------ 1 file changed, 4 insertions(+), 12 deletions(-) diff --git a/test/test_transforms.py b/test/test_transforms.py index c779801218a..02c735aa559 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -532,9 +532,7 @@ def test_convert_image_dtype_float_to_float(self): for output_dtype in output_dtypes: with self.subTest(input_dtype=input_dtype, output_dtype=output_dtype): transform = transforms.ConvertImageDtype(output_dtype) - transform_script = lambda image: torch.jit.script( # noqa: E731 - F.convert_image_dtype(image, output_dtype) - ) + transform_script = torch.jit.script(lambda image: F.convert_image_dtype(image, output_dtype)) output_image = transform(input_image) output_image_script = transform_script(input_image) @@ -554,9 +552,7 @@ def test_convert_image_dtype_float_to_int(self): for output_dtype in int_dtypes(): with self.subTest(input_dtype=input_dtype, output_dtype=output_dtype): transform = transforms.ConvertImageDtype(output_dtype) - transform_script = lambda image: torch.jit.script( # noqa: E731 - F.convert_image_dtype(image, output_dtype) - ) + transform_script = torch.jit.script(lambda image: F.convert_image_dtype(image, output_dtype)) if (input_dtype == torch.float32 and output_dtype in (torch.int32, torch.int64)) or ( input_dtype == torch.float64 and output_dtype == torch.int64 @@ -582,9 +578,7 @@ def test_convert_image_dtype_int_to_float(self): for output_dtype in float_dtypes(): with self.subTest(input_dtype=input_dtype, output_dtype=output_dtype): transform = transforms.ConvertImageDtype(output_dtype) - transform_script = lambda image: torch.jit.script( # noqa: E731 - F.convert_image_dtype(image, output_dtype) - ) + transform_script = torch.jit.script(lambda image: F.convert_image_dtype(image, output_dtype)) output_image = transform(input_image) output_image_script = transform_script(input_image) @@ -609,9 +603,7 @@ def test_convert_image_dtype_int_to_int(self): with self.subTest(input_dtype=input_dtype, output_dtype=output_dtype): transform = transforms.ConvertImageDtype(output_dtype) - transform_script = lambda image: torch.jit.script( # noqa: E731 - F.convert_image_dtype(image, output_dtype) - ) + transform_script = torch.jit.script(lambda image: F.convert_image_dtype(image, output_dtype)) output_image = transform(input_image) output_image_script = transform_script(input_image) From df58a273fa25432ff24b379c45c8f9622c76865b Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Fri, 12 Jun 2020 09:54:44 +0200 Subject: [PATCH 5/6] remove lambda --- test/test_transforms.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/test/test_transforms.py b/test/test_transforms.py index 02c735aa559..f8101a1d862 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -532,10 +532,10 @@ def test_convert_image_dtype_float_to_float(self): for output_dtype in output_dtypes: with self.subTest(input_dtype=input_dtype, output_dtype=output_dtype): transform = transforms.ConvertImageDtype(output_dtype) - transform_script = torch.jit.script(lambda image: F.convert_image_dtype(image, output_dtype)) + transform_script = torch.jit.script(F.convert_image_dtype) output_image = transform(input_image) - output_image_script = transform_script(input_image) + output_image_script = transform_script(input_image, output_dtype) script_diff = output_image_script - output_image self.assertTrue(script_diff.abs().max() < 1e-6) @@ -552,7 +552,7 @@ def test_convert_image_dtype_float_to_int(self): for output_dtype in int_dtypes(): with self.subTest(input_dtype=input_dtype, output_dtype=output_dtype): transform = transforms.ConvertImageDtype(output_dtype) - transform_script = torch.jit.script(lambda image: F.convert_image_dtype(image, output_dtype)) + transform_script = torch.jit.script(F.convert_image_dtype) if (input_dtype == torch.float32 and output_dtype in (torch.int32, torch.int64)) or ( input_dtype == torch.float64 and output_dtype == torch.int64 @@ -561,7 +561,7 @@ def test_convert_image_dtype_float_to_int(self): transform(input_image) else: output_image = transform(input_image) - output_image_script = transform_script(input_image) + output_image_script = transform_script(input_image, output_dtype) script_diff = output_image_script - output_image self.assertTrue(script_diff.abs().max() < 1e-6) @@ -578,10 +578,10 @@ def test_convert_image_dtype_int_to_float(self): for output_dtype in float_dtypes(): with self.subTest(input_dtype=input_dtype, output_dtype=output_dtype): transform = transforms.ConvertImageDtype(output_dtype) - transform_script = torch.jit.script(lambda image: F.convert_image_dtype(image, output_dtype)) + transform_script = torch.jit.script(F.convert_image_dtype) output_image = transform(input_image) - output_image_script = transform_script(input_image) + output_image_script = transform_script(input_image, output_dtype) script_diff = output_image_script - output_image self.assertTrue(script_diff.abs().max() < 1e-6) @@ -603,10 +603,10 @@ def test_convert_image_dtype_int_to_int(self): with self.subTest(input_dtype=input_dtype, output_dtype=output_dtype): transform = transforms.ConvertImageDtype(output_dtype) - transform_script = torch.jit.script(lambda image: F.convert_image_dtype(image, output_dtype)) + transform_script = torch.jit.script(F.convert_image_dtype) output_image = transform(input_image) - output_image_script = transform_script(input_image) + output_image_script = transform_script(input_image, output_dtype) script_diff = output_image_script - output_image self.assertTrue(script_diff.abs().max() < 1e-6) From d789cb9cf562fd76908661f4affd0b260d7a2b35 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Fri, 12 Jun 2020 11:37:29 +0200 Subject: [PATCH 6/6] try remove torch.dtype.is_floating_point --- torchvision/transforms/functional.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index e49ff063dc8..c9609774c24 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -113,6 +113,11 @@ def pil_to_tensor(pic): return img +def _is_floating_point(dtype: torch.dtype) -> bool: + # helper function since torch.dtype.is_floating_point is not scriptable + return isinstance(dtype, (torch.float32, torch.float, torch.float64, torch.double)) + + def convert_image_dtype(image: torch.Tensor, dtype: torch.dtype = torch.float) -> torch.Tensor: """Convert a tensor image to the given ``dtype`` and scale the values accordingly @@ -137,9 +142,12 @@ def convert_image_dtype(image: torch.Tensor, dtype: torch.dtype = torch.float) - if image.dtype == dtype: return image - if image.dtype.is_floating_point: + input_is_float = _is_floating_point(image.dtype) + output_is_float = _is_floating_point(dtype) + + if input_is_float: # float to float - if dtype.is_floating_point: + if output_is_float: return image.to(dtype) # float to int @@ -153,7 +161,7 @@ def convert_image_dtype(image: torch.Tensor, dtype: torch.dtype = torch.float) - return image.mul(torch.iinfo(dtype).max + 1 - eps).to(dtype) else: # int to float - if dtype.is_floating_point: + if output_is_float: max = torch.iinfo(image.dtype).max image = image.to(dtype) return image / max