Skip to content

Commit

Permalink
add dispatch to adjust_brightness
Browse files Browse the repository at this point in the history
  • Loading branch information
pmeier committed Jul 27, 2023
1 parent f36c64c commit bbaa35c
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 22 deletions.
4 changes: 2 additions & 2 deletions test/test_transforms_v2_refactored.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def _check_dispatcher_dispatch(dispatcher, kernel, input, *args, **kwargs):
preserved in doing so. For bounding boxes also checks that the format is preserved.
"""
if isinstance(input, datapoints._datapoint.Datapoint):
if dispatcher is F.resize:
if dispatcher in {F.resize, F.adjust_brightness}:
output = dispatcher(input, *args, **kwargs)
else:
# Due to our complex dispatch architecture for datapoints, we cannot spy on the kernel directly,
Expand Down Expand Up @@ -254,7 +254,7 @@ def _check_dispatcher_kernel_signature_match(dispatcher, *, kernel, input_type):

def _check_dispatcher_datapoint_signature_match(dispatcher):
"""Checks if the signature of the dispatcher matches the corresponding method signature on the Datapoint class."""
if dispatcher is F.resize:
if dispatcher in {F.resize, F.adjust_brightness}:
return
dispatcher_signature = inspect.signature(dispatcher)
dispatcher_params = list(dispatcher_signature.parameters.values())[1:]
Expand Down
39 changes: 21 additions & 18 deletions torchvision/transforms/v2/functional/_color.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from torchvision.utils import _log_api_usage_once

from ._meta import _num_value_bits, convert_dtype_image_tensor
from ._utils import is_simple_tensor
from ._utils import _get_kernel, is_simple_tensor, register_kernel


def _rgb_to_grayscale_image_tensor(
Expand Down Expand Up @@ -69,6 +69,25 @@ def _blend(image1: torch.Tensor, image2: torch.Tensor, ratio: float) -> torch.Te
return output if fp else output.to(image1.dtype)


def adjust_brightness(inpt: datapoints._InputTypeJIT, brightness_factor: float) -> datapoints._InputTypeJIT:
if not torch.jit.is_scripting():
_log_api_usage_once(adjust_brightness)

if torch.jit.is_scripting() or is_simple_tensor(inpt):
kernel = _get_kernel(adjust_brightness, type(inpt))
return kernel(inpt, brightness_factor=brightness_factor)
elif isinstance(inpt, datapoints._datapoint.Datapoint):
return inpt.adjust_brightness(brightness_factor=brightness_factor)
elif isinstance(inpt, PIL.Image.Image):
return adjust_brightness_image_pil(inpt, brightness_factor=brightness_factor)
else:
raise TypeError(
f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
f"but got {type(inpt)} instead."
)


@register_kernel(adjust_brightness, datapoints.Image)
def adjust_brightness_image_tensor(image: torch.Tensor, brightness_factor: float) -> torch.Tensor:
if brightness_factor < 0:
raise ValueError(f"brightness_factor ({brightness_factor}) is not non-negative.")
Expand All @@ -86,27 +105,11 @@ def adjust_brightness_image_tensor(image: torch.Tensor, brightness_factor: float
adjust_brightness_image_pil = _FP.adjust_brightness


@register_kernel(adjust_brightness, datapoints.Video)
def adjust_brightness_video(video: torch.Tensor, brightness_factor: float) -> torch.Tensor:
return adjust_brightness_image_tensor(video, brightness_factor=brightness_factor)


def adjust_brightness(inpt: datapoints._InputTypeJIT, brightness_factor: float) -> datapoints._InputTypeJIT:
if not torch.jit.is_scripting():
_log_api_usage_once(adjust_brightness)

if torch.jit.is_scripting() or is_simple_tensor(inpt):
return adjust_brightness_image_tensor(inpt, brightness_factor=brightness_factor)
elif isinstance(inpt, datapoints._datapoint.Datapoint):
return inpt.adjust_brightness(brightness_factor=brightness_factor)
elif isinstance(inpt, PIL.Image.Image):
return adjust_brightness_image_pil(inpt, brightness_factor=brightness_factor)
else:
raise TypeError(
f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
f"but got {type(inpt)} instead."
)


def adjust_saturation_image_tensor(image: torch.Tensor, saturation_factor: float) -> torch.Tensor:
if saturation_factor < 0:
raise ValueError(f"saturation_factor ({saturation_factor}) is not non-negative.")
Expand Down
15 changes: 13 additions & 2 deletions torchvision/transforms/v2/functional/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,23 @@ def wrapper(inpt, *args, **kwargs):

return wrapper

registry = _KERNEL_REGISTRY.get(dispatcher, {})

def decorator(kernel):
_KERNEL_REGISTRY[(dispatcher, datapoint_cls)] = datapoint_wrapper(kernel) if datapoint_wrapping else kernel
registry[datapoint_cls] = datapoint_wrapper(kernel) if datapoint_wrapping else kernel
return kernel

return decorator


def _noop(inpt, *args, **kwargs):
return inpt


def _get_kernel(dispatcher, datapoint_cls):
return _KERNEL_REGISTRY[(dispatcher, datapoint_cls)]
registry = _KERNEL_REGISTRY.get(dispatcher, {})

if datapoint_cls in registry:
return registry[datapoint_cls]

return _noop

0 comments on commit bbaa35c

Please sign in to comment.