diff --git a/torchvision/transforms/v2/functional/__init__.py b/torchvision/transforms/v2/functional/__init__.py index 75f94744fcb..4998b46c282 100644 --- a/torchvision/transforms/v2/functional/__init__.py +++ b/torchvision/transforms/v2/functional/__init__.py @@ -170,3 +170,70 @@ from ._type_conversion import pil_to_tensor, to_image_pil, to_image_tensor, to_pil_image from ._deprecated import get_image_size, to_tensor # usort: skip + + +def _register_builtin_kernels(): + import functools + import inspect + + import torch + from torchvision import datapoints + + from ._utils import _KERNEL_REGISTRY, _noop + + def default_kernel_wrapper(dispatcher, kernel): + dispatcher_params = list(inspect.signature(dispatcher).parameters)[1:] + kernel_params = list(inspect.signature(kernel).parameters)[1:] + + needs_args_kwargs_handling = kernel_params != dispatcher_params + + # this avoids converting list -> set at runtime below + kernel_params = set(kernel_params) + + @functools.wraps(kernel) + def wrapper(inpt, *args, **kwargs): + input_type = type(inpt) + + if needs_args_kwargs_handling: + # Convert args to kwargs to simplify further processing + kwargs.update(dict(zip(dispatcher_params, args))) + args = () + + # drop parameters that are not relevant for the kernel, but have a default value + # in the dispatcher + for kwarg in kwargs.keys() - kernel_params: + del kwargs[kwarg] + + # add parameters that are passed implicitly to the dispatcher as metadata, + # but have to be explicit for the kernel + for kwarg in input_type.__annotations__.keys() & kernel_params: + kwargs[kwarg] = getattr(inpt, kwarg) + + output = kernel(inpt.as_subclass(torch.Tensor), *args, **kwargs) + + if isinstance(inpt, datapoints.BoundingBox) and isinstance(output, tuple): + output, spatial_size = output + metadata = dict(spatial_size=spatial_size) + else: + metadata = dict() + + return input_type.wrap_like(inpt, output, **metadata) + + return wrapper + + def register(dispatcher, datapoint_cls, kernel): + _KERNEL_REGISTRY.setdefault(dispatcher, {})[datapoint_cls] = default_kernel_wrapper(dispatcher, kernel) + + register(resize, datapoints.Image, resize_image_tensor) + register(resize, datapoints.BoundingBox, resize_bounding_box) + register(resize, datapoints.Mask, resize_mask) + register(resize, datapoints.Video, resize_video) + + register(adjust_brightness, datapoints.Image, adjust_brightness_image_tensor) + register(adjust_brightness, datapoints.BoundingBox, _noop) + register(adjust_brightness, datapoints.Mask, _noop) + register(adjust_brightness, datapoints.Video, adjust_brightness_video) + + +_register_builtin_kernels() +del _register_builtin_kernels diff --git a/torchvision/transforms/v2/functional/_color.py b/torchvision/transforms/v2/functional/_color.py index 5a7008f8f0a..f8d75e29ea2 100644 --- a/torchvision/transforms/v2/functional/_color.py +++ b/torchvision/transforms/v2/functional/_color.py @@ -10,7 +10,7 @@ from torchvision.utils import _log_api_usage_once from ._meta import _num_value_bits, to_dtype_image_tensor -from ._utils import _get_kernel, _register_explicit_noop, _register_kernel_internal, is_simple_tensor +from ._utils import _get_kernel, is_simple_tensor def _rgb_to_grayscale_image_tensor( @@ -87,10 +87,6 @@ def adjust_brightness(inpt: datapoints._InputTypeJIT, brightness_factor: float) ) -_register_explicit_noop(adjust_brightness, datapoints.BoundingBox, datapoints.Mask) - - -@_register_kernel_internal(adjust_brightness, datapoints.Image) def adjust_brightness_image_tensor(image: torch.Tensor, brightness_factor: float) -> torch.Tensor: if brightness_factor < 0: raise ValueError(f"brightness_factor ({brightness_factor}) is not non-negative.") @@ -109,7 +105,6 @@ def adjust_brightness_image_pil(image: PIL.Image.Image, brightness_factor: float return _FP.adjust_brightness(image, brightness_factor=brightness_factor) -@_register_kernel_internal(adjust_brightness, datapoints.Video) def adjust_brightness_video(video: torch.Tensor, brightness_factor: float) -> torch.Tensor: return adjust_brightness_image_tensor(video, brightness_factor=brightness_factor) diff --git a/torchvision/transforms/v2/functional/_geometry.py b/torchvision/transforms/v2/functional/_geometry.py index b7ac48c40ec..472da0d9305 100644 --- a/torchvision/transforms/v2/functional/_geometry.py +++ b/torchvision/transforms/v2/functional/_geometry.py @@ -25,7 +25,7 @@ from ._meta import clamp_bounding_box, convert_format_bounding_box, get_spatial_size_image_pil -from ._utils import _get_kernel, _register_kernel_internal, is_simple_tensor +from ._utils import _get_kernel, is_simple_tensor def _check_interpolation(interpolation: Union[InterpolationMode, int]) -> InterpolationMode: @@ -183,7 +183,6 @@ def resize( ) -@_register_kernel_internal(resize, datapoints.Image) def resize_image_tensor( image: torch.Tensor, size: List[int], @@ -285,7 +284,6 @@ def resize_image_pil( return image.resize((new_width, new_height), resample=pil_modes_mapping[interpolation]) -@_register_kernel_internal(resize, datapoints.Mask) def resize_mask(mask: torch.Tensor, size: List[int], max_size: Optional[int] = None) -> torch.Tensor: if mask.ndim < 3: mask = mask.unsqueeze(0) @@ -301,7 +299,6 @@ def resize_mask(mask: torch.Tensor, size: List[int], max_size: Optional[int] = N return output -@_register_kernel_internal(resize, datapoints.BoundingBox) def resize_bounding_box( bounding_box: torch.Tensor, spatial_size: Tuple[int, int], size: List[int], max_size: Optional[int] = None ) -> Tuple[torch.Tensor, Tuple[int, int]]: @@ -320,7 +317,6 @@ def resize_bounding_box( ) -@_register_kernel_internal(resize, datapoints.Video) def resize_video( video: torch.Tensor, size: List[int], diff --git a/torchvision/transforms/v2/functional/_utils.py b/torchvision/transforms/v2/functional/_utils.py index 23eeb626b9b..57082628ecf 100644 --- a/torchvision/transforms/v2/functional/_utils.py +++ b/torchvision/transforms/v2/functional/_utils.py @@ -1,9 +1,6 @@ -import functools -import inspect from typing import Any import torch -from torchvision import datapoints from torchvision.datapoints._datapoint import Datapoint @@ -14,48 +11,7 @@ def is_simple_tensor(inpt: Any) -> bool: _KERNEL_REGISTRY = {} -def _kernel_wrapper_internal(dispatcher, kernel): - dispatcher_params = list(inspect.signature(dispatcher).parameters)[1:] - kernel_params = list(inspect.signature(kernel).parameters)[1:] - - needs_args_kwargs_handling = kernel_params != dispatcher_params - - # this avoids converting list -> set at runtime below - kernel_params = set(kernel_params) - - @functools.wraps(kernel) - def wrapper(inpt, *args, **kwargs): - input_type = type(inpt) - - if needs_args_kwargs_handling: - # Convert args to kwargs to simplify further processing - kwargs.update(dict(zip(dispatcher_params, args))) - args = () - - # drop parameters that are not relevant for the kernel, but have a default value - # in the dispatcher - for kwarg in kwargs.keys() - kernel_params: - del kwargs[kwarg] - - # add parameters that are passed implicitly to the dispatcher as metadata, - # but have to be explicit for the kernel - for kwarg in input_type.__annotations__.keys() & kernel_params: - kwargs[kwarg] = getattr(inpt, kwarg) - - output = kernel(inpt.as_subclass(torch.Tensor), *args, **kwargs) - - if isinstance(inpt, datapoints.BoundingBox) and isinstance(output, tuple): - output, spatial_size = output - metadata = dict(spatial_size=spatial_size) - else: - metadata = dict() - - return input_type.wrap_like(inpt, output, **metadata) - - return wrapper - - -def _register_kernel_internal(dispatcher, datapoint_cls, *, wrap_kernel=True): +def register_kernel(dispatcher, datapoint_cls): registry = _KERNEL_REGISTRY.setdefault(dispatcher, {}) if datapoint_cls in registry: raise TypeError( @@ -63,38 +19,16 @@ def _register_kernel_internal(dispatcher, datapoint_cls, *, wrap_kernel=True): ) def decorator(kernel): - registry[datapoint_cls] = _kernel_wrapper_internal(dispatcher, kernel) if wrap_kernel else kernel + registry[datapoint_cls] = kernel return kernel return decorator -def register_kernel(dispatcher, datapoint_cls): - return _register_kernel_internal(dispatcher, datapoint_cls, wrap_kernel=False) - - def _noop(inpt, *args, **kwargs): return inpt -def _register_explicit_noop(dispatcher, *datapoints_classes): - """ - Although this looks redundant with the no-op behavior of _get_kernel, this explicit registration prevents users - from registering kernels for builtin datapoints on builtin dispatchers that rely on the no-op behavior. - - For example, without explicit no-op registration the following would be valid user code: - - .. code:: - from torchvision.transforms.v2 import functional as F - - @F.register_kernel(F.adjust_brightness, datapoints.BoundingBox) - def lol(...): - ... - """ - for cls in datapoints_classes: - register_kernel(dispatcher, cls)(_noop) - - def _get_kernel(dispatcher, datapoint_cls): registry = _KERNEL_REGISTRY.get(dispatcher) if not registry: