Skip to content

Commit d3ef0bd

Browse files
committed
update normalize with changes from main
1 parent e281964 commit d3ef0bd

File tree

9 files changed

+19
-51
lines changed

9 files changed

+19
-51
lines changed

test/common_utils.py

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -277,17 +277,6 @@ def combinations_grid(**kwargs):
277277
return [dict(zip(kwargs.keys(), values)) for values in itertools.product(*kwargs.values())]
278278

279279

280-
def cvcuda_to_pil_compatible_tensor(tensor: "cvcuda.Tensor") -> torch.Tensor:
281-
tensor = cvcuda_to_tensor(tensor)
282-
if tensor.ndim != 4:
283-
raise ValueError(f"CV-CUDA Tensor should be 4 dimensional. Got {tensor.ndim} dimensions.")
284-
if tensor.shape[0] != 1:
285-
raise ValueError(
286-
f"CV-CUDA Tensor should have batch dimension 1 for comparison with PIL.Image.Image. Got {tensor.shape[0]}."
287-
)
288-
return tensor.squeeze(0).cpu()
289-
290-
291280
class ImagePair(TensorLikePair):
292281
def __init__(
293282
self,
@@ -316,11 +305,6 @@ def __init__(
316305
expected = expected[0]
317306
expected = expected.cpu()
318307

319-
# handle check for CV-CUDA Tensors
320-
if CVCUDA_AVAILABLE and isinstance(actual, _import_cvcuda().Tensor):
321-
# Use the PIL compatible tensor, so we can always compare with PIL.Image.Image
322-
actual = cvcuda_to_pil_compatible_tensor(actual)
323-
324308
super().__init__(actual, expected, **other_parameters)
325309
self.mae = mae
326310

test/test_transforms_v2.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
assert_equal,
2626
cache,
2727
cpu_and_cuda,
28-
cvcuda_to_pil_compatible_tensor,
2928
freeze_rng_state,
3029
ignore_jit_no_profile_information_warning,
3130
make_bounding_boxes,
@@ -5591,7 +5590,7 @@ def test_functional(self, make_input):
55915590
(F.normalize_image, tv_tensors.Image),
55925591
(F.normalize_video, tv_tensors.Video),
55935592
pytest.param(
5594-
F._misc._normalize_cvcuda,
5593+
F._misc._normalize_image_cvcuda,
55955594
"cvcuda.Tensor",
55965595
marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA"),
55975596
),
@@ -5669,7 +5668,7 @@ def test_correctness_image(self, mean, std, dtype, make_input, fn):
56695668
actual = fn(image, mean=mean, std=std)
56705669

56715670
if make_input == make_image_cvcuda:
5672-
image = cvcuda_to_pil_compatible_tensor(image)
5671+
image = F.cvcuda_to_tensor(image)[0].cpu()
56735672

56745673
expected = self._reference_normalize_image(image, mean=mean, std=std)
56755674

torchvision/transforms/v2/_misc.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
from torchvision import transforms as _transforms, tv_tensors
1111
from torchvision.transforms.v2 import functional as F, Transform
12+
from torchvision.transforms.v2.functional._utils import _is_cvcuda_available, _is_cvcuda_tensor
1213

1314
from ._utils import (
1415
_parse_labels_getter,
@@ -17,11 +18,13 @@
1718
get_bounding_boxes,
1819
get_keypoints,
1920
has_any,
20-
is_cvcuda_tensor,
2121
is_pure_tensor,
2222
)
2323

2424

25+
CVCUDA_AVAILABLE = _is_cvcuda_available()
26+
27+
2528
# TODO: do we want/need to expose this?
2629
class Identity(Transform):
2730
def transform(self, inpt: Any, params: dict[str, Any]) -> Any:
@@ -161,7 +164,8 @@ class Normalize(Transform):
161164

162165
_v1_transform_cls = _transforms.Normalize
163166

164-
_transformed_types = Transform._transformed_types + (is_cvcuda_tensor,)
167+
if CVCUDA_AVAILABLE:
168+
_transformed_types = Transform._transformed_types + (_is_cvcuda_tensor,)
165169

166170
def __init__(self, mean: Sequence[float], std: Sequence[float], inplace: bool = False):
167171
super().__init__()

torchvision/transforms/v2/_transform.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from torch import nn
99
from torch.utils._pytree import tree_flatten, tree_unflatten
1010
from torchvision import tv_tensors
11-
from torchvision.transforms.v2._utils import check_type, has_any, is_cvcuda_tensor, is_pure_tensor
11+
from torchvision.transforms.v2._utils import check_type, has_any, is_pure_tensor
1212
from torchvision.utils import _log_api_usage_once
1313

1414
from .functional._utils import _get_kernel
@@ -23,7 +23,7 @@ class Transform(nn.Module):
2323

2424
# Class attribute defining transformed types. Other types are passed-through without any transformation
2525
# We support both Types and callables that are able to do further checks on the type of the input.
26-
_transformed_types: tuple[type | Callable[[Any], bool], ...] = (torch.Tensor, PIL.Image.Image, is_cvcuda_tensor)
26+
_transformed_types: tuple[type | Callable[[Any], bool], ...] = (torch.Tensor, PIL.Image.Image)
2727

2828
def __init__(self) -> None:
2929
super().__init__()

torchvision/transforms/v2/_utils.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@
1515
from torchvision._utils import sequence_to_str
1616

1717
from torchvision.transforms.transforms import _check_sequence_input, _setup_angle, _setup_size # noqa: F401
18-
from torchvision.transforms.v2.functional import get_dimensions, get_size, is_cvcuda_tensor, is_pure_tensor
19-
from torchvision.transforms.v2.functional._utils import _FillType, _FillTypeJIT
18+
from torchvision.transforms.v2.functional import get_dimensions, get_size, is_pure_tensor
19+
from torchvision.transforms.v2.functional._utils import _FillType, _FillTypeJIT, _is_cvcuda_tensor
2020

2121

2222
def _setup_number_or_seq(arg: int | float | Sequence[int | float], name: str) -> Sequence[float]:
@@ -182,7 +182,7 @@ def query_chw(flat_inputs: list[Any]) -> tuple[int, int, int]:
182182
chws = {
183183
tuple(get_dimensions(inpt))
184184
for inpt in flat_inputs
185-
if check_type(inpt, (is_pure_tensor, tv_tensors.Image, PIL.Image.Image, tv_tensors.Video, is_cvcuda_tensor))
185+
if check_type(inpt, (is_pure_tensor, tv_tensors.Image, PIL.Image.Image, tv_tensors.Video, _is_cvcuda_tensor))
186186
}
187187
if not chws:
188188
raise TypeError("No image or video was found in the sample")
@@ -207,7 +207,7 @@ def query_size(flat_inputs: list[Any]) -> tuple[int, int]:
207207
tv_tensors.Mask,
208208
tv_tensors.BoundingBoxes,
209209
tv_tensors.KeyPoints,
210-
is_cvcuda_tensor,
210+
_is_cvcuda_tensor,
211211
),
212212
)
213213
}

