Skip to content

Commit e9efdd8

Browse files
committed
update todtype based on PR reviews
1 parent 889d4af commit e9efdd8

File tree

5 files changed

+15
-5
lines changed

5 files changed

+15
-5
lines changed

torchvision/transforms/v2/_misc.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
get_bounding_boxes,
1818
get_keypoints,
1919
has_any,
20+
is_cvcuda_tensor,
2021
is_pure_tensor,
2122
)
2223

@@ -267,7 +268,7 @@ class ToDtype(Transform):
267268
Default: ``False``.
268269
"""
269270

270-
_transformed_types = (torch.Tensor,)
271+
_transformed_types = (torch.Tensor, is_cvcuda_tensor)
271272

272273
def __init__(
273274
self, dtype: Union[torch.dtype, dict[Union[type, str], Optional[torch.dtype]]], scale: bool = False

torchvision/transforms/v2/_utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from torchvision._utils import sequence_to_str
1616

1717
from torchvision.transforms.transforms import _check_sequence_input, _setup_angle, _setup_size # noqa: F401
18-
from torchvision.transforms.v2.functional import get_dimensions, get_size, is_pure_tensor
18+
from torchvision.transforms.v2.functional import get_dimensions, get_size, is_cvcuda_tensor, is_pure_tensor
1919
from torchvision.transforms.v2.functional._utils import _FillType, _FillTypeJIT
2020

2121

@@ -207,6 +207,7 @@ def query_size(flat_inputs: list[Any]) -> tuple[int, int]:
207207
tv_tensors.Mask,
208208
tv_tensors.BoundingBoxes,
209209
tv_tensors.KeyPoints,
210+
is_cvcuda_tensor,
210211
),
211212
)
212213
}

torchvision/transforms/v2/functional/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from torchvision.transforms import InterpolationMode # usort: skip
22

3-
from ._utils import is_pure_tensor, register_kernel # usort: skip
3+
from ._utils import is_cvcuda_tensor, is_pure_tensor, register_kernel # usort: skip
44

55
from ._meta import (
66
clamp_bounding_boxes,

torchvision/transforms/v2/functional/_meta.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ def _get_size_image_pil(image: PIL.Image.Image) -> list[int]:
114114
return [height, width]
115115

116116

117-
def get_size_image_cvcuda(image: "cvcuda.Tensor") -> list[int]:
117+
def _get_size_cvcuda(image: "cvcuda.Tensor") -> list[int]:
118118
"""Get size of `cvcuda.Tensor` with NHWC layout."""
119119
hw = list(image.shape[-3:-1])
120120
ndims = len(hw)
@@ -125,7 +125,7 @@ def get_size_image_cvcuda(image: "cvcuda.Tensor") -> list[int]:
125125

126126

127127
if CVCUDA_AVAILABLE:
128-
_get_size_image_cvcuda = _register_kernel_internal(get_size, cvcuda.Tensor)(get_size_image_cvcuda)
128+
_register_kernel_internal(get_size, _import_cvcuda().Tensor)(_get_size_cvcuda)
129129

130130

131131
@_register_kernel_internal(get_size, tv_tensors.Video, tv_tensor_wrapper=False)

torchvision/transforms/v2/functional/_utils.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,3 +169,11 @@ def _is_cvcuda_available():
169169
return True
170170
except ImportError:
171171
return False
172+
173+
174+
def is_cvcuda_tensor(inpt: Any) -> bool:
175+
try:
176+
cvcuda = _import_cvcuda()
177+
return isinstance(inpt, cvcuda.Tensor)
178+
except ImportError:
179+
return False

0 commit comments

Comments
 (0)