Skip to content

Commit ea0bdec

Browse files
committed
update rotate to main standards
1 parent 5fbeac3 commit ea0bdec

File tree

3 files changed

+11
-12
lines changed

3 files changed

+11
-12
lines changed

test/test_transforms_v2.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
assert_equal,
2626
cache,
2727
cpu_and_cuda,
28-
cvcuda_to_pil_compatible_tensor,
2928
freeze_rng_state,
3029
ignore_jit_no_profile_information_warning,
3130
make_bounding_boxes,
@@ -2149,14 +2148,14 @@ def test_functional(self, make_input):
21492148
(F.rotate_video, tv_tensors.Video),
21502149
(F.rotate_keypoints, tv_tensors.KeyPoints),
21512150
pytest.param(
2152-
F._geometry._rotate_cvcuda,
2153-
"cvcuda.Tensor",
2151+
F._geometry._rotate_image_cvcuda,
2152+
None,
21542153
marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA not available"),
21552154
),
21562155
],
21572156
)
21582157
def test_functional_signature(self, kernel, input_type):
2159-
if input_type == "cvcuda.Tensor":
2158+
if kernel is F._geometry._rotate_image_cvcuda:
21602159
input_type = _import_cvcuda().Tensor
21612160
check_functional_kernel_signature_match(F.rotate, kernel=kernel, input_type=input_type)
21622161

@@ -2205,8 +2204,8 @@ def test_functional_image_correctness(self, angle, center, interpolation, expand
22052204
actual = F.rotate(image, angle=angle, center=center, interpolation=interpolation, expand=expand, fill=fill)
22062205

22072206
if make_input is make_image_cvcuda:
2208-
actual = cvcuda_to_pil_compatible_tensor(actual)
2209-
image = cvcuda_to_pil_compatible_tensor(image)
2207+
actual = F.cvcuda_to_tensor(actual)[0].cpu()
2208+
image = F.cvcuda_to_tensor(image)[0].cpu()
22102209

22112210
expected = F.to_image(
22122211
F.rotate(
@@ -2256,8 +2255,8 @@ def test_transform_image_correctness(self, center, interpolation, expand, fill,
22562255
torch.manual_seed(seed)
22572256

22582257
if make_input is make_image_cvcuda:
2259-
actual = cvcuda_to_pil_compatible_tensor(actual)
2260-
image = cvcuda_to_pil_compatible_tensor(image)
2258+
actual = F.cvcuda_to_tensor(actual)[0].cpu()
2259+
image = F.cvcuda_to_tensor(image)[0].cpu()
22612260

22622261
expected = F.to_image(transform(F.to_pil_image(image)))
22632262

torchvision/transforms/v2/_geometry.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@
2929
is_pure_tensor,
3030
query_size,
3131
)
32-
from .functional._utils import is_cvcuda_tensor
3332

3433
CVCUDA_AVAILABLE = _is_cvcuda_available()
3534

@@ -607,7 +606,8 @@ class RandomRotation(Transform):
607606

608607
_v1_transform_cls = _transforms.RandomRotation
609608

610-
_transformed_types = Transform._transformed_types + (is_cvcuda_tensor,)
609+
if CVCUDA_AVAILABLE:
610+
_transformed_types = Transform._transformed_types + (_is_cvcuda_tensor,)
611611

612612
def __init__(
613613
self,

torchvision/transforms/v2/functional/_geometry.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1559,7 +1559,7 @@ def rotate_video(
15591559
}
15601560

15611561

1562-
def _rotate_cvcuda(
1562+
def _rotate_image_cvcuda(
15631563
inpt: "cvcuda.Tensor",
15641564
angle: float,
15651565
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
@@ -1642,7 +1642,7 @@ def _rotate_cvcuda(
16421642

16431643

16441644
if CVCUDA_AVAILABLE:
1645-
_register_kernel_internal(rotate, _import_cvcuda().Tensor)(_rotate_cvcuda)
1645+
_register_kernel_internal(rotate, _import_cvcuda().Tensor)(_rotate_image_cvcuda)
16461646

16471647

16481648
def pad(

0 commit comments

Comments
 (0)