torchvision/transforms/v2/functional/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from torchvision.transforms import InterpolationMode # usort: skip
22

3-
from ._utils import is_pure_tensor, register_kernel, is_cvcuda_tensor # usort: skip
3+
from ._utils import is_pure_tensor, register_kernel # usort: skip
44

55
from ._meta import (
66
clamp_bounding_boxes,

torchvision/transforms/v2/functional/_augment.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import io
2-
from typing import TYPE_CHECKING
32

43
import PIL.Image
54

@@ -9,15 +8,7 @@
98
from torchvision.transforms.functional import pil_to_tensor, to_pil_image
109
from torchvision.utils import _log_api_usage_once
1110

12-
from ._utils import _get_kernel, _import_cvcuda, _is_cvcuda_available, _register_kernel_internal
13-
14-
15-
CVCUDA_AVAILABLE = _is_cvcuda_available()
16-
17-
if TYPE_CHECKING:
18-
import cvcuda # type: ignore[import-not-found]
19-
if CVCUDA_AVAILABLE:
20-
cvcuda = _import_cvcuda() # noqa: F811
11+
from ._utils import _get_kernel, _register_kernel_internal
2112

2213

2314
def erase(

torchvision/transforms/v2/functional/_color.py

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
from typing import TYPE_CHECKING
2-
31
import PIL.Image
42
import torch
53
from torch.nn.functional import conv2d
@@ -11,15 +9,7 @@
119

1210
from ._misc import _num_value_bits, to_dtype_image
1311
from ._type_conversion import pil_to_tensor, to_pil_image
14-
from ._utils import _get_kernel, _import_cvcuda, _is_cvcuda_available, _register_kernel_internal
15-
16-
17-
CVCUDA_AVAILABLE = _is_cvcuda_available()
18-
19-
if TYPE_CHECKING:
20-
import cvcuda # type: ignore[import-not-found]
21-
if CVCUDA_AVAILABLE:
22-
cvcuda = _import_cvcuda() # noqa: F811
12+
from ._utils import _get_kernel, _register_kernel_internal
2313

2414

2515
def rgb_to_grayscale(inpt: torch.Tensor, num_output_channels: int = 1) -> torch.Tensor:

torchvision/transforms/v2/functional/_misc.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def normalize_video(video: torch.Tensor, mean: list[float], std: list[float], in
7979
return normalize_image(video, mean, std, inplace=inplace)
8080

8181

82-
def _normalize_cvcuda(
82+
def _normalize_image_cvcuda(
8383
image: "cvcuda.Tensor",
8484
mean: list[float],
8585
std: list[float],
@@ -114,7 +114,7 @@ def _normalize_cvcuda(
114114

115115

116116
if CVCUDA_AVAILABLE:
117-
_register_kernel_internal(normalize, _import_cvcuda().Tensor)(_normalize_cvcuda)
117+
_register_kernel_internal(normalize, _import_cvcuda().Tensor)(_normalize_image_cvcuda)
118118

119119

120120
def gaussian_blur(inpt: torch.Tensor, kernel_size: list[int], sigma: Optional[list[float]] = None) -> torch.Tensor:

0 commit comments

Comments
 (0)