Skip to content

Commit

Permalink
[PoC] register all kernels centrally
Browse files Browse the repository at this point in the history
  • Loading branch information
pmeier committed Jul 27, 2023
1 parent f178373 commit a9125f1
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 79 deletions.
67 changes: 67 additions & 0 deletions torchvision/transforms/v2/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,3 +170,70 @@
from ._type_conversion import pil_to_tensor, to_image_pil, to_image_tensor, to_pil_image

from ._deprecated import get_image_size, to_tensor # usort: skip


def _register_builtin_kernels():
import functools
import inspect

import torch
from torchvision import datapoints

from ._utils import _KERNEL_REGISTRY, _noop

def default_kernel_wrapper(dispatcher, kernel):
dispatcher_params = list(inspect.signature(dispatcher).parameters)[1:]
kernel_params = list(inspect.signature(kernel).parameters)[1:]

needs_args_kwargs_handling = kernel_params != dispatcher_params

# this avoids converting list -> set at runtime below
kernel_params = set(kernel_params)

@functools.wraps(kernel)
def wrapper(inpt, *args, **kwargs):
input_type = type(inpt)

if needs_args_kwargs_handling:
# Convert args to kwargs to simplify further processing
kwargs.update(dict(zip(dispatcher_params, args)))
args = ()

# drop parameters that are not relevant for the kernel, but have a default value
# in the dispatcher
for kwarg in kwargs.keys() - kernel_params:
del kwargs[kwarg]

# add parameters that are passed implicitly to the dispatcher as metadata,
# but have to be explicit for the kernel
for kwarg in input_type.__annotations__.keys() & kernel_params:
kwargs[kwarg] = getattr(inpt, kwarg)

output = kernel(inpt.as_subclass(torch.Tensor), *args, **kwargs)

if isinstance(inpt, datapoints.BoundingBox) and isinstance(output, tuple):
output, spatial_size = output
metadata = dict(spatial_size=spatial_size)
else:
metadata = dict()

return input_type.wrap_like(inpt, output, **metadata)

return wrapper

def register(dispatcher, datapoint_cls, kernel):
_KERNEL_REGISTRY.setdefault(dispatcher, {})[datapoint_cls] = default_kernel_wrapper(dispatcher, kernel)

register(resize, datapoints.Image, resize_image_tensor)
register(resize, datapoints.BoundingBox, resize_bounding_box)
register(resize, datapoints.Mask, resize_mask)
register(resize, datapoints.Video, resize_video)

register(adjust_brightness, datapoints.Image, adjust_brightness_image_tensor)
register(adjust_brightness, datapoints.BoundingBox, _noop)
register(adjust_brightness, datapoints.Mask, _noop)
register(adjust_brightness, datapoints.Video, adjust_brightness_video)


_register_builtin_kernels()
del _register_builtin_kernels
7 changes: 1 addition & 6 deletions torchvision/transforms/v2/functional/_color.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from torchvision.utils import _log_api_usage_once

from ._meta import _num_value_bits, to_dtype_image_tensor
from ._utils import _get_kernel, _register_explicit_noop, _register_kernel_internal, is_simple_tensor
from ._utils import _get_kernel, is_simple_tensor


def _rgb_to_grayscale_image_tensor(
Expand Down Expand Up @@ -87,10 +87,6 @@ def adjust_brightness(inpt: datapoints._InputTypeJIT, brightness_factor: float)
)


_register_explicit_noop(adjust_brightness, datapoints.BoundingBox, datapoints.Mask)


