Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

remove unnecessary checks from pad_image_tensor #6894

Merged
merged 8 commits into from
Nov 3, 2022
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 80 additions & 17 deletions torchvision/prototype/transforms/functional/_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import PIL.Image
import torch
from torch.nn.functional import pad as torch_pad
from torchvision.prototype import features
from torchvision.transforms import functional_pil as _FP, functional_tensor as _FT
from torchvision.transforms.functional import (
Expand All @@ -14,7 +15,6 @@
pil_to_tensor,
to_pil_image,
)
from torchvision.transforms.functional_tensor import _parse_pad_padding

from ._meta import convert_format_bounding_box, get_spatial_size_image_pil

Expand Down Expand Up @@ -645,7 +645,32 @@ def rotate(
return rotate_image_pil(inpt, angle, interpolation=interpolation, expand=expand, fill=fill, center=center)


pad_image_pil = _FP.pad
def _parse_pad_padding(padding: Union[int, List[int]]) -> List[int]:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've merged the check for invalid types as well as wrong lengths here, since this function is also used by pad_bounding_box and that currently doesn't have these checks.

if not isinstance(padding, (int, tuple, list)):
raise TypeError(f"`padding` should be an integer or tuple or list of integers, but got {padding}")
pmeier marked this conversation as resolved.
Show resolved Hide resolved

if isinstance(padding, int):
if torch.jit.is_scripting():
# This maybe unreachable
raise ValueError("padding can't be an int while torchscripting, set it as a list [value, ]")
datumbox marked this conversation as resolved.
Show resolved Hide resolved
pad_left = pad_right = pad_top = pad_bottom = padding
else:
if len(padding) == 1:
pad_left = pad_right = pad_top = pad_bottom = padding[0]
elif len(padding) == 2:
pad_left = pad_right = padding[0]
pad_top = pad_bottom = padding[1]
elif len(padding) == 4:
pad_left = padding[0]
pad_top = padding[1]
pad_right = padding[2]
pad_bottom = padding[3]
else:
raise ValueError(
f"Padding must be an int or a 1, 2, or 4 element tuple, not a {len(padding)} element tuple"
)

return [pad_left, pad_right, pad_top, pad_bottom]


def pad_image_tensor(
Expand All @@ -654,50 +679,85 @@ def pad_image_tensor(
fill: features.FillTypeJIT = None,
padding_mode: str = "constant",
) -> torch.Tensor:
# Be aware that while `padding` has order `[left, top, right, bottom]` has order, `torch_padding` uses
# `[left, right, top, bottom]`. This stems from the fact that we align our API with PIL, but need to use `torch_pad`
# internally.
torch_padding = _FT._parse_pad_padding(padding)

if padding_mode not in ["constant", "edge", "reflect", "symmetric"]:
raise ValueError(
f"`padding_mode` should be either `'constant'`, `'edge'`, `'reflect'` or `'symmetric'`, "
f"but got `'{padding_mode}'`."
)

if fill is None:
# This is a JIT workaround
return _pad_with_scalar_fill(image, padding, fill=None, padding_mode=padding_mode)
return _pad_with_scalar_fill(image, torch_padding, fill=0, padding_mode=padding_mode)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why fill=0 instead of fill=None as originally ? Do we still need this workaround ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Previously, this was handled by _FT.pad:

if fill is None:
fill = 0

Since we are no longer calling it, we need to handle it our own. For some reason, the fill overwrite above did not work for me while developing and so I went this way. Rechecking today, it seems to work and it was probably caused by something else. I've fixed this in 1622cd6.

More general though, why are we allowing None in the first place and even use it as default if we all we do with it is to map it to 0? I faintly remember there were discussions about it, but I think I was out of the loop on them. Is there a public discussion that you can point me to or did this happen offline? Intuitively, I would remove the None as valid value and just use 0 as default.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

More general though, why are we allowing None in the first place and even use it as default if we all we do with it is to map it to 0?

Origins for fill=None starts from your PR #1760 on rotate op.
See this #6623 for later discussions

elif isinstance(fill, (int, float)) or len(fill) == 1:
fill_number = fill[0] if isinstance(fill, list) else fill
return _pad_with_scalar_fill(image, padding, fill=fill_number, padding_mode=padding_mode)
return _pad_with_scalar_fill(image, torch_padding, fill=fill_number, padding_mode=padding_mode)
else:
return _pad_with_vector_fill(image, padding, fill=fill, padding_mode=padding_mode)
return _pad_with_vector_fill(image, torch_padding, fill=fill, padding_mode=padding_mode)


def _pad_with_scalar_fill(
image: torch.Tensor,
padding: Union[int, List[int]],
fill: Union[int, float, None],
padding_mode: str = "constant",
torch_padding: List[int],
fill: Union[int, float],
padding_mode: str,
) -> torch.Tensor:
shape = image.shape
num_channels, height, width = shape[-3:]

if image.numel() > 0:
image = _FT.pad(
img=image.reshape(-1, num_channels, height, width), padding=padding, fill=fill, padding_mode=padding_mode
)
image = image.reshape(-1, num_channels, height, width)

if padding_mode == "edge":
# Similar to the padding order, `torch_pad`'s PIL's padding modes don't have the same names. Thus, we map
# the PIL name for the padding mode, which we are also using for our API, to the corresponding `torch_pad`
# name.
padding_mode = "replicate"

if padding_mode == "constant":
image = torch_pad(image, torch_padding, mode=padding_mode, value=float(fill))
elif padding_mode in ("reflect", "replicate"):
# `torch_pad` only supports `"reflect"` or `"replicate"` padding for floating point inputs.
# TODO: See https://github.com/pytorch/pytorch/issues/40763
dtype = image.dtype
if not image.is_floating_point():
needs_cast = True
image = image.to(torch.float32)
else:
needs_cast = False

image = torch_pad(image, torch_padding, mode=padding_mode)

if needs_cast:
image = image.to(dtype)
else: # padding_mode == "symmetric"
image = _FT._pad_symmetric(image, torch_padding)

new_height, new_width = image.shape[-2:]
else:
left, right, top, bottom = _FT._parse_pad_padding(padding)
left, right, top, bottom = torch_padding
new_height = height + top + bottom
new_width = width + left + right

return image.reshape(shape[:-3] + (num_channels, new_height, new_width))


# TODO: This should be removed once pytorch pad supports non-scalar padding values
# TODO: This should be removed once torch_pad supports non-scalar padding values
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vfdev-5 Is there an issue for that?

def _pad_with_vector_fill(
image: torch.Tensor,
padding: Union[int, List[int]],
torch_padding: List[int],
fill: List[float],
padding_mode: str = "constant",
padding_mode: str,
) -> torch.Tensor:
if padding_mode != "constant":
raise ValueError(f"Padding mode '{padding_mode}' is not supported if fill is not scalar")

output = _pad_with_scalar_fill(image, padding, fill=0, padding_mode="constant")
left, right, top, bottom = _parse_pad_padding(padding)
output = _pad_with_scalar_fill(image, torch_padding, fill=0, padding_mode="constant")
left, right, top, bottom = torch_padding
fill = torch.tensor(fill, dtype=image.dtype, device=image.device).reshape(-1, 1, 1)

if top > 0:
Expand All @@ -711,6 +771,9 @@ def _pad_with_vector_fill(
return output


pad_image_pil = _FP.pad
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor cleanup, since we normally define the PIL kernel below the tensor one.



def pad_mask(
mask: torch.Tensor,
padding: Union[int, List[int]],
Expand Down