Skip to content

Commit cf2f2bc

Browse files
committed
move transformed type check to Rotate transform
1 parent 84d204b commit cf2f2bc

File tree

2 files changed

+5
-2
lines changed

2 files changed

+5
-2
lines changed

torchvision/transforms/v2/_geometry.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
is_pure_tensor,
3030
query_size,
3131
)
32+
from .functional._utils import is_cvcuda_tensor
3233

3334

3435
class RandomHorizontalFlip(_RandomApplyTransform):
@@ -598,6 +599,8 @@ class RandomRotation(Transform):
598599

599600
_v1_transform_cls = _transforms.RandomRotation
600601

602+
_transformed_types = Transform._transformed_types + (is_cvcuda_tensor,)
603+
601604
def __init__(
602605
self,
603606
degrees: Union[numbers.Number, Sequence],

torchvision/transforms/v2/_transform.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from torchvision.transforms.v2._utils import check_type, has_any, is_pure_tensor
1212
from torchvision.utils import _log_api_usage_once
1313

14-
from .functional._utils import _get_kernel, is_cvcuda_tensor
14+
from .functional._utils import _get_kernel
1515

1616

1717
class Transform(nn.Module):
@@ -23,7 +23,7 @@ class Transform(nn.Module):
2323

2424
# Class attribute defining transformed types. Other types are passed-through without any transformation
2525
# We support both Types and callables that are able to do further checks on the type of the input.
26-
_transformed_types: tuple[type | Callable[[Any], bool], ...] = (torch.Tensor, PIL.Image.Image, is_cvcuda_tensor)
26+
_transformed_types: tuple[type | Callable[[Any], bool], ...] = (torch.Tensor, PIL.Image.Image)
2727

2828
def __init__(self) -> None:
2929
super().__init__()

0 commit comments

Comments
 (0)