@_register_kernel_internal(adjust_brightness, datapoints.Image)
def adjust_brightness_image_tensor(image: torch.Tensor, brightness_factor: float) -> torch.Tensor:
if brightness_factor < 0:
raise ValueError(f"brightness_factor ({brightness_factor}) is not non-negative.")
Expand All @@ -109,7 +105,6 @@ def adjust_brightness_image_pil(image: PIL.Image.Image, brightness_factor: float
return _FP.adjust_brightness(image, brightness_factor=brightness_factor)


@_register_kernel_internal(adjust_brightness, datapoints.Video)
def adjust_brightness_video(video: torch.Tensor, brightness_factor: float) -> torch.Tensor:
return adjust_brightness_image_tensor(video, brightness_factor=brightness_factor)

Expand Down
6 changes: 1 addition & 5 deletions torchvision/transforms/v2/functional/_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

from ._meta import clamp_bounding_box, convert_format_bounding_box, get_spatial_size_image_pil

from ._utils import _get_kernel, _register_kernel_internal, is_simple_tensor
from ._utils import _get_kernel, is_simple_tensor


def _check_interpolation(interpolation: Union[InterpolationMode, int]) -> InterpolationMode:
Expand Down Expand Up @@ -183,7 +183,6 @@ def resize(
)


@_register_kernel_internal(resize, datapoints.Image)
def resize_image_tensor(
image: torch.Tensor,
size: List[int],
Expand Down Expand Up @@ -285,7 +284,6 @@ def resize_image_pil(
return image.resize((new_width, new_height), resample=pil_modes_mapping[interpolation])


@_register_kernel_internal(resize, datapoints.Mask)
def resize_mask(mask: torch.Tensor, size: List[int], max_size: Optional[int] = None) -> torch.Tensor:
if mask.ndim < 3:
mask = mask.unsqueeze(0)
Expand All @@ -301,7 +299,6 @@ def resize_mask(mask: torch.Tensor, size: List[int], max_size: Optional[int] = N
return output


@_register_kernel_internal(resize, datapoints.BoundingBox)
def resize_bounding_box(
bounding_box: torch.Tensor, spatial_size: Tuple[int, int], size: List[int], max_size: Optional[int] = None
) -> Tuple[torch.Tensor, Tuple[int, int]]:
Expand All @@ -320,7 +317,6 @@ def resize_bounding_box(
)


@_register_kernel_internal(resize, datapoints.Video)
def resize_video(
video: torch.Tensor,
size: List[int],
Expand Down
70 changes: 2 additions & 68 deletions torchvision/transforms/v2/functional/_utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
import functools
import inspect
from typing import Any

import torch
from torchvision import datapoints
from torchvision.datapoints._datapoint import Datapoint


Expand All @@ -14,87 +11,24 @@ def is_simple_tensor(inpt: Any) -> bool:
_KERNEL_REGISTRY = {}


def _kernel_wrapper_internal(dispatcher, kernel):
dispatcher_params = list(inspect.signature(dispatcher).parameters)[1:]
kernel_params = list(inspect.signature(kernel).parameters)[1:]

needs_args_kwargs_handling = kernel_params != dispatcher_params

# this avoids converting list -> set at runtime below
kernel_params = set(kernel_params)

@functools.wraps(kernel)
def wrapper(inpt, *args, **kwargs):
input_type = type(inpt)

if needs_args_kwargs_handling:
# Convert args to kwargs to simplify further processing
kwargs.update(dict(zip(dispatcher_params, args)))
args = ()

# drop parameters that are not relevant for the kernel, but have a default value
# in the dispatcher
for kwarg in kwargs.keys() - kernel_params:
del kwargs[kwarg]

# add parameters that are passed implicitly to the dispatcher as metadata,
# but have to be explicit for the kernel
for kwarg in input_type.__annotations__.keys() & kernel_params:
kwargs[kwarg] = getattr(inpt, kwarg)

output = kernel(inpt.as_subclass(torch.Tensor), *args, **kwargs)

if isinstance(inpt, datapoints.BoundingBox) and isinstance(output, tuple):
output, spatial_size = output
metadata = dict(spatial_size=spatial_size)
else:
metadata = dict()

return input_type.wrap_like(inpt, output, **metadata)

return wrapper


def _register_kernel_internal(dispatcher, datapoint_cls, *, wrap_kernel=True):
def register_kernel(dispatcher, datapoint_cls):
registry = _KERNEL_REGISTRY.setdefault(dispatcher, {})
if datapoint_cls in registry:
raise TypeError(
f"Dispatcher '{dispatcher.__name__}' already has a kernel registered for type '{datapoint_cls.__name__}'."
)

def decorator(kernel):
registry[datapoint_cls] = _kernel_wrapper_internal(dispatcher, kernel) if wrap_kernel else kernel
registry[datapoint_cls] = kernel
return kernel

return decorator


def register_kernel(dispatcher, datapoint_cls):
return _register_kernel_internal(dispatcher, datapoint_cls, wrap_kernel=False)


def _noop(inpt, *args, **kwargs):
return inpt


def _register_explicit_noop(dispatcher, *datapoints_classes):
"""
Although this looks redundant with the no-op behavior of _get_kernel, this explicit registration prevents users
from registering kernels for builtin datapoints on builtin dispatchers that rely on the no-op behavior.
For example, without explicit no-op registration the following would be valid user code:
.. code::
from torchvision.transforms.v2 import functional as F
@F.register_kernel(F.adjust_brightness, datapoints.BoundingBox)
def lol(...):
...
"""
for cls in datapoints_classes:
register_kernel(dispatcher, cls)(_noop)


def _get_kernel(dispatcher, datapoint_cls):
registry = _KERNEL_REGISTRY.get(dispatcher)
if not registry:
Expand Down

0 comments on commit a9125f1

Please sign in to comment.