From 44db71c0772e5ef5758c38d0e4e8ad9995946c80 Mon Sep 17 00:00:00 2001 From: Justin Davis Date: Tue, 25 Nov 2025 09:14:49 -0800 Subject: [PATCH 01/17] implement additional cvcuda infra for all branches to avoid duplicate setup --- torchvision/transforms/v2/_transform.py | 4 ++-- torchvision/transforms/v2/_utils.py | 3 ++- .../transforms/v2/functional/__init__.py | 2 +- .../transforms/v2/functional/_augment.py | 11 ++++++++++- .../transforms/v2/functional/_color.py | 12 +++++++++++- .../transforms/v2/functional/_geometry.py | 19 +++++++++++++++++-- torchvision/transforms/v2/functional/_misc.py | 11 +++++++++-- .../transforms/v2/functional/_utils.py | 16 ++++++++++++++++ 8 files changed, 68 insertions(+), 10 deletions(-) diff --git a/torchvision/transforms/v2/_transform.py b/torchvision/transforms/v2/_transform.py index ac84fcb6c82..bec9ffcf714 100644 --- a/torchvision/transforms/v2/_transform.py +++ b/torchvision/transforms/v2/_transform.py @@ -11,7 +11,7 @@ from torchvision.transforms.v2._utils import check_type, has_any, is_pure_tensor from torchvision.utils import _log_api_usage_once -from .functional._utils import _get_kernel +from .functional._utils import _get_kernel, is_cvcuda_tensor class Transform(nn.Module): @@ -23,7 +23,7 @@ class Transform(nn.Module): # Class attribute defining transformed types. Other types are passed-through without any transformation # We support both Types and callables that are able to do further checks on the type of the input. - _transformed_types: tuple[type | Callable[[Any], bool], ...] = (torch.Tensor, PIL.Image.Image) + _transformed_types: tuple[type | Callable[[Any], bool], ...] = (torch.Tensor, PIL.Image.Image, is_cvcuda_tensor) def __init__(self) -> None: super().__init__() diff --git a/torchvision/transforms/v2/_utils.py b/torchvision/transforms/v2/_utils.py index bb6051b4e61..765a772fe41 100644 --- a/torchvision/transforms/v2/_utils.py +++ b/torchvision/transforms/v2/_utils.py @@ -15,7 +15,7 @@ from torchvision._utils import sequence_to_str from torchvision.transforms.transforms import _check_sequence_input, _setup_angle, _setup_size # noqa: F401 -from torchvision.transforms.v2.functional import get_dimensions, get_size, is_pure_tensor +from torchvision.transforms.v2.functional import get_dimensions, get_size, is_cvcuda_tensor, is_pure_tensor from torchvision.transforms.v2.functional._utils import _FillType, _FillTypeJIT @@ -207,6 +207,7 @@ def query_size(flat_inputs: list[Any]) -> tuple[int, int]: tv_tensors.Mask, tv_tensors.BoundingBoxes, tv_tensors.KeyPoints, + is_cvcuda_tensor, ), ) } diff --git a/torchvision/transforms/v2/functional/__init__.py b/torchvision/transforms/v2/functional/__init__.py index 032a993b1f0..52181e4624b 100644 --- a/torchvision/transforms/v2/functional/__init__.py +++ b/torchvision/transforms/v2/functional/__init__.py @@ -1,6 +1,6 @@ from torchvision.transforms import InterpolationMode # usort: skip -from ._utils import is_pure_tensor, register_kernel # usort: skip +from ._utils import is_pure_tensor, register_kernel, is_cvcuda_tensor # usort: skip from ._meta import ( clamp_bounding_boxes, diff --git a/torchvision/transforms/v2/functional/_augment.py b/torchvision/transforms/v2/functional/_augment.py index a904d8d7cbd..7ce5bdc7b7e 100644 --- a/torchvision/transforms/v2/functional/_augment.py +++ b/torchvision/transforms/v2/functional/_augment.py @@ -1,4 +1,5 @@ import io +from typing import TYPE_CHECKING import PIL.Image @@ -8,7 +9,15 @@ from torchvision.transforms.functional import pil_to_tensor, to_pil_image from torchvision.utils import _log_api_usage_once -from ._utils import _get_kernel, _register_kernel_internal +from ._utils import _get_kernel, _import_cvcuda, _is_cvcuda_available, _register_kernel_internal + + +CVCUDA_AVAILABLE = _is_cvcuda_available() + +if TYPE_CHECKING: + import cvcuda # type: ignore[import-not-found] +if CVCUDA_AVAILABLE: + cvcuda = _import_cvcuda() # noqa: F811 def erase( diff --git a/torchvision/transforms/v2/functional/_color.py b/torchvision/transforms/v2/functional/_color.py index be254c0d63a..5be9c62902a 100644 --- a/torchvision/transforms/v2/functional/_color.py +++ b/torchvision/transforms/v2/functional/_color.py @@ -1,3 +1,5 @@ +from typing import TYPE_CHECKING + import PIL.Image import torch from torch.nn.functional import conv2d @@ -9,7 +11,15 @@ from ._misc import _num_value_bits, to_dtype_image from ._type_conversion import pil_to_tensor, to_pil_image -from ._utils import _get_kernel, _register_kernel_internal +from ._utils import _get_kernel, _import_cvcuda, _is_cvcuda_available, _register_kernel_internal + + +CVCUDA_AVAILABLE = _is_cvcuda_available() + +if TYPE_CHECKING: + import cvcuda # type: ignore[import-not-found] +if CVCUDA_AVAILABLE: + cvcuda = _import_cvcuda() # noqa: F811 def rgb_to_grayscale(inpt: torch.Tensor, num_output_channels: int = 1) -> torch.Tensor: diff --git a/torchvision/transforms/v2/functional/_geometry.py b/torchvision/transforms/v2/functional/_geometry.py index 4fcb7fabe0d..c029488001c 100644 --- a/torchvision/transforms/v2/functional/_geometry.py +++ b/torchvision/transforms/v2/functional/_geometry.py @@ -2,7 +2,7 @@ import numbers import warnings from collections.abc import Sequence -from typing import Any, Optional, Union +from typing import Any, Optional, TYPE_CHECKING, Union import PIL.Image import torch @@ -26,7 +26,22 @@ from ._meta import _get_size_image_pil, clamp_bounding_boxes, convert_bounding_box_format -from ._utils import _FillTypeJIT, _get_kernel, _register_five_ten_crop_kernel_internal, _register_kernel_internal +from ._utils import ( + _FillTypeJIT, + _get_kernel, + _import_cvcuda, + _is_cvcuda_available, + _register_five_ten_crop_kernel_internal, + _register_kernel_internal, +) + + +CVCUDA_AVAILABLE = _is_cvcuda_available() + +if TYPE_CHECKING: + import cvcuda # type: ignore[import-not-found] +if CVCUDA_AVAILABLE: + cvcuda = _import_cvcuda() # noqa: F811 def _check_interpolation(interpolation: Union[InterpolationMode, int]) -> InterpolationMode: diff --git a/torchvision/transforms/v2/functional/_misc.py b/torchvision/transforms/v2/functional/_misc.py index daf263df046..0fa05a2113c 100644 --- a/torchvision/transforms/v2/functional/_misc.py +++ b/torchvision/transforms/v2/functional/_misc.py @@ -1,5 +1,5 @@ import math -from typing import Optional +from typing import Optional, TYPE_CHECKING import PIL.Image import torch @@ -13,7 +13,14 @@ from ._meta import _convert_bounding_box_format -from ._utils import _get_kernel, _register_kernel_internal, is_pure_tensor +from ._utils import _get_kernel, _import_cvcuda, _is_cvcuda_available, _register_kernel_internal, is_pure_tensor + +CVCUDA_AVAILABLE = _is_cvcuda_available() + +if TYPE_CHECKING: + import cvcuda # type: ignore[import-not-found] +if CVCUDA_AVAILABLE: + cvcuda = _import_cvcuda() # noqa: F811 def normalize( diff --git a/torchvision/transforms/v2/functional/_utils.py b/torchvision/transforms/v2/functional/_utils.py index ad1eddd258b..73fafaf7425 100644 --- a/torchvision/transforms/v2/functional/_utils.py +++ b/torchvision/transforms/v2/functional/_utils.py @@ -169,3 +169,19 @@ def _is_cvcuda_available(): return True except ImportError: return False + + +def is_cvcuda_tensor(inpt: Any) -> bool: + """ + Check if the input is a CVCUDA tensor. + + Args: + inpt: The input to check. + + Returns: + True if the input is a CV-CUDA tensor, False otherwise. + """ + if _is_cvcuda_available(): + cvcuda = _import_cvcuda() + return isinstance(inpt, cvcuda.Tensor) + return False From e3dd70022fa1c87aca7a9a98068b6e13e802a375 Mon Sep 17 00:00:00 2001 From: Justin Davis Date: Tue, 25 Nov 2025 09:26:19 -0800 Subject: [PATCH 02/17] update make_image_cvcuda to have default batch dim --- test/common_utils.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/test/common_utils.py b/test/common_utils.py index 8c3c9dd58a8..e7bae60c41b 100644 --- a/test/common_utils.py +++ b/test/common_utils.py @@ -400,8 +400,9 @@ def make_image_pil(*args, **kwargs): return to_pil_image(make_image(*args, **kwargs)) -def make_image_cvcuda(*args, **kwargs): - return to_cvcuda_tensor(make_image(*args, **kwargs)) +def make_image_cvcuda(*args, batch_dims=(1,), **kwargs): + # explicitly default batch_dims to (1,) since to_cvcuda_tensor requires a batch dimension (ndims == 4) + return to_cvcuda_tensor(make_image(*args, batch_dims=batch_dims, **kwargs)) def make_keypoints(canvas_size=DEFAULT_SIZE, *, num_points=4, dtype=None, device="cpu"): From c035df1c6eaebcad25604f8c298a7d9eaf86864b Mon Sep 17 00:00:00 2001 From: Justin Davis Date: Mon, 1 Dec 2025 18:16:27 -0800 Subject: [PATCH 03/17] add stanardized setup to main for easier updating of PRs and branches --- test/common_utils.py | 21 ++++++++++++++-- test/test_transforms_v2.py | 2 +- torchvision/transforms/v2/_utils.py | 2 +- torchvision/transforms/v2/functional/_meta.py | 24 +++++++++++++++++-- 4 files changed, 43 insertions(+), 6 deletions(-) diff --git a/test/common_utils.py b/test/common_utils.py index e7bae60c41b..3b889e93d2e 100644 --- a/test/common_utils.py +++ b/test/common_utils.py @@ -20,13 +20,15 @@ from torch.testing._comparison import BooleanPair, NonePair, not_close_error_metas, NumberPair, TensorLikePair from torchvision import io, tv_tensors from torchvision.transforms._functional_tensor import _max_value as get_max_value -from torchvision.transforms.v2.functional import to_cvcuda_tensor, to_image, to_pil_image +from torchvision.transforms.v2.functional import cvcuda_to_tensor, to_cvcuda_tensor, to_image, to_pil_image +from torchvision.transforms.v2.functional._utils import _import_cvcuda, _is_cvcuda_available from torchvision.utils import _Image_fromarray IN_OSS_CI = any(os.getenv(var) == "true" for var in ["CIRCLECI", "GITHUB_ACTIONS"]) IN_RE_WORKER = os.environ.get("INSIDE_RE_WORKER") is not None IN_FBCODE = os.environ.get("IN_FBCODE_TORCHVISION") == "1" +CVCUDA_AVAILABLE = _is_cvcuda_available() CUDA_NOT_AVAILABLE_MSG = "CUDA device not available" MPS_NOT_AVAILABLE_MSG = "MPS device not available" OSS_CI_GPU_NO_CUDA_MSG = "We're in an OSS GPU machine, and this test doesn't need cuda." @@ -275,6 +277,17 @@ def combinations_grid(**kwargs): return [dict(zip(kwargs.keys(), values)) for values in itertools.product(*kwargs.values())] +def cvcuda_to_pil_compatible_tensor(tensor: "cvcuda.Tensor") -> torch.Tensor: + tensor = cvcuda_to_tensor(tensor) + if tensor.ndim != 4: + raise ValueError(f"CV-CUDA Tensor should be 4 dimensional. Got {tensor.ndim} dimensions.") + if tensor.shape[0] != 1: + raise ValueError( + f"CV-CUDA Tensor should have batch dimension 1 for comparison with PIL.Image.Image. Got {tensor.shape[0]}." + ) + return tensor.squeeze(0).cpu() + + class ImagePair(TensorLikePair): def __init__( self, @@ -287,6 +300,11 @@ def __init__( if all(isinstance(input, PIL.Image.Image) for input in [actual, expected]): actual, expected = (to_image(input) for input in [actual, expected]) + # handle check for CV-CUDA Tensors + if CVCUDA_AVAILABLE and isinstance(actual, _import_cvcuda().Tensor): + # Use the PIL compatible tensor, so we can always compare with PIL.Image.Image + actual = cvcuda_to_pil_compatible_tensor(actual) + super().__init__(actual, expected, **other_parameters) self.mae = mae @@ -401,7 +419,6 @@ def make_image_pil(*args, **kwargs): def make_image_cvcuda(*args, batch_dims=(1,), **kwargs): - # explicitly default batch_dims to (1,) since to_cvcuda_tensor requires a batch dimension (ndims == 4) return to_cvcuda_tensor(make_image(*args, batch_dims=batch_dims, **kwargs)) diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index 670a9d00ffb..7eba65550da 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -21,6 +21,7 @@ import torchvision.transforms.v2 as transforms from common_utils import ( + assert_close, assert_equal, cache, cpu_and_cuda, @@ -41,7 +42,6 @@ ) from torch import nn -from torch.testing import assert_close from torch.utils._pytree import tree_flatten, tree_map from torch.utils.data import DataLoader, default_collate from torchvision import tv_tensors diff --git a/torchvision/transforms/v2/_utils.py b/torchvision/transforms/v2/_utils.py index 765a772fe41..3fc33ce5964 100644 --- a/torchvision/transforms/v2/_utils.py +++ b/torchvision/transforms/v2/_utils.py @@ -182,7 +182,7 @@ def query_chw(flat_inputs: list[Any]) -> tuple[int, int, int]: chws = { tuple(get_dimensions(inpt)) for inpt in flat_inputs - if check_type(inpt, (is_pure_tensor, tv_tensors.Image, PIL.Image.Image, tv_tensors.Video)) + if check_type(inpt, (is_pure_tensor, tv_tensors.Image, PIL.Image.Image, tv_tensors.Video, is_cvcuda_tensor)) } if not chws: raise TypeError("No image or video was found in the sample") diff --git a/torchvision/transforms/v2/functional/_meta.py b/torchvision/transforms/v2/functional/_meta.py index 6b8f19f12f4..ee562cb2aee 100644 --- a/torchvision/transforms/v2/functional/_meta.py +++ b/torchvision/transforms/v2/functional/_meta.py @@ -51,6 +51,16 @@ def get_dimensions_video(video: torch.Tensor) -> list[int]: return get_dimensions_image(video) +def _get_dimensions_cvcuda(image: "cvcuda.Tensor") -> list[int]: + # CV-CUDA tensor is always in NHWC layout + # get_dimensions is CHW + return [image.shape[3], image.shape[1], image.shape[2]] + + +if CVCUDA_AVAILABLE: + _register_kernel_internal(get_dimensions, cvcuda.Tensor)(_get_dimensions_cvcuda) + + def get_num_channels(inpt: torch.Tensor) -> int: if torch.jit.is_scripting(): return get_num_channels_image(inpt) @@ -87,6 +97,16 @@ def get_num_channels_video(video: torch.Tensor) -> int: get_image_num_channels = get_num_channels +def _get_num_channels_cvcuda(image: "cvcuda.Tensor") -> int: + # CV-CUDA tensor is always in NHWC layout + # get_num_channels is C + return image.shape[3] + + +if CVCUDA_AVAILABLE: + _register_kernel_internal(get_num_channels, cvcuda.Tensor)(_get_num_channels_cvcuda) + + def get_size(inpt: torch.Tensor) -> list[int]: if torch.jit.is_scripting(): return get_size_image(inpt) @@ -114,7 +134,7 @@ def _get_size_image_pil(image: PIL.Image.Image) -> list[int]: return [height, width] -def get_size_image_cvcuda(image: "cvcuda.Tensor") -> list[int]: +def _get_size_cvcuda(image: "cvcuda.Tensor") -> list[int]: """Get size of `cvcuda.Tensor` with NHWC layout.""" hw = list(image.shape[-3:-1]) ndims = len(hw) @@ -125,7 +145,7 @@ def get_size_image_cvcuda(image: "cvcuda.Tensor") -> list[int]: if CVCUDA_AVAILABLE: - _get_size_image_cvcuda = _register_kernel_internal(get_size, cvcuda.Tensor)(get_size_image_cvcuda) + _register_kernel_internal(get_size, cvcuda.Tensor)(_get_size_cvcuda) @_register_kernel_internal(get_size, tv_tensors.Video, tv_tensor_wrapper=False) From 98d7dfb2059eaf2c10c3f549ea45f1d27875134c Mon Sep 17 00:00:00 2001 From: Justin Davis Date: Mon, 1 Dec 2025 18:25:09 -0800 Subject: [PATCH 04/17] update is_cvcuda_tensor --- torchvision/transforms/v2/functional/_utils.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/torchvision/transforms/v2/functional/_utils.py b/torchvision/transforms/v2/functional/_utils.py index 73fafaf7425..44b2edeaf2d 100644 --- a/torchvision/transforms/v2/functional/_utils.py +++ b/torchvision/transforms/v2/functional/_utils.py @@ -181,7 +181,8 @@ def is_cvcuda_tensor(inpt: Any) -> bool: Returns: True if the input is a CV-CUDA tensor, False otherwise. """ - if _is_cvcuda_available(): + try: cvcuda = _import_cvcuda() return isinstance(inpt, cvcuda.Tensor) - return False + except ImportError: + return False From ddc116d13febdae1d53507bcde9f103a4c14eba7 Mon Sep 17 00:00:00 2001 From: Justin Davis Date: Tue, 2 Dec 2025 12:37:03 -0800 Subject: [PATCH 05/17] add cvcuda to pil compatible to transforms by default --- test/test_transforms_v2.py | 1 + 1 file changed, 1 insertion(+) diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index 7eba65550da..87166477669 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -25,6 +25,7 @@ assert_equal, cache, cpu_and_cuda, + cvcuda_to_pil_compatible_tensor, freeze_rng_state, ignore_jit_no_profile_information_warning, make_bounding_boxes, From e51dc7eabd254261347245f4492892fd0944aae5 Mon Sep 17 00:00:00 2001 From: Justin Davis Date: Tue, 2 Dec 2025 12:46:23 -0800 Subject: [PATCH 06/17] remove cvcuda from transform class --- torchvision/transforms/v2/_transform.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchvision/transforms/v2/_transform.py b/torchvision/transforms/v2/_transform.py index bec9ffcf714..ac84fcb6c82 100644 --- a/torchvision/transforms/v2/_transform.py +++ b/torchvision/transforms/v2/_transform.py @@ -11,7 +11,7 @@ from torchvision.transforms.v2._utils import check_type, has_any, is_pure_tensor from torchvision.utils import _log_api_usage_once -from .functional._utils import _get_kernel, is_cvcuda_tensor +from .functional._utils import _get_kernel class Transform(nn.Module): @@ -23,7 +23,7 @@ class Transform(nn.Module): # Class attribute defining transformed types. Other types are passed-through without any transformation # We support both Types and callables that are able to do further checks on the type of the input. - _transformed_types: tuple[type | Callable[[Any], bool], ...] = (torch.Tensor, PIL.Image.Image, is_cvcuda_tensor) + _transformed_types: tuple[type | Callable[[Any], bool], ...] = (torch.Tensor, PIL.Image.Image) def __init__(self) -> None: super().__init__() From 4939355a2c7421eeba95d7f155fe7953066aec6d Mon Sep 17 00:00:00 2001 From: Justin Davis Date: Thu, 4 Dec 2025 11:07:08 -0800 Subject: [PATCH 07/17] resolve more formatting naming --- torchvision/transforms/v2/functional/__init__.py | 2 +- torchvision/transforms/v2/functional/_meta.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/torchvision/transforms/v2/functional/__init__.py b/torchvision/transforms/v2/functional/__init__.py index 52181e4624b..032a993b1f0 100644 --- a/torchvision/transforms/v2/functional/__init__.py +++ b/torchvision/transforms/v2/functional/__init__.py @@ -1,6 +1,6 @@ from torchvision.transforms import InterpolationMode # usort: skip -from ._utils import is_pure_tensor, register_kernel, is_cvcuda_tensor # usort: skip +from ._utils import is_pure_tensor, register_kernel # usort: skip from ._meta import ( clamp_bounding_boxes, diff --git a/torchvision/transforms/v2/functional/_meta.py b/torchvision/transforms/v2/functional/_meta.py index e8630f788ca..af03ad018d4 100644 --- a/torchvision/transforms/v2/functional/_meta.py +++ b/torchvision/transforms/v2/functional/_meta.py @@ -51,14 +51,14 @@ def get_dimensions_video(video: torch.Tensor) -> list[int]: return get_dimensions_image(video) -def _get_dimensions_cvcuda(image: "cvcuda.Tensor") -> list[int]: +def get_dimensions_image_cvcuda(image: "cvcuda.Tensor") -> list[int]: # CV-CUDA tensor is always in NHWC layout # get_dimensions is CHW return [image.shape[3], image.shape[1], image.shape[2]] if CVCUDA_AVAILABLE: - _register_kernel_internal(get_dimensions, cvcuda.Tensor)(_get_dimensions_cvcuda) + _register_kernel_internal(get_dimensions, cvcuda.Tensor)(get_dimensions_image_cvcuda) def get_num_channels(inpt: torch.Tensor) -> int: @@ -97,14 +97,14 @@ def get_num_channels_video(video: torch.Tensor) -> int: get_image_num_channels = get_num_channels -def _get_num_channels_cvcuda(image: "cvcuda.Tensor") -> int: +def get_num_channels_image_cvcuda(image: "cvcuda.Tensor") -> int: # CV-CUDA tensor is always in NHWC layout # get_num_channels is C return image.shape[3] if CVCUDA_AVAILABLE: - _register_kernel_internal(get_num_channels, cvcuda.Tensor)(_get_num_channels_cvcuda) + _register_kernel_internal(get_num_channels, cvcuda.Tensor)(get_num_channels_image_cvcuda) def get_size(inpt: torch.Tensor) -> list[int]: From fbea584365311ae6b56be7e4f6bbff1f834dd31a Mon Sep 17 00:00:00 2001 From: Justin Davis Date: Thu, 4 Dec 2025 11:15:49 -0800 Subject: [PATCH 08/17] update is cvcuda tensor impl --- torchvision/transforms/v2/_utils.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/torchvision/transforms/v2/_utils.py b/torchvision/transforms/v2/_utils.py index 3fc33ce5964..e803aa49c60 100644 --- a/torchvision/transforms/v2/_utils.py +++ b/torchvision/transforms/v2/_utils.py @@ -15,8 +15,8 @@ from torchvision._utils import sequence_to_str from torchvision.transforms.transforms import _check_sequence_input, _setup_angle, _setup_size # noqa: F401 -from torchvision.transforms.v2.functional import get_dimensions, get_size, is_cvcuda_tensor, is_pure_tensor -from torchvision.transforms.v2.functional._utils import _FillType, _FillTypeJIT +from torchvision.transforms.v2.functional import get_dimensions, get_size, is_pure_tensor +from torchvision.transforms.v2.functional._utils import _FillType, _FillTypeJIT, _is_cvcuda_tensor def _setup_number_or_seq(arg: int | float | Sequence[int | float], name: str) -> Sequence[float]: @@ -182,7 +182,7 @@ def query_chw(flat_inputs: list[Any]) -> tuple[int, int, int]: chws = { tuple(get_dimensions(inpt)) for inpt in flat_inputs - if check_type(inpt, (is_pure_tensor, tv_tensors.Image, PIL.Image.Image, tv_tensors.Video, is_cvcuda_tensor)) + if check_type(inpt, (is_pure_tensor, tv_tensors.Image, PIL.Image.Image, tv_tensors.Video, _is_cvcuda_tensor)) } if not chws: raise TypeError("No image or video was found in the sample") @@ -207,7 +207,7 @@ def query_size(flat_inputs: list[Any]) -> tuple[int, int]: tv_tensors.Mask, tv_tensors.BoundingBoxes, tv_tensors.KeyPoints, - is_cvcuda_tensor, + _is_cvcuda_tensor, ), ) } From 511c169f0fd238a0b3d4eb564533f8d4087f9a0b Mon Sep 17 00:00:00 2001 From: Justin Davis Date: Tue, 25 Nov 2025 14:28:45 -0800 Subject: [PATCH 09/17] adjust brightness done and tested --- test/test_transforms_v2.py | 45 +++++++++++++++++-- .../transforms/v2/functional/_color.py | 13 ++++++ 2 files changed, 54 insertions(+), 4 deletions(-) diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index f767e211125..bab29ccf0a5 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -2823,7 +2823,18 @@ class TestAdjustBrightness: def test_kernel(self, kernel, make_input, dtype, device): check_kernel(kernel, make_input(dtype=dtype, device=device), brightness_factor=self._DEFAULT_BRIGHTNESS_FACTOR) - @pytest.mark.parametrize("make_input", [make_image_tensor, make_image_pil, make_image, make_video]) + @pytest.mark.parametrize( + "make_input", + [ + make_image_tensor, + make_image_pil, + make_image, + make_video, + pytest.param( + make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA not available") + ), + ], + ) def test_functional(self, make_input): check_functional(F.adjust_brightness, make_input(), brightness_factor=self._DEFAULT_BRIGHTNESS_FACTOR) @@ -2834,19 +2845,45 @@ def test_functional(self, make_input): (F._color._adjust_brightness_image_pil, PIL.Image.Image), (F.adjust_brightness_image, tv_tensors.Image), (F.adjust_brightness_video, tv_tensors.Video), + pytest.param( + F._color._adjust_brightness_cvcuda, + "cvcuda.Tensor", + marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA not available"), + ), ], ) def test_functional_signature(self, kernel, input_type): + if input_type == "cvcuda.Tensor": + input_type = _import_cvcuda().Tensor check_functional_kernel_signature_match(F.adjust_brightness, kernel=kernel, input_type=input_type) + @pytest.mark.parametrize( + "make_input", + [ + make_image, + pytest.param( + make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA not available") + ), + ], + ) @pytest.mark.parametrize("brightness_factor", _CORRECTNESS_BRIGHTNESS_FACTORS) - def test_image_correctness(self, brightness_factor): - image = make_image(dtype=torch.uint8, device="cpu") + def test_image_correctness(self, make_input, brightness_factor): + image = make_input(dtype=torch.uint8, device="cpu") actual = F.adjust_brightness(image, brightness_factor=brightness_factor) + + if make_input is make_image_cvcuda: + actual = F.cvcuda_to_tensor(actual).to(device="cpu") + actual = actual.squeeze(0) + image = F.cvcuda_to_tensor(image) + image = image.squeeze(0) + expected = F.to_image(F.adjust_brightness(F.to_pil_image(image), brightness_factor=brightness_factor)) - torch.testing.assert_close(actual, expected) + if make_input is make_image_cvcuda: + torch.testing.assert_close(actual, expected, rtol=0, atol=1) + else: + torch.testing.assert_close(actual, expected) class TestCutMixMixUp: diff --git a/torchvision/transforms/v2/functional/_color.py b/torchvision/transforms/v2/functional/_color.py index 5be9c62902a..2e663ae56fe 100644 --- a/torchvision/transforms/v2/functional/_color.py +++ b/torchvision/transforms/v2/functional/_color.py @@ -145,6 +145,19 @@ def adjust_brightness_video(video: torch.Tensor, brightness_factor: float) -> to return adjust_brightness_image(video, brightness_factor=brightness_factor) +def _adjust_brightness_cvcuda(image: "cvcuda.Tensor", brightness_factor: float) -> "cvcuda.Tensor": + cvcuda = _import_cvcuda() + + cv_brightness = torch.tensor([brightness_factor], dtype=torch.float32, device="cuda") + cv_brightness = cvcuda.as_tensor(cv_brightness, "N") + + return cvcuda.brightness_contrast(image, brightness=cv_brightness) + + +if CVCUDA_AVAILABLE: + _register_kernel_internal(adjust_brightness, _import_cvcuda().Tensor)(_adjust_brightness_cvcuda) + + def adjust_saturation(inpt: torch.Tensor, saturation_factor: float) -> torch.Tensor: """Adjust saturation.""" if torch.jit.is_scripting(): From 54f3f4af9794b92dd4cf8a4de40a8d6fbeb384ab Mon Sep 17 00:00:00 2001 From: Justin Davis Date: Tue, 25 Nov 2025 17:21:04 -0800 Subject: [PATCH 10/17] completed and tested adjust_contrast --- test/test_transforms_v2.py | 40 +++++++++++++++++-- .../transforms/v2/functional/_color.py | 29 ++++++++++++++ 2 files changed, 66 insertions(+), 3 deletions(-) diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index bab29ccf0a5..c8cc93fe629 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -6091,7 +6091,18 @@ def test_kernel_image(self, dtype, device): def test_kernel_video(self): check_kernel(F.adjust_contrast_video, make_video(), contrast_factor=0.5) - @pytest.mark.parametrize("make_input", [make_image_tensor, make_image, make_image_pil, make_video]) + @pytest.mark.parametrize( + "make_input", + [ + make_image_tensor, + make_image, + make_image_pil, + make_video, + pytest.param( + make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA not available") + ), + ], + ) def test_functional(self, make_input): check_functional(F.adjust_contrast, make_input(), contrast_factor=0.5) @@ -6102,9 +6113,16 @@ def test_functional(self, make_input): (F._color._adjust_contrast_image_pil, PIL.Image.Image), (F.adjust_contrast_image, tv_tensors.Image), (F.adjust_contrast_video, tv_tensors.Video), + pytest.param( + F._color._adjust_contrast_cvcuda, + "cvcuda.Tensor", + marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA not available"), + ), ], ) def test_functional_signature(self, kernel, input_type): + if input_type == "cvcuda.Tensor": + input_type = _import_cvcuda().Tensor check_functional_kernel_signature_match(F.adjust_contrast, kernel=kernel, input_type=input_type) def test_functional_error(self): @@ -6114,11 +6132,27 @@ def test_functional_error(self): with pytest.raises(ValueError, match="is not non-negative"): F.adjust_contrast(make_image(), contrast_factor=-1) + @pytest.mark.parametrize( + "make_input", + [ + make_image, + pytest.param( + make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA not available") + ), + ], + ) @pytest.mark.parametrize("contrast_factor", [0.1, 0.5, 1.0]) - def test_correctness_image(self, contrast_factor): - image = make_image(dtype=torch.uint8, device="cpu") + def test_correctness_image(self, make_input, contrast_factor): + image = make_input(dtype=torch.uint8, device="cpu") actual = F.adjust_contrast(image, contrast_factor=contrast_factor) + + if make_input is make_image_cvcuda: + actual = F.cvcuda_to_tensor(actual).to(device="cpu") + actual = actual.squeeze(0) + image = F.cvcuda_to_tensor(image) + image = image.squeeze(0) + expected = F.to_image(F.adjust_contrast(F.to_pil_image(image), contrast_factor=contrast_factor)) assert_close(actual, expected, rtol=0, atol=1) diff --git a/torchvision/transforms/v2/functional/_color.py b/torchvision/transforms/v2/functional/_color.py index 2e663ae56fe..061604f0f94 100644 --- a/torchvision/transforms/v2/functional/_color.py +++ b/torchvision/transforms/v2/functional/_color.py @@ -236,6 +236,35 @@ def adjust_contrast_video(video: torch.Tensor, contrast_factor: float) -> torch. return adjust_contrast_image(video, contrast_factor=contrast_factor) +def _adjust_contrast_cvcuda(image: "cvcuda.Tensor", contrast_factor: float) -> "cvcuda.Tensor": + cvcuda = _import_cvcuda() + + if contrast_factor < 0: + raise ValueError(f"contrast_factor ({contrast_factor}) is not non-negative.") + + c = image.shape[3] + if c not in [1, 3]: + raise TypeError(f"Input image tensor permitted channel values are 1 or 3, but found {c}") + + if c == 3: + grayscale = cvcuda.cvtcolor(image, cvcuda.ColorConversion.RGB2GRAY) + else: + grayscale = image + + contrast = cvcuda.as_tensor(torch.tensor([contrast_factor], dtype=torch.float32, device="cuda")) + + torch_image = torch.as_tensor(grayscale.cuda()) + mean = torch.mean(torch_image.float()) + + contrast_center = cvcuda.as_tensor(torch.tensor([mean.item()], dtype=torch.float32, device="cuda")) + + return cvcuda.brightness_contrast(image, contrast=contrast, contrast_center=contrast_center) + + +if CVCUDA_AVAILABLE: + _register_kernel_internal(adjust_contrast, _import_cvcuda().Tensor)(_adjust_contrast_cvcuda) + + def adjust_sharpness(inpt: torch.Tensor, sharpness_factor: float) -> torch.Tensor: """See :class:`~torchvision.transforms.RandomAdjustSharpness`""" if torch.jit.is_scripting(): From b11c38a589dc3e2bedfac7cf005ef39c52cb0254 Mon Sep 17 00:00:00 2001 From: Justin Davis Date: Tue, 2 Dec 2025 15:03:06 -0800 Subject: [PATCH 11/17] update brightness contrast tests plus add comment on mean calc for contrast --- test/test_transforms_v2.py | 14 ++++---------- torchvision/transforms/v2/functional/_color.py | 3 ++- 2 files changed, 6 insertions(+), 11 deletions(-) diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index c8cc93fe629..8acd48a301d 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -2873,17 +2873,14 @@ def test_image_correctness(self, make_input, brightness_factor): actual = F.adjust_brightness(image, brightness_factor=brightness_factor) if make_input is make_image_cvcuda: - actual = F.cvcuda_to_tensor(actual).to(device="cpu") - actual = actual.squeeze(0) - image = F.cvcuda_to_tensor(image) - image = image.squeeze(0) + image = cvcuda_to_pil_compatible_tensor(image) expected = F.to_image(F.adjust_brightness(F.to_pil_image(image), brightness_factor=brightness_factor)) if make_input is make_image_cvcuda: - torch.testing.assert_close(actual, expected, rtol=0, atol=1) + assert_close(actual, expected, rtol=0, atol=1) else: - torch.testing.assert_close(actual, expected) + assert_close(actual, expected) class TestCutMixMixUp: @@ -6148,10 +6145,7 @@ def test_correctness_image(self, make_input, contrast_factor): actual = F.adjust_contrast(image, contrast_factor=contrast_factor) if make_input is make_image_cvcuda: - actual = F.cvcuda_to_tensor(actual).to(device="cpu") - actual = actual.squeeze(0) - image = F.cvcuda_to_tensor(image) - image = image.squeeze(0) + image = cvcuda_to_pil_compatible_tensor(image) expected = F.to_image(F.adjust_contrast(F.to_pil_image(image), contrast_factor=contrast_factor)) diff --git a/torchvision/transforms/v2/functional/_color.py b/torchvision/transforms/v2/functional/_color.py index 061604f0f94..64bebbe05b7 100644 --- a/torchvision/transforms/v2/functional/_color.py +++ b/torchvision/transforms/v2/functional/_color.py @@ -253,9 +253,10 @@ def _adjust_contrast_cvcuda(image: "cvcuda.Tensor", contrast_factor: float) -> " contrast = cvcuda.as_tensor(torch.tensor([contrast_factor], dtype=torch.float32, device="cuda")) + # torchvision uses the mean of the image as the center of contrast + # we will compute that here using torch as well for consistency torch_image = torch.as_tensor(grayscale.cuda()) mean = torch.mean(torch_image.float()) - contrast_center = cvcuda.as_tensor(torch.tensor([mean.item()], dtype=torch.float32, device="cuda")) return cvcuda.brightness_contrast(image, contrast=contrast, contrast_center=contrast_center) From d379658cd2228f7a5e153c942cdc1859800f6904 Mon Sep 17 00:00:00 2001 From: Justin Davis Date: Tue, 25 Nov 2025 17:00:40 -0800 Subject: [PATCH 12/17] complete and tested adjust_hue --- test/test_transforms_v2.py | 40 +++++++++++++++++-- .../transforms/v2/functional/_color.py | 25 ++++++++++++ 2 files changed, 62 insertions(+), 3 deletions(-) diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index 8acd48a301d..7cd11ef57be 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -6201,7 +6201,18 @@ def test_kernel_image(self, dtype, device): def test_kernel_video(self): check_kernel(F.adjust_hue_video, make_video(), hue_factor=0.25) - @pytest.mark.parametrize("make_input", [make_image_tensor, make_image, make_image_pil, make_video]) + @pytest.mark.parametrize( + "make_input", + [ + make_image_tensor, + make_image, + make_image_pil, + make_video, + pytest.param( + make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA not available") + ), + ], + ) def test_functional(self, make_input): check_functional(F.adjust_hue, make_input(), hue_factor=0.25) @@ -6212,9 +6223,16 @@ def test_functional(self, make_input): (F._color._adjust_hue_image_pil, PIL.Image.Image), (F.adjust_hue_image, tv_tensors.Image), (F.adjust_hue_video, tv_tensors.Video), + pytest.param( + F._color._adjust_hue_cvcuda, + "cvcuda.Tensor", + marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA not available"), + ), ], ) def test_functional_signature(self, kernel, input_type): + if input_type == "cvcuda.Tensor": + input_type = _import_cvcuda().Tensor check_functional_kernel_signature_match(F.adjust_hue, kernel=kernel, input_type=input_type) def test_functional_error(self): @@ -6225,11 +6243,27 @@ def test_functional_error(self): with pytest.raises(ValueError, match=re.escape("is not in [-0.5, 0.5]")): F.adjust_hue(make_image(), hue_factor=hue_factor) + @pytest.mark.parametrize( + "make_input", + [ + make_image, + pytest.param( + make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA not available") + ), + ], + ) @pytest.mark.parametrize("hue_factor", [-0.5, -0.3, 0.0, 0.2, 0.5]) - def test_correctness_image(self, hue_factor): - image = make_image(dtype=torch.uint8, device="cpu") + def test_correctness_image(self, make_input, hue_factor): + image = make_input(dtype=torch.uint8, device="cpu") actual = F.adjust_hue(image, hue_factor=hue_factor) + + if make_input is make_image_cvcuda: + actual = F.cvcuda_to_tensor(actual).to(device="cpu") + actual = actual.squeeze(0) + image = F.cvcuda_to_tensor(image) + image = image.squeeze(0) + expected = F.to_image(F.adjust_hue(F.to_pil_image(image), hue_factor=hue_factor)) mae = (actual.float() - expected.float()).abs().mean() diff --git a/torchvision/transforms/v2/functional/_color.py b/torchvision/transforms/v2/functional/_color.py index 64bebbe05b7..6fe03ec36ad 100644 --- a/torchvision/transforms/v2/functional/_color.py +++ b/torchvision/transforms/v2/functional/_color.py @@ -457,6 +457,31 @@ def adjust_hue_video(video: torch.Tensor, hue_factor: float) -> torch.Tensor: return adjust_hue_image(video, hue_factor=hue_factor) +def _adjust_hue_cvcuda(image: "cvcuda.Tensor", hue_factor: float) -> "cvcuda.Tensor": + cvcuda = _import_cvcuda() + + if not (-0.5 <= hue_factor <= 0.5): + raise ValueError(f"hue_factor ({hue_factor}) is not in [-0.5, 0.5].") + + c = image.shape[3] + if c not in [1, 3, 4]: + raise TypeError(f"Input image tensor permitted channel values are 1, 3, or 4, but found {c}") + + if c == 1: # Match PIL behaviour + return image + + # no native adjust_hue, use CV-CUDA for color converison, use torch for elementwise operations + hsv = cvcuda.cvtcolor(image, cvcuda.ColorConversion.RGB2HSV) + hsv_torch = torch.as_tensor(hsv.cuda()).float() + hsv_torch[..., 0] = (hsv_torch[..., 0] + hue_factor * 180) % 180 + hsv_modified = cvcuda.as_tensor(hsv_torch.to(torch.uint8), "NHWC") + return cvcuda.cvtcolor(hsv_modified, cvcuda.ColorConversion.HSV2RGB) + + +if CVCUDA_AVAILABLE: + _register_kernel_internal(adjust_hue, _import_cvcuda().Tensor)(_adjust_hue_cvcuda) + + def adjust_gamma(inpt: torch.Tensor, gamma: float, gain: float = 1) -> torch.Tensor: """Adjust gamma.""" if torch.jit.is_scripting(): From 310982c115463b4a28f54946e6a9b92f2d51d2c9 Mon Sep 17 00:00:00 2001 From: Justin Davis Date: Tue, 2 Dec 2025 15:05:09 -0800 Subject: [PATCH 13/17] merge brightness contrast and hue adjustment together --- test/test_transforms_v2.py | 6 ++---- torchvision/transforms/v2/functional/_color.py | 3 +++ 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index 7cd11ef57be..50e5756d1fe 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -6259,10 +6259,8 @@ def test_correctness_image(self, make_input, hue_factor): actual = F.adjust_hue(image, hue_factor=hue_factor) if make_input is make_image_cvcuda: - actual = F.cvcuda_to_tensor(actual).to(device="cpu") - actual = actual.squeeze(0) - image = F.cvcuda_to_tensor(image) - image = image.squeeze(0) + actual = cvcuda_to_pil_compatible_tensor(actual) + image = cvcuda_to_pil_compatible_tensor(image) expected = F.to_image(F.adjust_hue(F.to_pil_image(image), hue_factor=hue_factor)) diff --git a/torchvision/transforms/v2/functional/_color.py b/torchvision/transforms/v2/functional/_color.py index 6fe03ec36ad..3b776d077bf 100644 --- a/torchvision/transforms/v2/functional/_color.py +++ b/torchvision/transforms/v2/functional/_color.py @@ -471,9 +471,12 @@ def _adjust_hue_cvcuda(image: "cvcuda.Tensor", hue_factor: float) -> "cvcuda.Ten return image # no native adjust_hue, use CV-CUDA for color converison, use torch for elementwise operations + # CV-CUDA accelerates the HSV conversion hsv = cvcuda.cvtcolor(image, cvcuda.ColorConversion.RGB2HSV) + # then use torch for elementwise operations hsv_torch = torch.as_tensor(hsv.cuda()).float() hsv_torch[..., 0] = (hsv_torch[..., 0] + hue_factor * 180) % 180 + # convert back to cvcuda tensor and accelerate the HSV2RGB conversion hsv_modified = cvcuda.as_tensor(hsv_torch.to(torch.uint8), "NHWC") return cvcuda.cvtcolor(hsv_modified, cvcuda.ColorConversion.HSV2RGB) From e0392a0a0001c63a0300d13ecc02a54e19efb118 Mon Sep 17 00:00:00 2001 From: Justin Davis Date: Tue, 25 Nov 2025 14:43:57 -0800 Subject: [PATCH 14/17] wip adjust_saturation --- .../transforms/v2/functional/_color.py | 35 +++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/torchvision/transforms/v2/functional/_color.py b/torchvision/transforms/v2/functional/_color.py index 3b776d077bf..1613353b247 100644 --- a/torchvision/transforms/v2/functional/_color.py +++ b/torchvision/transforms/v2/functional/_color.py @@ -197,6 +197,41 @@ def adjust_saturation_video(video: torch.Tensor, saturation_factor: float) -> to return adjust_saturation_image(video, saturation_factor=saturation_factor) +def _adjust_saturation_cvcuda(image: "cvcuda.Tensor", saturation_factor: float) -> "cvcuda.Tensor": + if saturation_factor < 0: + raise ValueError(f"saturation_factor ({saturation_factor}) is not non-negative.") + + c = image.shape[-1] # NHWC layout + if c not in [1, 3, 4]: + raise TypeError(f"Input image tensor permitted channel values are 1, 3, or 4, but found {c}") + + if c == 1: # Match PIL behaviour + return image + + # Grayscale weights (same as _rgb_to_grayscale_image) + sf = saturation_factor + r, g, b = 0.2989, 0.587, 0.114 + + # Build 3x4 saturation matrix + twist_data = [ + [sf + (1 - sf) * r, (1 - sf) * g, (1 - sf) * b, 0.0], + [(1 - sf) * r, sf + (1 - sf) * g, (1 - sf) * b, 0.0], + [(1 - sf) * r, (1 - sf) * g, sf + (1 - sf) * b, 0.0], + ] + twist = cvcuda.Tensor( + torch.tensor(twist_data, dtype=torch.float32, device="cuda").contiguous(), + layout="HW", + ) + + return cvcuda.color_twist(image, twist) + + +if CVCUDA_AVAILABLE: + _register_kernel_internal(adjust_saturation, cvcuda.Tensor)( + _adjust_saturation_cvcuda + ) + + def adjust_contrast(inpt: torch.Tensor, contrast_factor: float) -> torch.Tensor: """See :class:`~torchvision.transforms.RandomAutocontrast`""" if torch.jit.is_scripting(): From 61b237c8aad27e5ada00f300259460913fbbd139 Mon Sep 17 00:00:00 2001 From: Justin Davis Date: Tue, 25 Nov 2025 15:43:11 -0800 Subject: [PATCH 15/17] adjust saturation complete and tested --- test/test_transforms_v2.py | 41 +++++++++++++++++-- .../transforms/v2/functional/_color.py | 18 ++++---- 2 files changed, 45 insertions(+), 14 deletions(-) diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index 50e5756d1fe..fea0762222f 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -6277,7 +6277,18 @@ def test_kernel_image(self, dtype, device): def test_kernel_video(self): check_kernel(F.adjust_saturation_video, make_video(), saturation_factor=0.5) - @pytest.mark.parametrize("make_input", [make_image_tensor, make_image, make_image_pil, make_video]) + @pytest.mark.parametrize( + "make_input", + [ + make_image_tensor, + make_image, + make_image_pil, + make_video, + pytest.param( + make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA not available") + ), + ], + ) def test_functional(self, make_input): check_functional(F.adjust_saturation, make_input(), saturation_factor=0.5) @@ -6288,9 +6299,16 @@ def test_functional(self, make_input): (F._color._adjust_saturation_image_pil, PIL.Image.Image), (F.adjust_saturation_image, tv_tensors.Image), (F.adjust_saturation_video, tv_tensors.Video), + pytest.param( + F._color._adjust_saturation_cvcuda, + "cvcuda.Tensor", + marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA not available"), + ), ], ) def test_functional_signature(self, kernel, input_type): + if input_type == "cvcuda.Tensor": + input_type = _import_cvcuda().Tensor check_functional_kernel_signature_match(F.adjust_saturation, kernel=kernel, input_type=input_type) def test_functional_error(self): @@ -6300,11 +6318,28 @@ def test_functional_error(self): with pytest.raises(ValueError, match="is not non-negative"): F.adjust_saturation(make_image(), saturation_factor=-1) + @pytest.mark.parametrize( + "make_input", + [ + make_image, + pytest.param( + make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA not available") + ), + ], + ) + @pytest.mark.parametrize("color_space", ["RGB", "GRAY"]) @pytest.mark.parametrize("saturation_factor", [0.1, 0.5, 1.0]) - def test_correctness_image(self, saturation_factor): - image = make_image(dtype=torch.uint8, device="cpu") + def test_correctness_image(self, make_input, color_space, saturation_factor): + image = make_input(dtype=torch.uint8, color_space=color_space, device="cpu") actual = F.adjust_saturation(image, saturation_factor=saturation_factor) + + if make_input is make_image_cvcuda: + actual = F.cvcuda_to_tensor(actual).to(device="cpu") + actual = actual.squeeze(0) + image = F.cvcuda_to_tensor(image) + image = image.squeeze(0) + expected = F.to_image(F.adjust_saturation(F.to_pil_image(image), saturation_factor=saturation_factor)) assert_close(actual, expected, rtol=0, atol=1) diff --git a/torchvision/transforms/v2/functional/_color.py b/torchvision/transforms/v2/functional/_color.py index 1613353b247..ba8785deb24 100644 --- a/torchvision/transforms/v2/functional/_color.py +++ b/torchvision/transforms/v2/functional/_color.py @@ -201,35 +201,31 @@ def _adjust_saturation_cvcuda(image: "cvcuda.Tensor", saturation_factor: float) if saturation_factor < 0: raise ValueError(f"saturation_factor ({saturation_factor}) is not non-negative.") - c = image.shape[-1] # NHWC layout + c = image.shape[3] if c not in [1, 3, 4]: raise TypeError(f"Input image tensor permitted channel values are 1, 3, or 4, but found {c}") if c == 1: # Match PIL behaviour return image - # Grayscale weights (same as _rgb_to_grayscale_image) + # grayscale weights sf = saturation_factor r, g, b = 0.2989, 0.587, 0.114 - - # Build 3x4 saturation matrix twist_data = [ [sf + (1 - sf) * r, (1 - sf) * g, (1 - sf) * b, 0.0], [(1 - sf) * r, sf + (1 - sf) * g, (1 - sf) * b, 0.0], [(1 - sf) * r, (1 - sf) * g, sf + (1 - sf) * b, 0.0], ] - twist = cvcuda.Tensor( - torch.tensor(twist_data, dtype=torch.float32, device="cuda").contiguous(), - layout="HW", + twist_tensor = cvcuda.as_tensor( + torch.tensor(twist_data, dtype=torch.float32, device="cuda"), + "HW", ) - return cvcuda.color_twist(image, twist) + return cvcuda.color_twist(image, twist_tensor) if CVCUDA_AVAILABLE: - _register_kernel_internal(adjust_saturation, cvcuda.Tensor)( - _adjust_saturation_cvcuda - ) + _register_kernel_internal(adjust_saturation, cvcuda.Tensor)(_adjust_saturation_cvcuda) def adjust_contrast(inpt: torch.Tensor, contrast_factor: float) -> torch.Tensor: From 2c68fc370b8f9aa574080f4bfd750c53cf3c67e4 Mon Sep 17 00:00:00 2001 From: Justin Davis Date: Tue, 2 Dec 2025 15:26:04 -0800 Subject: [PATCH 16/17] add adjust saturation --- test/test_transforms_v2.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index fea0762222f..9515fcf9702 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -6335,10 +6335,7 @@ def test_correctness_image(self, make_input, color_space, saturation_factor): actual = F.adjust_saturation(image, saturation_factor=saturation_factor) if make_input is make_image_cvcuda: - actual = F.cvcuda_to_tensor(actual).to(device="cpu") - actual = actual.squeeze(0) - image = F.cvcuda_to_tensor(image) - image = image.squeeze(0) + image = cvcuda_to_pil_compatible_tensor(image) expected = F.to_image(F.adjust_saturation(F.to_pil_image(image), saturation_factor=saturation_factor)) From 8564f0a7568fd3d1cf1b85c1b34eeb25b1bc3bdc Mon Sep 17 00:00:00 2001 From: Justin Davis Date: Thu, 4 Dec 2025 14:47:07 -0800 Subject: [PATCH 17/17] update to main standards --- test/test_transforms_v2.py | 35 +++++++++---------- .../transforms/v2/functional/_augment.py | 11 +----- .../transforms/v2/functional/_color.py | 16 ++++----- torchvision/transforms/v2/functional/_misc.py | 11 ++---- 4 files changed, 28 insertions(+), 45 deletions(-) diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index 9515fcf9702..380b661454e 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -25,7 +25,6 @@ assert_equal, cache, cpu_and_cuda, - cvcuda_to_pil_compatible_tensor, freeze_rng_state, ignore_jit_no_profile_information_warning, make_bounding_boxes, @@ -2846,14 +2845,14 @@ def test_functional(self, make_input): (F.adjust_brightness_image, tv_tensors.Image), (F.adjust_brightness_video, tv_tensors.Video), pytest.param( - F._color._adjust_brightness_cvcuda, - "cvcuda.Tensor", + F._color._adjust_brightness_image_cvcuda, + None, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA not available"), ), ], ) def test_functional_signature(self, kernel, input_type): - if input_type == "cvcuda.Tensor": + if kernel is F._color._adjust_brightness_image_cvcuda: input_type = _import_cvcuda().Tensor check_functional_kernel_signature_match(F.adjust_brightness, kernel=kernel, input_type=input_type) @@ -2873,7 +2872,7 @@ def test_image_correctness(self, make_input, brightness_factor): actual = F.adjust_brightness(image, brightness_factor=brightness_factor) if make_input is make_image_cvcuda: - image = cvcuda_to_pil_compatible_tensor(image) + image = F.cvcuda_to_tensor(image)[0].cpu() expected = F.to_image(F.adjust_brightness(F.to_pil_image(image), brightness_factor=brightness_factor)) @@ -6111,14 +6110,14 @@ def test_functional(self, make_input): (F.adjust_contrast_image, tv_tensors.Image), (F.adjust_contrast_video, tv_tensors.Video), pytest.param( - F._color._adjust_contrast_cvcuda, - "cvcuda.Tensor", + F._color._adjust_contrast_image_cvcuda, + None, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA not available"), ), ], ) def test_functional_signature(self, kernel, input_type): - if input_type == "cvcuda.Tensor": + if kernel is F._color._adjust_contrast_image_cvcuda: input_type = _import_cvcuda().Tensor check_functional_kernel_signature_match(F.adjust_contrast, kernel=kernel, input_type=input_type) @@ -6145,7 +6144,7 @@ def test_correctness_image(self, make_input, contrast_factor): actual = F.adjust_contrast(image, contrast_factor=contrast_factor) if make_input is make_image_cvcuda: - image = cvcuda_to_pil_compatible_tensor(image) + image = F.cvcuda_to_tensor(image)[0].cpu() expected = F.to_image(F.adjust_contrast(F.to_pil_image(image), contrast_factor=contrast_factor)) @@ -6224,14 +6223,14 @@ def test_functional(self, make_input): (F.adjust_hue_image, tv_tensors.Image), (F.adjust_hue_video, tv_tensors.Video), pytest.param( - F._color._adjust_hue_cvcuda, - "cvcuda.Tensor", + F._color._adjust_hue_image_cvcuda, + None, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA not available"), ), ], ) def test_functional_signature(self, kernel, input_type): - if input_type == "cvcuda.Tensor": + if kernel is F._color._adjust_hue_image_cvcuda: input_type = _import_cvcuda().Tensor check_functional_kernel_signature_match(F.adjust_hue, kernel=kernel, input_type=input_type) @@ -6259,8 +6258,8 @@ def test_correctness_image(self, make_input, hue_factor): actual = F.adjust_hue(image, hue_factor=hue_factor) if make_input is make_image_cvcuda: - actual = cvcuda_to_pil_compatible_tensor(actual) - image = cvcuda_to_pil_compatible_tensor(image) + actual = F.cvcuda_to_tensor(actual)[0].cpu() + image = F.cvcuda_to_tensor(image)[0].cpu() expected = F.to_image(F.adjust_hue(F.to_pil_image(image), hue_factor=hue_factor)) @@ -6300,14 +6299,14 @@ def test_functional(self, make_input): (F.adjust_saturation_image, tv_tensors.Image), (F.adjust_saturation_video, tv_tensors.Video), pytest.param( - F._color._adjust_saturation_cvcuda, - "cvcuda.Tensor", + F._color._adjust_saturation_image_cvcuda, + None, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA not available"), ), ], ) def test_functional_signature(self, kernel, input_type): - if input_type == "cvcuda.Tensor": + if kernel is F._color._adjust_saturation_image_cvcuda: input_type = _import_cvcuda().Tensor check_functional_kernel_signature_match(F.adjust_saturation, kernel=kernel, input_type=input_type) @@ -6335,7 +6334,7 @@ def test_correctness_image(self, make_input, color_space, saturation_factor): actual = F.adjust_saturation(image, saturation_factor=saturation_factor) if make_input is make_image_cvcuda: - image = cvcuda_to_pil_compatible_tensor(image) + image = F.cvcuda_to_tensor(image)[0].cpu() expected = F.to_image(F.adjust_saturation(F.to_pil_image(image), saturation_factor=saturation_factor)) diff --git a/torchvision/transforms/v2/functional/_augment.py b/torchvision/transforms/v2/functional/_augment.py index 7ce5bdc7b7e..a904d8d7cbd 100644 --- a/torchvision/transforms/v2/functional/_augment.py +++ b/torchvision/transforms/v2/functional/_augment.py @@ -1,5 +1,4 @@ import io -from typing import TYPE_CHECKING import PIL.Image @@ -9,15 +8,7 @@ from torchvision.transforms.functional import pil_to_tensor, to_pil_image from torchvision.utils import _log_api_usage_once -from ._utils import _get_kernel, _import_cvcuda, _is_cvcuda_available, _register_kernel_internal - - -CVCUDA_AVAILABLE = _is_cvcuda_available() - -if TYPE_CHECKING: - import cvcuda # type: ignore[import-not-found] -if CVCUDA_AVAILABLE: - cvcuda = _import_cvcuda() # noqa: F811 +from ._utils import _get_kernel, _register_kernel_internal def erase( diff --git a/torchvision/transforms/v2/functional/_color.py b/torchvision/transforms/v2/functional/_color.py index ba8785deb24..599e9b99461 100644 --- a/torchvision/transforms/v2/functional/_color.py +++ b/torchvision/transforms/v2/functional/_color.py @@ -145,7 +145,7 @@ def adjust_brightness_video(video: torch.Tensor, brightness_factor: float) -> to return adjust_brightness_image(video, brightness_factor=brightness_factor) -def _adjust_brightness_cvcuda(image: "cvcuda.Tensor", brightness_factor: float) -> "cvcuda.Tensor": +def _adjust_brightness_image_cvcuda(image: "cvcuda.Tensor", brightness_factor: float) -> "cvcuda.Tensor": cvcuda = _import_cvcuda() cv_brightness = torch.tensor([brightness_factor], dtype=torch.float32, device="cuda") @@ -155,7 +155,7 @@ def _adjust_brightness_cvcuda(image: "cvcuda.Tensor", brightness_factor: float) if CVCUDA_AVAILABLE: - _register_kernel_internal(adjust_brightness, _import_cvcuda().Tensor)(_adjust_brightness_cvcuda) + _register_kernel_internal(adjust_brightness, _import_cvcuda().Tensor)(_adjust_brightness_image_cvcuda) def adjust_saturation(inpt: torch.Tensor, saturation_factor: float) -> torch.Tensor: @@ -197,7 +197,7 @@ def adjust_saturation_video(video: torch.Tensor, saturation_factor: float) -> to return adjust_saturation_image(video, saturation_factor=saturation_factor) -def _adjust_saturation_cvcuda(image: "cvcuda.Tensor", saturation_factor: float) -> "cvcuda.Tensor": +def _adjust_saturation_image_cvcuda(image: "cvcuda.Tensor", saturation_factor: float) -> "cvcuda.Tensor": if saturation_factor < 0: raise ValueError(f"saturation_factor ({saturation_factor}) is not non-negative.") @@ -225,7 +225,7 @@ def _adjust_saturation_cvcuda(image: "cvcuda.Tensor", saturation_factor: float) if CVCUDA_AVAILABLE: - _register_kernel_internal(adjust_saturation, cvcuda.Tensor)(_adjust_saturation_cvcuda) + _register_kernel_internal(adjust_saturation, cvcuda.Tensor)(_adjust_saturation_image_cvcuda) def adjust_contrast(inpt: torch.Tensor, contrast_factor: float) -> torch.Tensor: @@ -267,7 +267,7 @@ def adjust_contrast_video(video: torch.Tensor, contrast_factor: float) -> torch. return adjust_contrast_image(video, contrast_factor=contrast_factor) -def _adjust_contrast_cvcuda(image: "cvcuda.Tensor", contrast_factor: float) -> "cvcuda.Tensor": +def _adjust_contrast_image_cvcuda(image: "cvcuda.Tensor", contrast_factor: float) -> "cvcuda.Tensor": cvcuda = _import_cvcuda() if contrast_factor < 0: @@ -294,7 +294,7 @@ def _adjust_contrast_cvcuda(image: "cvcuda.Tensor", contrast_factor: float) -> " if CVCUDA_AVAILABLE: - _register_kernel_internal(adjust_contrast, _import_cvcuda().Tensor)(_adjust_contrast_cvcuda) + _register_kernel_internal(adjust_contrast, _import_cvcuda().Tensor)(_adjust_contrast_image_cvcuda) def adjust_sharpness(inpt: torch.Tensor, sharpness_factor: float) -> torch.Tensor: @@ -488,7 +488,7 @@ def adjust_hue_video(video: torch.Tensor, hue_factor: float) -> torch.Tensor: return adjust_hue_image(video, hue_factor=hue_factor) -def _adjust_hue_cvcuda(image: "cvcuda.Tensor", hue_factor: float) -> "cvcuda.Tensor": +def _adjust_hue_image_cvcuda(image: "cvcuda.Tensor", hue_factor: float) -> "cvcuda.Tensor": cvcuda = _import_cvcuda() if not (-0.5 <= hue_factor <= 0.5): @@ -513,7 +513,7 @@ def _adjust_hue_cvcuda(image: "cvcuda.Tensor", hue_factor: float) -> "cvcuda.Ten if CVCUDA_AVAILABLE: - _register_kernel_internal(adjust_hue, _import_cvcuda().Tensor)(_adjust_hue_cvcuda) + _register_kernel_internal(adjust_hue, _import_cvcuda().Tensor)(_adjust_hue_image_cvcuda) def adjust_gamma(inpt: torch.Tensor, gamma: float, gain: float = 1) -> torch.Tensor: diff --git a/torchvision/transforms/v2/functional/_misc.py b/torchvision/transforms/v2/functional/_misc.py index 0fa05a2113c..daf263df046 100644 --- a/torchvision/transforms/v2/functional/_misc.py +++ b/torchvision/transforms/v2/functional/_misc.py @@ -1,5 +1,5 @@ import math -from typing import Optional, TYPE_CHECKING +from typing import Optional import PIL.Image import torch @@ -13,14 +13,7 @@ from ._meta import _convert_bounding_box_format -from ._utils import _get_kernel, _import_cvcuda, _is_cvcuda_available, _register_kernel_internal, is_pure_tensor - -CVCUDA_AVAILABLE = _is_cvcuda_available() - -if TYPE_CHECKING: - import cvcuda # type: ignore[import-not-found] -if CVCUDA_AVAILABLE: - cvcuda = _import_cvcuda() # noqa: F811 +from ._utils import _get_kernel, _register_kernel_internal, is_pure_tensor def normalize(