-
Notifications
You must be signed in to change notification settings - Fork 7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
remove unnecessary checks from pad_image_tensor #6894
Changes from 1 commit
968f37b
3f841ca
1622cd6
75f4ba1
7a238ab
84146c6
44e475e
015ce01
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -4,6 +4,7 @@ | |||||
|
||||||
import PIL.Image | ||||||
import torch | ||||||
from torch.nn.functional import pad as torch_pad | ||||||
from torchvision.prototype import features | ||||||
from torchvision.transforms import functional_pil as _FP, functional_tensor as _FT | ||||||
from torchvision.transforms.functional import ( | ||||||
|
@@ -14,7 +15,6 @@ | |||||
pil_to_tensor, | ||||||
to_pil_image, | ||||||
) | ||||||
from torchvision.transforms.functional_tensor import _parse_pad_padding | ||||||
|
||||||
from ._meta import convert_format_bounding_box, get_spatial_size_image_pil | ||||||
|
||||||
|
@@ -645,7 +645,32 @@ def rotate( | |||||
return rotate_image_pil(inpt, angle, interpolation=interpolation, expand=expand, fill=fill, center=center) | ||||||
|
||||||
|
||||||
pad_image_pil = _FP.pad | ||||||
def _parse_pad_padding(padding: Union[int, List[int]]) -> List[int]: | ||||||
if not isinstance(padding, (int, tuple, list)): | ||||||
raise TypeError(f"`padding` should be an integer or tuple or list of integers, but got {padding}") | ||||||
pmeier marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
|
||||||
if isinstance(padding, int): | ||||||
if torch.jit.is_scripting(): | ||||||
# This maybe unreachable | ||||||
raise ValueError("padding can't be an int while torchscripting, set it as a list [value, ]") | ||||||
datumbox marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
pad_left = pad_right = pad_top = pad_bottom = padding | ||||||
else: | ||||||
if len(padding) == 1: | ||||||
pad_left = pad_right = pad_top = pad_bottom = padding[0] | ||||||
elif len(padding) == 2: | ||||||
pad_left = pad_right = padding[0] | ||||||
pad_top = pad_bottom = padding[1] | ||||||
elif len(padding) == 4: | ||||||
pad_left = padding[0] | ||||||
pad_top = padding[1] | ||||||
pad_right = padding[2] | ||||||
pad_bottom = padding[3] | ||||||
else: | ||||||
raise ValueError( | ||||||
f"Padding must be an int or a 1, 2, or 4 element tuple, not a {len(padding)} element tuple" | ||||||
) | ||||||
|
||||||
return [pad_left, pad_right, pad_top, pad_bottom] | ||||||
|
||||||
|
||||||
def pad_image_tensor( | ||||||
|
@@ -654,50 +679,85 @@ def pad_image_tensor( | |||||
fill: features.FillTypeJIT = None, | ||||||
padding_mode: str = "constant", | ||||||
) -> torch.Tensor: | ||||||
# Be aware that while `padding` has order `[left, top, right, bottom]` has order, `torch_padding` uses | ||||||
# `[left, right, top, bottom]`. This stems from the fact that we align our API with PIL, but need to use `torch_pad` | ||||||
# internally. | ||||||
torch_padding = _FT._parse_pad_padding(padding) | ||||||
|
||||||
if padding_mode not in ["constant", "edge", "reflect", "symmetric"]: | ||||||
raise ValueError( | ||||||
f"`padding_mode` should be either `'constant'`, `'edge'`, `'reflect'` or `'symmetric'`, " | ||||||
f"but got `'{padding_mode}'`." | ||||||
) | ||||||
|
||||||
if fill is None: | ||||||
# This is a JIT workaround | ||||||
return _pad_with_scalar_fill(image, padding, fill=None, padding_mode=padding_mode) | ||||||
return _pad_with_scalar_fill(image, torch_padding, fill=0, padding_mode=padding_mode) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Previously, this was handled by vision/torchvision/transforms/functional_tensor.py Lines 386 to 387 in e64784c
Since we are no longer calling it, we need to handle it our own. For some reason, the fill overwrite above did not work for me while developing and so I went this way. Rechecking today, it seems to work and it was probably caused by something else. I've fixed this in 1622cd6. More general though, why are we allowing There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||||||
elif isinstance(fill, (int, float)) or len(fill) == 1: | ||||||
fill_number = fill[0] if isinstance(fill, list) else fill | ||||||
return _pad_with_scalar_fill(image, padding, fill=fill_number, padding_mode=padding_mode) | ||||||
return _pad_with_scalar_fill(image, torch_padding, fill=fill_number, padding_mode=padding_mode) | ||||||
else: | ||||||
return _pad_with_vector_fill(image, padding, fill=fill, padding_mode=padding_mode) | ||||||
return _pad_with_vector_fill(image, torch_padding, fill=fill, padding_mode=padding_mode) | ||||||
|
||||||
|
||||||
def _pad_with_scalar_fill( | ||||||
image: torch.Tensor, | ||||||
padding: Union[int, List[int]], | ||||||
fill: Union[int, float, None], | ||||||
padding_mode: str = "constant", | ||||||
torch_padding: List[int], | ||||||
fill: Union[int, float], | ||||||
padding_mode: str, | ||||||
) -> torch.Tensor: | ||||||
shape = image.shape | ||||||
num_channels, height, width = shape[-3:] | ||||||
|
||||||
if image.numel() > 0: | ||||||
image = _FT.pad( | ||||||
img=image.reshape(-1, num_channels, height, width), padding=padding, fill=fill, padding_mode=padding_mode | ||||||
) | ||||||
image = image.reshape(-1, num_channels, height, width) | ||||||
|
||||||
if padding_mode == "edge": | ||||||
# Similar to the padding order, `torch_pad`'s PIL's padding modes don't have the same names. Thus, we map | ||||||
# the PIL name for the padding mode, which we are also using for our API, to the corresponding `torch_pad` | ||||||
# name. | ||||||
padding_mode = "replicate" | ||||||
|
||||||
if padding_mode == "constant": | ||||||
image = torch_pad(image, torch_padding, mode=padding_mode, value=float(fill)) | ||||||
elif padding_mode in ("reflect", "replicate"): | ||||||
# `torch_pad` only supports `"reflect"` or `"replicate"` padding for floating point inputs. | ||||||
# TODO: See https://github.com/pytorch/pytorch/issues/40763 | ||||||
dtype = image.dtype | ||||||
if not image.is_floating_point(): | ||||||
needs_cast = True | ||||||
image = image.to(torch.float32) | ||||||
else: | ||||||
needs_cast = False | ||||||
|
||||||
image = torch_pad(image, torch_padding, mode=padding_mode) | ||||||
|
||||||
if needs_cast: | ||||||
image = image.to(dtype) | ||||||
else: # padding_mode == "symmetric" | ||||||
image = _FT._pad_symmetric(image, torch_padding) | ||||||
|
||||||
new_height, new_width = image.shape[-2:] | ||||||
else: | ||||||
left, right, top, bottom = _FT._parse_pad_padding(padding) | ||||||
left, right, top, bottom = torch_padding | ||||||
new_height = height + top + bottom | ||||||
new_width = width + left + right | ||||||
|
||||||
return image.reshape(shape[:-3] + (num_channels, new_height, new_width)) | ||||||
|
||||||
|
||||||
# TODO: This should be removed once pytorch pad supports non-scalar padding values | ||||||
# TODO: This should be removed once torch_pad supports non-scalar padding values | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @vfdev-5 Is there an issue for that? |
||||||
def _pad_with_vector_fill( | ||||||
image: torch.Tensor, | ||||||
padding: Union[int, List[int]], | ||||||
torch_padding: List[int], | ||||||
fill: List[float], | ||||||
padding_mode: str = "constant", | ||||||
padding_mode: str, | ||||||
) -> torch.Tensor: | ||||||
if padding_mode != "constant": | ||||||
raise ValueError(f"Padding mode '{padding_mode}' is not supported if fill is not scalar") | ||||||
|
||||||
output = _pad_with_scalar_fill(image, padding, fill=0, padding_mode="constant") | ||||||
left, right, top, bottom = _parse_pad_padding(padding) | ||||||
output = _pad_with_scalar_fill(image, torch_padding, fill=0, padding_mode="constant") | ||||||
left, right, top, bottom = torch_padding | ||||||
fill = torch.tensor(fill, dtype=image.dtype, device=image.device).reshape(-1, 1, 1) | ||||||
|
||||||
if top > 0: | ||||||
|
@@ -711,6 +771,9 @@ def _pad_with_vector_fill( | |||||
return output | ||||||
|
||||||
|
||||||
pad_image_pil = _FP.pad | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Minor cleanup, since we normally define the PIL kernel below the tensor one. |
||||||
|
||||||
|
||||||
def pad_mask( | ||||||
mask: torch.Tensor, | ||||||
padding: Union[int, List[int]], | ||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've merged the check for invalid types as well as wrong lengths here, since this function is also used by
pad_bounding_box
and that currently doesn't have these checks.