From 968f37bba0266318d1da9360a3b5768eea9329dc Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 2 Nov 2022 20:30:07 +0100 Subject: [PATCH 1/5] remove unnecessary changes from pad_image_tensor --- .../transforms/functional/_geometry.py | 97 +++++++++++++++---- 1 file changed, 80 insertions(+), 17 deletions(-) diff --git a/torchvision/prototype/transforms/functional/_geometry.py b/torchvision/prototype/transforms/functional/_geometry.py index 9ed8a965ee3..563ebd69506 100644 --- a/torchvision/prototype/transforms/functional/_geometry.py +++ b/torchvision/prototype/transforms/functional/_geometry.py @@ -4,6 +4,7 @@ import PIL.Image import torch +from torch.nn.functional import pad as torch_pad from torchvision.prototype import features from torchvision.transforms import functional_pil as _FP, functional_tensor as _FT from torchvision.transforms.functional import ( @@ -14,7 +15,6 @@ pil_to_tensor, to_pil_image, ) -from torchvision.transforms.functional_tensor import _parse_pad_padding from ._meta import convert_format_bounding_box, get_spatial_size_image_pil @@ -645,7 +645,32 @@ def rotate( return rotate_image_pil(inpt, angle, interpolation=interpolation, expand=expand, fill=fill, center=center) -pad_image_pil = _FP.pad +def _parse_pad_padding(padding: Union[int, List[int]]) -> List[int]: + if not isinstance(padding, (int, tuple, list)): + raise TypeError(f"`padding` should be an integer or tuple or list of integers, but got {padding}") + + if isinstance(padding, int): + if torch.jit.is_scripting(): + # This maybe unreachable + raise ValueError("padding can't be an int while torchscripting, set it as a list [value, ]") + pad_left = pad_right = pad_top = pad_bottom = padding + else: + if len(padding) == 1: + pad_left = pad_right = pad_top = pad_bottom = padding[0] + elif len(padding) == 2: + pad_left = pad_right = padding[0] + pad_top = pad_bottom = padding[1] + elif len(padding) == 4: + pad_left = padding[0] + pad_top = padding[1] + pad_right = padding[2] + pad_bottom = padding[3] + else: + raise ValueError( + f"Padding must be an int or a 1, 2, or 4 element tuple, not a {len(padding)} element tuple" + ) + + return [pad_left, pad_right, pad_top, pad_bottom] def pad_image_tensor( @@ -654,50 +679,85 @@ def pad_image_tensor( fill: features.FillTypeJIT = None, padding_mode: str = "constant", ) -> torch.Tensor: + # Be aware that while `padding` has order `[left, top, right, bottom]` has order, `torch_padding` uses + # `[left, right, top, bottom]`. This stems from the fact that we align our API with PIL, but need to use `torch_pad` + # internally. + torch_padding = _FT._parse_pad_padding(padding) + + if padding_mode not in ["constant", "edge", "reflect", "symmetric"]: + raise ValueError( + f"`padding_mode` should be either `'constant'`, `'edge'`, `'reflect'` or `'symmetric'`, " + f"but got `'{padding_mode}'`." + ) + if fill is None: # This is a JIT workaround - return _pad_with_scalar_fill(image, padding, fill=None, padding_mode=padding_mode) + return _pad_with_scalar_fill(image, torch_padding, fill=0, padding_mode=padding_mode) elif isinstance(fill, (int, float)) or len(fill) == 1: fill_number = fill[0] if isinstance(fill, list) else fill - return _pad_with_scalar_fill(image, padding, fill=fill_number, padding_mode=padding_mode) + return _pad_with_scalar_fill(image, torch_padding, fill=fill_number, padding_mode=padding_mode) else: - return _pad_with_vector_fill(image, padding, fill=fill, padding_mode=padding_mode) + return _pad_with_vector_fill(image, torch_padding, fill=fill, padding_mode=padding_mode) def _pad_with_scalar_fill( image: torch.Tensor, - padding: Union[int, List[int]], - fill: Union[int, float, None], - padding_mode: str = "constant", + torch_padding: List[int], + fill: Union[int, float], + padding_mode: str, ) -> torch.Tensor: shape = image.shape num_channels, height, width = shape[-3:] if image.numel() > 0: - image = _FT.pad( - img=image.reshape(-1, num_channels, height, width), padding=padding, fill=fill, padding_mode=padding_mode - ) + image = image.reshape(-1, num_channels, height, width) + + if padding_mode == "edge": + # Similar to the padding order, `torch_pad`'s PIL's padding modes don't have the same names. Thus, we map + # the PIL name for the padding mode, which we are also using for our API, to the corresponding `torch_pad` + # name. + padding_mode = "replicate" + + if padding_mode == "constant": + image = torch_pad(image, torch_padding, mode=padding_mode, value=float(fill)) + elif padding_mode in ("reflect", "replicate"): + # `torch_pad` only supports `"reflect"` or `"replicate"` padding for floating point inputs. + # TODO: See https://github.com/pytorch/pytorch/issues/40763 + dtype = image.dtype + if not image.is_floating_point(): + needs_cast = True + image = image.to(torch.float32) + else: + needs_cast = False + + image = torch_pad(image, torch_padding, mode=padding_mode) + + if needs_cast: + image = image.to(dtype) + else: # padding_mode == "symmetric" + image = _FT._pad_symmetric(image, torch_padding) + new_height, new_width = image.shape[-2:] else: - left, right, top, bottom = _FT._parse_pad_padding(padding) + left, right, top, bottom = torch_padding new_height = height + top + bottom new_width = width + left + right return image.reshape(shape[:-3] + (num_channels, new_height, new_width)) -# TODO: This should be removed once pytorch pad supports non-scalar padding values +# TODO: This should be removed once torch_pad supports non-scalar padding values def _pad_with_vector_fill( image: torch.Tensor, - padding: Union[int, List[int]], + torch_padding: List[int], fill: List[float], - padding_mode: str = "constant", + padding_mode: str, ) -> torch.Tensor: if padding_mode != "constant": raise ValueError(f"Padding mode '{padding_mode}' is not supported if fill is not scalar") - output = _pad_with_scalar_fill(image, padding, fill=0, padding_mode="constant") - left, right, top, bottom = _parse_pad_padding(padding) + output = _pad_with_scalar_fill(image, torch_padding, fill=0, padding_mode="constant") + left, right, top, bottom = torch_padding fill = torch.tensor(fill, dtype=image.dtype, device=image.device).reshape(-1, 1, 1) if top > 0: @@ -711,6 +771,9 @@ def _pad_with_vector_fill( return output +pad_image_pil = _FP.pad + + def pad_mask( mask: torch.Tensor, padding: Union[int, List[int]], From 3f841ca879cb0ee0abbfbee8bc03ab971b1fc77e Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 2 Nov 2022 22:01:35 +0100 Subject: [PATCH 2/5] cleanup --- torchvision/prototype/transforms/functional/_geometry.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/torchvision/prototype/transforms/functional/_geometry.py b/torchvision/prototype/transforms/functional/_geometry.py index 563ebd69506..d6ef750b87f 100644 --- a/torchvision/prototype/transforms/functional/_geometry.py +++ b/torchvision/prototype/transforms/functional/_geometry.py @@ -682,7 +682,7 @@ def pad_image_tensor( # Be aware that while `padding` has order `[left, top, right, bottom]` has order, `torch_padding` uses # `[left, right, top, bottom]`. This stems from the fact that we align our API with PIL, but need to use `torch_pad` # internally. - torch_padding = _FT._parse_pad_padding(padding) + torch_padding = _parse_pad_padding(padding) if padding_mode not in ["constant", "edge", "reflect", "symmetric"]: raise ValueError( @@ -693,9 +693,10 @@ def pad_image_tensor( if fill is None: # This is a JIT workaround return _pad_with_scalar_fill(image, torch_padding, fill=0, padding_mode=padding_mode) - elif isinstance(fill, (int, float)) or len(fill) == 1: - fill_number = fill[0] if isinstance(fill, list) else fill - return _pad_with_scalar_fill(image, torch_padding, fill=fill_number, padding_mode=padding_mode) + elif isinstance(fill, (int, float)): + return _pad_with_scalar_fill(image, torch_padding, fill=fill, padding_mode=padding_mode) + elif len(fill) == 1: + return _pad_with_scalar_fill(image, torch_padding, fill=fill[0], padding_mode=padding_mode) else: return _pad_with_vector_fill(image, torch_padding, fill=fill, padding_mode=padding_mode) From 1622cd6c0ca320ac7acb1b909924cc86db4af936 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 3 Nov 2022 09:08:23 +0100 Subject: [PATCH 3/5] fix fill=None workaround --- torchvision/prototype/transforms/functional/_geometry.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torchvision/prototype/transforms/functional/_geometry.py b/torchvision/prototype/transforms/functional/_geometry.py index d6ef750b87f..ff39ad8b7bc 100644 --- a/torchvision/prototype/transforms/functional/_geometry.py +++ b/torchvision/prototype/transforms/functional/_geometry.py @@ -691,9 +691,9 @@ def pad_image_tensor( ) if fill is None: - # This is a JIT workaround - return _pad_with_scalar_fill(image, torch_padding, fill=0, padding_mode=padding_mode) - elif isinstance(fill, (int, float)): + fill = 0 + + if isinstance(fill, (int, float)): return _pad_with_scalar_fill(image, torch_padding, fill=fill, padding_mode=padding_mode) elif len(fill) == 1: return _pad_with_scalar_fill(image, torch_padding, fill=fill[0], padding_mode=padding_mode) From 84146c670a0c40a3876aafd509b4259373567cb4 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 3 Nov 2022 11:13:33 +0100 Subject: [PATCH 4/5] address review comments --- test/prototype_transforms_kernel_infos.py | 1 - .../prototype/transforms/functional/_geometry.py | 10 +++------- 2 files changed, 3 insertions(+), 8 deletions(-) diff --git a/test/prototype_transforms_kernel_infos.py b/test/prototype_transforms_kernel_infos.py index eac1fc0ae9d..e5f3c819874 100644 --- a/test/prototype_transforms_kernel_infos.py +++ b/test/prototype_transforms_kernel_infos.py @@ -1087,7 +1087,6 @@ def sample_inputs_pad_video(): reference_inputs_fn=reference_inputs_pad_image_tensor, closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS, test_marks=[ - xfail_jit_python_scalar_arg("padding"), xfail_jit_tuple_instead_of_list("padding"), xfail_jit_tuple_instead_of_list("fill"), # TODO: check if this is a regression since it seems that should be supported if `int` is ok diff --git a/torchvision/prototype/transforms/functional/_geometry.py b/torchvision/prototype/transforms/functional/_geometry.py index 405cdf47b91..bf4905c37c4 100644 --- a/torchvision/prototype/transforms/functional/_geometry.py +++ b/torchvision/prototype/transforms/functional/_geometry.py @@ -664,15 +664,9 @@ def rotate( def _parse_pad_padding(padding: Union[int, List[int]]) -> List[int]: - if not isinstance(padding, (int, tuple, list)): - raise TypeError(f"`padding` should be an integer or tuple or list of integers, but got {padding}") - if isinstance(padding, int): - if torch.jit.is_scripting(): - # This maybe unreachable - raise ValueError("padding can't be an int while torchscripting, set it as a list [value, ]") pad_left = pad_right = pad_top = pad_bottom = padding - else: + elif isinstance(padding, (tuple, list)): if len(padding) == 1: pad_left = pad_right = pad_top = pad_bottom = padding[0] elif len(padding) == 2: @@ -687,6 +681,8 @@ def _parse_pad_padding(padding: Union[int, List[int]]) -> List[int]: raise ValueError( f"Padding must be an int or a 1, 2, or 4 element tuple, not a {len(padding)} element tuple" ) + else: + raise TypeError(f"`padding` should be an integer or tuple or list of integers, but got {padding}") return [pad_left, pad_right, pad_top, pad_bottom] From 015ce01e001ec106f066500a171099e19eedea50 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 3 Nov 2022 15:29:11 +0100 Subject: [PATCH 5/5] remove more xfails --- test/prototype_transforms_dispatcher_infos.py | 1 - test/prototype_transforms_kernel_infos.py | 1 - 2 files changed, 2 deletions(-) diff --git a/test/prototype_transforms_dispatcher_infos.py b/test/prototype_transforms_dispatcher_infos.py index e570e4355c5..d817e4a71fb 100644 --- a/test/prototype_transforms_dispatcher_infos.py +++ b/test/prototype_transforms_dispatcher_infos.py @@ -234,7 +234,6 @@ def fill_sequence_needs_broadcast(args_kwargs): condition=lambda args_kwargs: fill_sequence_needs_broadcast(args_kwargs) and args_kwargs.kwargs.get("padding_mode", "constant") == "constant", ), - xfail_jit_python_scalar_arg("padding"), xfail_jit_tuple_instead_of_list("padding"), xfail_jit_tuple_instead_of_list("fill"), # TODO: check if this is a regression since it seems that should be supported if `int` is ok diff --git a/test/prototype_transforms_kernel_infos.py b/test/prototype_transforms_kernel_infos.py index 5a55e1fe271..a106aea65ba 100644 --- a/test/prototype_transforms_kernel_infos.py +++ b/test/prototype_transforms_kernel_infos.py @@ -1158,7 +1158,6 @@ def reference_inputs_pad_bounding_box(): reference_fn=reference_pad_bounding_box, reference_inputs_fn=reference_inputs_pad_bounding_box, test_marks=[ - xfail_jit_python_scalar_arg("padding"), xfail_jit_tuple_instead_of_list("padding"), ], ),