Skip to content

Commit 1175844

Browse files
committed
update to main standards
1 parent c37991a commit 1175844

File tree

7 files changed

+62
-67
lines changed

7 files changed

+62
-67
lines changed

test/test_transforms_v2.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
assert_equal,
2626
cache,
2727
cpu_and_cuda,
28-
cvcuda_to_pil_compatible_tensor,
2928
freeze_rng_state,
3029
ignore_jit_no_profile_information_warning,
3130
make_bounding_boxes,
@@ -5200,14 +5199,14 @@ def test_functional(self, make_input):
52005199
(F.perspective_video, tv_tensors.Video),
52015200
(F.perspective_keypoints, tv_tensors.KeyPoints),
52025201
pytest.param(
5203-
F._geometry._perspective_cvcuda,
5204-
"cvcuda.Tensor",
5202+
F._geometry._perspective_image_cvcuda,
5203+
None,
52055204
marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA not available"),
52065205
),
52075206
],
52085207
)
52095208
def test_functional_signature(self, kernel, input_type):
5210-
if input_type == "cvcuda.Tensor":
5209+
if kernel is F._geometry._perspective_image_cvcuda:
52115210
input_type = _import_cvcuda().Tensor
52125211
check_functional_kernel_signature_match(F.perspective, kernel=kernel, input_type=input_type)
52135212

@@ -5256,8 +5255,8 @@ def test_image_functional_correctness(self, coefficients, interpolation, fill, m
52565255
image, startpoints=None, endpoints=None, coefficients=coefficients, interpolation=interpolation, fill=fill
52575256
)
52585257
if make_input is make_image_cvcuda:
5259-
actual = cvcuda_to_pil_compatible_tensor(actual)
5260-
image = cvcuda_to_pil_compatible_tensor(image)
5258+
actual = F.cvcuda_to_tensor(actual)[0].cpu()
5259+
image = F.cvcuda_to_tensor(image)[0].cpu()
52615260

