Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add circular padding #8619

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion test/test_transforms_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ def test_color_jitter_all(self, device, channels):


@pytest.mark.parametrize("device", cpu_and_cuda())
@pytest.mark.parametrize("m", ["constant", "edge", "reflect", "symmetric"])
@pytest.mark.parametrize("m", ["constant", "edge", "reflect", "symmetric", "circular"])
@pytest.mark.parametrize("mul", [1, -1])
def test_pad(m, mul, device):
fill = 127 if m == "constant" else 0
Expand Down Expand Up @@ -264,6 +264,7 @@ def test_crop(device):
{"padding_mode": "constant", "fill": 10},
{"padding_mode": "edge"},
{"padding_mode": "reflect"},
{"padding_mode": "circular"},
],
)
@pytest.mark.parametrize("pad_if_needed", [True, False])
Expand Down
10 changes: 7 additions & 3 deletions torchvision/transforms/_functional_pil.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def pad(
img: Image.Image,
padding: Union[int, List[int], Tuple[int, ...]],
fill: Optional[Union[float, List[float], Tuple[float, ...]]] = 0,
padding_mode: Literal["constant", "edge", "reflect", "symmetric"] = "constant",
padding_mode: Literal["constant", "edge", "reflect", "symmetric", "circular"] = "constant",
) -> Image.Image:

if not _is_pil_image(img):
Expand All @@ -168,8 +168,8 @@ def pad(
# Compatibility with `functional_tensor.pad`
padding = padding[0]

if padding_mode not in ["constant", "edge", "reflect", "symmetric"]:
raise ValueError("Padding mode should be either constant, edge, reflect or symmetric")
if padding_mode not in ["constant", "edge", "reflect", "symmetric", "circular"]:
raise ValueError("Padding mode should be either constant, edge, reflect, symmetric, or circular")

if padding_mode == "constant":
opts = _parse_fill(fill, img, name="fill")
Expand Down Expand Up @@ -201,6 +201,10 @@ def pad(

pad_left, pad_top, pad_right, pad_bottom = np.maximum(p, 0)

if padding_mode == "circular":
# Compatibility with np.pad
padding_mode = "wrap"

if img.mode == "P":
palette = img.getpalette()
img = np.asarray(img)
Expand Down
6 changes: 3 additions & 3 deletions torchvision/transforms/_functional_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,8 +398,8 @@ def pad(
f"Padding must be an int or a 1, 2, or 4 element tuple, not a {len(padding)} element tuple"
)

if padding_mode not in ["constant", "edge", "reflect", "symmetric"]:
raise ValueError("Padding mode should be either constant, edge, reflect or symmetric")
if padding_mode not in ["constant", "edge", "reflect", "symmetric", "circular"]:
raise ValueError("Padding mode should be either constant, edge, reflect, symmetric, or circular")

p = _parse_pad_padding(padding)

Expand All @@ -424,7 +424,7 @@ def pad(
need_cast = True
img = img.to(torch.float32)

if padding_mode in ("reflect", "replicate"):
if padding_mode in ("reflect", "replicate", "circular"):
img = torch_pad(img, p, mode=padding_mode)
else:
img = torch_pad(img, p, mode=padding_mode, value=float(fill))
Expand Down
8 changes: 6 additions & 2 deletions torchvision/transforms/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,7 +482,7 @@ def resize(
def pad(img: Tensor, padding: List[int], fill: Union[int, float] = 0, padding_mode: str = "constant") -> Tensor:
r"""Pad the given image on all sides with the given "pad" value.
If the image is torch Tensor, it is expected
to have [..., H, W] shape, where ... means at most 2 leading dimensions for mode reflect and symmetric,
to have [..., H, W] shape, where ... means at most 2 leading dimensions for modes reflect, symmetric, and circular,
at most 3 leading dimensions for mode edge,
and an arbitrary number of leading dimensions for mode constant

Expand All @@ -501,7 +501,7 @@ def pad(img: Tensor, padding: List[int], fill: Union[int, float] = 0, padding_mo
This value is only used when the padding_mode is constant.
Only number is supported for torch Tensor.
Only int or tuple value is supported for PIL Image.
padding_mode (str): Type of padding. Should be: constant, edge, reflect or symmetric.
padding_mode (str): Type of padding. Should be: constant, edge, reflect, symmetric, or circular.
Default is constant.

- constant: pads with a constant value, this value is specified with fill
Expand All @@ -517,6 +517,10 @@ def pad(img: Tensor, padding: List[int], fill: Union[int, float] = 0, padding_mo
For example, padding [1, 2, 3, 4] with 2 elements on both sides in symmetric mode
will result in [2, 1, 1, 2, 3, 4, 4, 3]

- circular: pads by repeating the values from the opposite side of the image in order.
For example, padding [1, 2, 3, 4] with 2 elements on both sides in circular mode
will result in [3, 4, 1, 2, 3, 4, 1, 2]

Returns:
PIL Image or Tensor: Padded image.
"""
Expand Down