Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fast rotation for right angles #8295

Merged
merged 13 commits into from
Mar 13, 2024
9 changes: 9 additions & 0 deletions test/test_transforms_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1782,6 +1782,15 @@ def test_transform_unknown_fill_error(self):
with pytest.raises(TypeError, match="Got inappropriate fill arg"):
transforms.RandomAffine(degrees=0, fill="fill")

@pytest.mark.parametrize("angle", [0, 90, 180, 270])
def test_functional_image_fast_path_correctness(self, angle):
image = make_image(dtype=torch.uint8, device="cpu")

actual = F.rotate(image, angle=angle)
expected = F.to_image(F.rotate(F.to_pil_image(image), angle=angle))

torch.testing.assert_close(actual, expected)


class TestContainerTransforms:
class BuiltinTransform(transforms.Transform):
Expand Down
13 changes: 13 additions & 0 deletions torchvision/transforms/v2/functional/_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -997,6 +997,19 @@ def rotate_image(
center: Optional[List[float]] = None,
fill: _FillTypeJIT = None,
) -> torch.Tensor:
angle = angle % 360 # shift angle to [0, 360) range

# fast path: transpose without affine transform
if center is None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The tests are failing and I think it's because we should set expand to True. When expand is False we're not supposed to be changing the shape of the output (which is what rot90 does!).

Suggested change
if center is None:
if expand and center is None:

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here is what does PIL here:
https://github.com/python-pillow/Pillow/blob/8f63748e50378424628155994efd7e0739a4d1d1/src/PIL/Image.py#L2287-L2296

Yes, I agree that it is incorrect to omit expand for image with h != w, sorry for an incorrect suggestion previously.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we also check for square input then? So even if expand is false, rot90 is still valid.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Following their code they: 1) accept angle 0 and 180 whatever provided expand value and 2) for rotations 90 or 270 they do if angle in (90, 270) and (expand or self.width == self.height).
I think this is reasonable and we can do the same

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will update accordingly

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have updated the code. Also updated the test cases to cover everything. I just realized the tests didn't fail on my machine because I was running inside an environment with torchvision installed (big mistake!).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No worries @gau-nernst thanks for the update. I can confirm the tests are passing now (tried locally). let's just wait for the CI to be green before merging

if angle == 0:
return image.clone()
if angle == 90:
return torch.rot90(image, k=1, dims=(-1, -2))
if angle == 180:
return torch.rot90(image, k=2, dims=(-1, -2))
if angle == 270:
return torch.rot90(image, k=3, dims=(-1, -2))

interpolation = _check_interpolation(interpolation)

input_height, input_width = image.shape[-2:]
Expand Down