52625261
expected = F.to_image(
52635262
F.perspective(

torchvision/transforms/v2/_geometry.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -944,7 +944,8 @@ class RandomPerspective(_RandomApplyTransform):
944944

945945
_v1_transform_cls = _transforms.RandomPerspective
946946

947-
_transformed_types = _RandomApplyTransform._transformed_types + (is_cvcuda_tensor,)
947+
if CVCUDA_AVAILABLE:
948+
_transformed_types = _RandomApplyTransform._transformed_types + (_is_cvcuda_tensor,)
948949

949950
def __init__(
950951
self,

torchvision/transforms/v2/functional/_augment.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import io
2-
from typing import TYPE_CHECKING
32

43
import PIL.Image
54

@@ -9,15 +8,7 @@
98
from torchvision.transforms.functional import pil_to_tensor, to_pil_image
109
from torchvision.utils import _log_api_usage_once
1110

12-
from ._utils import _get_kernel, _import_cvcuda, _is_cvcuda_available, _register_kernel_internal
13-
14-
15-
CVCUDA_AVAILABLE = _is_cvcuda_available()
16-
17-
if TYPE_CHECKING:
18-
import cvcuda # type: ignore[import-not-found]
19-
if CVCUDA_AVAILABLE:
20-
cvcuda = _import_cvcuda() # noqa: F811
11+
from ._utils import _get_kernel, _register_kernel_internal
2112

2213

2314
def erase(

torchvision/transforms/v2/functional/_color.py

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
from typing import TYPE_CHECKING
2-
31
import PIL.Image
42
import torch
53
from torch.nn.functional import conv2d
@@ -11,15 +9,7 @@
119

1210
from ._misc import _num_value_bits, to_dtype_image
1311
from ._type_conversion import pil_to_tensor, to_pil_image
14-
from ._utils import _get_kernel, _import_cvcuda, _is_cvcuda_available, _register_kernel_internal
15-
16-
17-
CVCUDA_AVAILABLE = _is_cvcuda_available()
18-
19-
if TYPE_CHECKING:
20-
import cvcuda # type: ignore[import-not-found]
21-
if CVCUDA_AVAILABLE:
22-
cvcuda = _import_cvcuda() # noqa: F811
12+
from ._utils import _get_kernel, _register_kernel_internal
2313

2414

2515
def rgb_to_grayscale(inpt: torch.Tensor, num_output_channels: int = 1) -> torch.Tensor:

torchvision/transforms/v2/functional/_geometry.py

Lines changed: 4 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030

3131
from ._utils import (
3232
_FillTypeJIT,
33+
_get_cvcuda_interp,
3334
_get_kernel,
3435
_import_cvcuda,
3536
_is_cvcuda_available,
@@ -2287,31 +2288,7 @@ def perspective_video(
22872288
)
22882289

22892290

2290-
if CVCUDA_AVAILABLE:
2291-
_cvcuda_interp = {
2292-
InterpolationMode.BILINEAR: cvcuda.Interp.LINEAR,
2293-
"bilinear": cvcuda.Interp.LINEAR,
2294-
"linear": cvcuda.Interp.LINEAR,
2295-
2: cvcuda.Interp.LINEAR,
2296-
InterpolationMode.BICUBIC: cvcuda.Interp.CUBIC,
2297-
"bicubic": cvcuda.Interp.CUBIC,
2298-
3: cvcuda.Interp.CUBIC,
2299-
InterpolationMode.NEAREST: cvcuda.Interp.NEAREST,
2300-
"nearest": cvcuda.Interp.NEAREST,
2301-
0: cvcuda.Interp.NEAREST,
2302-
InterpolationMode.BOX: cvcuda.Interp.BOX,
2303-
"box": cvcuda.Interp.BOX,
2304-
4: cvcuda.Interp.BOX,
2305-
InterpolationMode.HAMMING: cvcuda.Interp.HAMMING,
2306-
"hamming": cvcuda.Interp.HAMMING,
2307-
5: cvcuda.Interp.HAMMING,
2308-
InterpolationMode.LANCZOS: cvcuda.Interp.LANCZOS,
2309-
"lanczos": cvcuda.Interp.LANCZOS,
2310-
1: cvcuda.Interp.LANCZOS,
2311-
}
2312-
2313-
2314-
def _perspective_cvcuda(
2291+
def _perspective_image_cvcuda(
23152292
image: "cvcuda.Tensor",
23162293
startpoints: Optional[list[list[int]]],
23172294
endpoints: Optional[list[list[int]]],
@@ -2324,9 +2301,7 @@ def _perspective_cvcuda(
23242301
c = _perspective_coefficients(startpoints, endpoints, coefficients)
23252302
interpolation = _check_interpolation(interpolation)
23262303

2327-
interp = _cvcuda_interp.get(interpolation)
2328-
if interp is None:
2329-
raise ValueError(f"Invalid interpolation mode: {interpolation}")
2304+
interp = _get_cvcuda_interp(interpolation)
23302305

23312306
xform = np.array([[c[0], c[1], c[2]], [c[3], c[4], c[5]], [c[6], c[7], 1.0]], dtype=np.float32)
23322307

@@ -2348,7 +2323,7 @@ def _perspective_cvcuda(
23482323

23492324

23502325
if CVCUDA_AVAILABLE:
2351-
_register_kernel_internal(perspective, _import_cvcuda().Tensor)(_perspective_cvcuda)
2326+
_register_kernel_internal(perspective, _import_cvcuda().Tensor)(_perspective_image_cvcuda)
23522327

23532328

23542329
def elastic(

torchvision/transforms/v2/functional/_misc.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import math
2-
from typing import Optional, TYPE_CHECKING
2+
from typing import Optional
33

44
import PIL.Image
55
import torch
@@ -13,14 +13,7 @@
1313

1414
from ._meta import _convert_bounding_box_format
1515

16-
from ._utils import _get_kernel, _import_cvcuda, _is_cvcuda_available, _register_kernel_internal, is_pure_tensor
17-
18-
CVCUDA_AVAILABLE = _is_cvcuda_available()
19-
20-
if TYPE_CHECKING:
21-
import cvcuda # type: ignore[import-not-found]
22-
if CVCUDA_AVAILABLE:
23-
cvcuda = _import_cvcuda() # noqa: F811
16+
from ._utils import _get_kernel, _register_kernel_internal, is_pure_tensor
2417

2518

2619
def normalize(

torchvision/transforms/v2/functional/_utils.py

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,13 @@
11
import functools
22
from collections.abc import Sequence
3-
from typing import Any, Callable, Optional, Union
3+
from typing import Any, Callable, Optional, TYPE_CHECKING, Union
44

55
import torch
66
from torchvision import tv_tensors
7+
from torchvision.transforms.functional import InterpolationMode
8+
9+
if TYPE_CHECKING:
10+
import cvcuda # type: ignore[import-not-found]
711

812
_FillType = Union[int, float, Sequence[int], Sequence[float], None]
913
_FillTypeJIT = Optional[list[float]]
@@ -177,3 +181,45 @@ def _is_cvcuda_tensor(inpt: Any) -> bool:
177181
return isinstance(inpt, cvcuda.Tensor)
178182
except ImportError:
179183
return False
184+
185+
186+
_interpolation_mode_to_cvcuda_interp: dict[InterpolationMode | str | int, "cvcuda.Interp"] = {}
187+
188+
189+
def _populate_interpolation_mode_to_cvcuda_interp():
190+
cvcuda = _import_cvcuda()
191+
192+
global _interpolation_mode_to_cvcuda_interp
193+
194+
_interpolation_mode_to_cvcuda_interp = {
195+
InterpolationMode.BILINEAR: cvcuda.Interp.LINEAR,
196+
"bilinear": cvcuda.Interp.LINEAR,
197+
"linear": cvcuda.Interp.LINEAR,
198+
2: cvcuda.Interp.LINEAR,
199+
InterpolationMode.BICUBIC: cvcuda.Interp.CUBIC,
200+
"bicubic": cvcuda.Interp.CUBIC,
201+
3: cvcuda.Interp.CUBIC,
202+
InterpolationMode.NEAREST: cvcuda.Interp.NEAREST,
203+
"nearest": cvcuda.Interp.NEAREST,
204+
0: cvcuda.Interp.NEAREST,
205+
InterpolationMode.BOX: cvcuda.Interp.BOX,
206+
"box": cvcuda.Interp.BOX,
207+
4: cvcuda.Interp.BOX,
208+
InterpolationMode.HAMMING: cvcuda.Interp.HAMMING,
209+
"hamming": cvcuda.Interp.HAMMING,
210+
5: cvcuda.Interp.HAMMING,
211+
InterpolationMode.LANCZOS: cvcuda.Interp.LANCZOS,
212+
"lanczos": cvcuda.Interp.LANCZOS,
213+
1: cvcuda.Interp.LANCZOS,
214+
}
215+
216+
217+
def _get_cvcuda_interp(interpolation: InterpolationMode | str | int) -> "cvcuda.Interp":
218+
if len(_interpolation_mode_to_cvcuda_interp) == 0:
219+
_populate_interpolation_mode_to_cvcuda_interp()
220+
221+
interp = _interpolation_mode_to_cvcuda_interp.get(interpolation)
222+
if interp is None:
223+
raise ValueError(f"Interpolation mode {interpolation} is not supported with CV-CUDA")
224+
225+
return interp

0 commit comments

Comments
 (0)