From 002f1e2a631753fd0a304d6110d2c7f314a9c733 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Fri, 28 Jul 2023 14:29:59 +0200 Subject: [PATCH 1/2] poc --- poc.py | 48 +++++++++++++++++++ torchvision/transforms/v2/_color.py | 2 + torchvision/transforms/v2/_transform.py | 25 ++++++---- .../transforms/v2/functional/_color.py | 37 +++++++------- .../transforms/v2/functional/_utils.py | 27 ++++++++--- 5 files changed, 106 insertions(+), 33 deletions(-) create mode 100644 poc.py diff --git a/poc.py b/poc.py new file mode 100644 index 00000000000..a6697ea3e5d --- /dev/null +++ b/poc.py @@ -0,0 +1,48 @@ +import contextlib + +import torch +from torchvision import datapoints +from torchvision.transforms import v2 as transforms +from torchvision.transforms.v2 import functional as F + + +@contextlib.contextmanager +def assert_raises(): + try: + yield + except: + pass + else: + assert False + + +image = datapoints.Image(torch.rand(3, 32, 32)) +bounding_box = datapoints.BoundingBox([1, 2, 3, 4], format="XYXY", spatial_size=(32, 32)) +mask = datapoints.Mask(torch.randint(10, size=(32, 32), dtype=torch.uint8)) +video = datapoints.Video(torch.rand(2, 3, 32, 32)) + +assert (F.equalize(image) != image).all() + +with assert_raises(): + F.equalize(bounding_box) + +with assert_raises(): + F.equalize(mask) + +assert (F.equalize(video) != video).all() + +transform = transforms.RandomEqualize(p=1) + +sample = dict( + image=image, + bounding_box=bounding_box, + mask=mask, + video=video, +) + +transformed_sample = transform(sample) + +assert (transformed_sample["image"] != sample["image"]).all() +assert transformed_sample["bounding_box"] is sample["bounding_box"] +assert transformed_sample["mask"] is sample["mask"] +assert (transformed_sample["video"] != sample["video"]).all() diff --git a/torchvision/transforms/v2/_color.py b/torchvision/transforms/v2/_color.py index 7dd8eeae236..4cfa5049112 100644 --- a/torchvision/transforms/v2/_color.py +++ b/torchvision/transforms/v2/_color.py @@ -289,6 +289,8 @@ class RandomEqualize(_RandomApplyTransform): _v1_transform_cls = _transforms.RandomEqualize + _USED_DISPATCHERS = {F.equalize} + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: return F.equalize(inpt) diff --git a/torchvision/transforms/v2/_transform.py b/torchvision/transforms/v2/_transform.py index f83ed5d6e11..1d1669e765f 100644 --- a/torchvision/transforms/v2/_transform.py +++ b/torchvision/transforms/v2/_transform.py @@ -1,27 +1,32 @@ from __future__ import annotations import enum -from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union +from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union import PIL.Image import torch from torch import nn from torch.utils._pytree import tree_flatten, tree_unflatten from torchvision import datapoints +from torchvision.transforms.v2.functional._utils import _get_supported_datapoint_classes from torchvision.transforms.v2.utils import check_type, has_any, is_simple_tensor from torchvision.utils import _log_api_usage_once class Transform(nn.Module): - - # Class attribute defining transformed types. Other types are passed-through without any transformation - # We support both Types and callables that are able to do further checks on the type of the input. - _transformed_types: Tuple[Union[Type, Callable[[Any], bool]], ...] = (torch.Tensor, PIL.Image.Image) + # This needs to be populated by the subclasses + _USED_DISPATCHERS: Set[Type] = {} def __init__(self) -> None: super().__init__() _log_api_usage_once(self) + dispatchers = self._USED_DISPATCHERS.copy() + types = _get_supported_datapoint_classes(dispatchers.pop()) + for dispatcher in dispatchers: + types &= _get_supported_datapoint_classes(dispatcher) + self._supported_datapoint_types = types + def _check_inputs(self, flat_inputs: List[Any]) -> None: pass @@ -68,15 +73,15 @@ def _needs_transform_list(self, flat_inputs: List[Any]) -> List[bool]: needs_transform_list = [] transform_simple_tensor = not has_any(flat_inputs, datapoints.Image, datapoints.Video, PIL.Image.Image) for inpt in flat_inputs: - needs_transform = True + needs_transform = False - if not check_type(inpt, self._transformed_types): - needs_transform = False + # FIXME: we need to find a way to indicate "this transform doesn't support PIL images / plain tensors" + if any(isinstance(inpt, cls) for cls in [*self._supported_datapoint_types, PIL.Image.Image]): + needs_transform = True elif is_simple_tensor(inpt): if transform_simple_tensor: + needs_transform = True transform_simple_tensor = False - else: - needs_transform = False needs_transform_list.append(needs_transform) return needs_transform_list diff --git a/torchvision/transforms/v2/functional/_color.py b/torchvision/transforms/v2/functional/_color.py index 5a7008f8f0a..91bb7144125 100644 --- a/torchvision/transforms/v2/functional/_color.py +++ b/torchvision/transforms/v2/functional/_color.py @@ -557,6 +557,25 @@ def autocontrast(inpt: datapoints._InputTypeJIT) -> datapoints._InputTypeJIT: ) +def equalize(inpt: datapoints._InputTypeJIT) -> datapoints._InputTypeJIT: + if not torch.jit.is_scripting(): + _log_api_usage_once(equalize) + + if torch.jit.is_scripting() or is_simple_tensor(inpt): + return equalize_image_tensor(inpt) + elif isinstance(inpt, PIL.Image.Image): + return equalize_image_pil(inpt) + elif isinstance(inpt, datapoints._datapoint.Datapoint): + kernel = _get_kernel(equalize, type(inpt)) + return kernel(inpt) + else: + raise TypeError( + f"Input can either be a plain tensor, a PIL image, or a TorchVision datapoint, " + f"but got {type(inpt)} instead." + ) + + +@_register_kernel_internal(equalize, datapoints.Image) def equalize_image_tensor(image: torch.Tensor) -> torch.Tensor: if image.numel() == 0: return image @@ -629,27 +648,11 @@ def equalize_image_tensor(image: torch.Tensor) -> torch.Tensor: equalize_image_pil = _FP.equalize +@_register_kernel_internal(equalize, datapoints.Video) def equalize_video(video: torch.Tensor) -> torch.Tensor: return equalize_image_tensor(video) -def equalize(inpt: datapoints._InputTypeJIT) -> datapoints._InputTypeJIT: - if not torch.jit.is_scripting(): - _log_api_usage_once(equalize) - - if torch.jit.is_scripting() or is_simple_tensor(inpt): - return equalize_image_tensor(inpt) - elif isinstance(inpt, datapoints._datapoint.Datapoint): - return inpt.equalize() - elif isinstance(inpt, PIL.Image.Image): - return equalize_image_pil(inpt) - else: - raise TypeError( - f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " - f"but got {type(inpt)} instead." - ) - - def invert_image_tensor(image: torch.Tensor) -> torch.Tensor: if image.is_floating_point(): return 1.0 - image diff --git a/torchvision/transforms/v2/functional/_utils.py b/torchvision/transforms/v2/functional/_utils.py index 23eeb626b9b..4ed97388b1b 100644 --- a/torchvision/transforms/v2/functional/_utils.py +++ b/torchvision/transforms/v2/functional/_utils.py @@ -69,7 +69,16 @@ def decorator(kernel): return decorator +_BUILTIN_DATAPOINT_CLASSES = { + obj for obj in datapoints.__dict__.values() if isinstance(obj, type) and issubclass(obj, Datapoint) +} + + def register_kernel(dispatcher, datapoint_cls): + if not issubclass(datapoint_cls, Datapoint): + raise TypeError(f"Can only register subclasses of datapoints.Datapoint, but got {datapoint_cls}.") + elif datapoint_cls in _BUILTIN_DATAPOINT_CLASSES: + raise TypeError(f"Cannot register kernel for builtin datapoints.{datapoint_cls.__name__}.") return _register_kernel_internal(dispatcher, datapoint_cls, wrap_kernel=False) @@ -91,14 +100,16 @@ def _register_explicit_noop(dispatcher, *datapoints_classes): def lol(...): ... """ - for cls in datapoints_classes: - register_kernel(dispatcher, cls)(_noop) + # FIXME: Just still here to keep the diff minimal + pass + + +def _get_supported_datapoint_classes(dispatcher): + return set(_KERNEL_REGISTRY.get(dispatcher, {}).keys()) def _get_kernel(dispatcher, datapoint_cls): - registry = _KERNEL_REGISTRY.get(dispatcher) - if not registry: - raise ValueError(f"No kernel registered for dispatcher '{dispatcher.__name__}'.") + registry = _KERNEL_REGISTRY[dispatcher] if datapoint_cls in registry: return registry[datapoint_cls] @@ -107,4 +118,8 @@ def _get_kernel(dispatcher, datapoint_cls): if issubclass(datapoint_cls, registered_cls): return kernel - return _noop + supported_datapoint_type_names = "'" + "', '".join(cls.__name__ for cls in registry.keys()) + "'" + raise TypeError( + f"Datapoint type '{datapoint_cls.__name__}' is not supported by {dispatcher.__name__}. " + f"Supported types are {supported_datapoint_type_names} and subclasses thereof." + ) From 505a8a013be405c8790f5e245de9e8eb5ff8c86e Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Fri, 28 Jul 2023 15:08:47 +0200 Subject: [PATCH 2/2] Update torchvision/transforms/v2/_transform.py Co-authored-by: Nicolas Hug --- torchvision/transforms/v2/_transform.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/transforms/v2/_transform.py b/torchvision/transforms/v2/_transform.py index 1d1669e765f..caff6ec23d8 100644 --- a/torchvision/transforms/v2/_transform.py +++ b/torchvision/transforms/v2/_transform.py @@ -14,7 +14,7 @@ class Transform(nn.Module): - # This needs to be populated by the subclasses + # This needs to be re-defined by the subclasses _USED_DISPATCHERS: Set[Type] = {} def __init__(self) -> None: