Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PoC arbitrary dispatch #3

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions poc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import contextlib

import torch
from torchvision import datapoints
from torchvision.transforms import v2 as transforms
from torchvision.transforms.v2 import functional as F


@contextlib.contextmanager
def assert_raises():
try:
yield
except:
pass
else:
assert False


image = datapoints.Image(torch.rand(3, 32, 32))
bounding_box = datapoints.BoundingBox([1, 2, 3, 4], format="XYXY", spatial_size=(32, 32))
mask = datapoints.Mask(torch.randint(10, size=(32, 32), dtype=torch.uint8))
video = datapoints.Video(torch.rand(2, 3, 32, 32))

assert (F.equalize(image) != image).all()

with assert_raises():
F.equalize(bounding_box)

with assert_raises():
F.equalize(mask)

assert (F.equalize(video) != video).all()

transform = transforms.RandomEqualize(p=1)

sample = dict(
image=image,
bounding_box=bounding_box,
mask=mask,
video=video,
)

transformed_sample = transform(sample)

assert (transformed_sample["image"] != sample["image"]).all()
assert transformed_sample["bounding_box"] is sample["bounding_box"]
assert transformed_sample["mask"] is sample["mask"]
assert (transformed_sample["video"] != sample["video"]).all()
2 changes: 2 additions & 0 deletions torchvision/transforms/v2/_color.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,8 @@ class RandomEqualize(_RandomApplyTransform):

_v1_transform_cls = _transforms.RandomEqualize

_USED_DISPATCHERS = {F.equalize}

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

QQs:

  1. Do we expect more than one dispatcher per transform?
  2. does this need really to be a set?

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return F.equalize(inpt)

Expand Down
25 changes: 15 additions & 10 deletions torchvision/transforms/v2/_transform.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,32 @@
from __future__ import annotations

import enum
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union

import PIL.Image
import torch
from torch import nn
from torch.utils._pytree import tree_flatten, tree_unflatten
from torchvision import datapoints
from torchvision.transforms.v2.functional._utils import _get_supported_datapoint_classes
from torchvision.transforms.v2.utils import check_type, has_any, is_simple_tensor
from torchvision.utils import _log_api_usage_once


class Transform(nn.Module):

# Class attribute defining transformed types. Other types are passed-through without any transformation
# We support both Types and callables that are able to do further checks on the type of the input.
_transformed_types: Tuple[Union[Type, Callable[[Any], bool]], ...] = (torch.Tensor, PIL.Image.Image)
# This needs to be re-defined by the subclasses
_USED_DISPATCHERS: Set[Type] = {}
Copy link

@NicolasHug NicolasHug Jul 28, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure about the pros and cons of making this None instead... Are we cool with mutable class attributes? Maybe we can think of that later

(If the answer is that this makes mypy unhappy then hmmmmmmmm)

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, we can go with either.


def __init__(self) -> None:
super().__init__()
_log_api_usage_once(self)

dispatchers = self._USED_DISPATCHERS.copy()
types = _get_supported_datapoint_classes(dispatchers.pop())
for dispatcher in dispatchers:
types &= _get_supported_datapoint_classes(dispatcher)
self._supported_datapoint_types = types
Comment on lines +25 to +28

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this just

types = {}
for dispatcher in dispatchers:
    types.update(_get_supported_datapoint_classes(dispatcher))

?

and _get_supported_datapoint_classes probably doesn't need to return a set if we do that

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wanted to make sure that we only go for the common ground. Meaning, we only transfer types that are support by all dispatchers. This is why I have the intersection & in there.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It makes we wonder about the semantic and the use of _supported_datapoint_types.. what if there's a transform with multiple dispatchers where each of them handle a different part of the input? Would this transform end-up being a passthrough for everything?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean typically if we're using the intersection, we can't just have a simple logic in the transforms like

if type(inpt) not in self._supported_datapoint_types:
    # pass-through inpt

We can't use the union either because that would pass some inputs to some dispatchers that don't support them (and error).

Seems like we have to implement the pass-through logic on a per-dispatcher basis instead (within the transforms).

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like we have to implement the pass-through logic on a per-dispatcher basis instead (within the transforms).

You are right. Let me see how this fits in


def _check_inputs(self, flat_inputs: List[Any]) -> None:
pass

Expand Down Expand Up @@ -68,15 +73,15 @@ def _needs_transform_list(self, flat_inputs: List[Any]) -> List[bool]:
needs_transform_list = []
transform_simple_tensor = not has_any(flat_inputs, datapoints.Image, datapoints.Video, PIL.Image.Image)
for inpt in flat_inputs:
needs_transform = True
needs_transform = False

if not check_type(inpt, self._transformed_types):
needs_transform = False
# FIXME: we need to find a way to indicate "this transform doesn't support PIL images / plain tensors"
if any(isinstance(inpt, cls) for cls in [*self._supported_datapoint_types, PIL.Image.Image]):
needs_transform = True
elif is_simple_tensor(inpt):
if transform_simple_tensor:
needs_transform = True
transform_simple_tensor = False
else:
needs_transform = False
needs_transform_list.append(needs_transform)
return needs_transform_list

Expand Down
37 changes: 20 additions & 17 deletions torchvision/transforms/v2/functional/_color.py
Original file line number Diff line number Diff line change
Expand Up @@ -557,6 +557,25 @@ def autocontrast(inpt: datapoints._InputTypeJIT) -> datapoints._InputTypeJIT:
)


def equalize(inpt: datapoints._InputTypeJIT) -> datapoints._InputTypeJIT:
if not torch.jit.is_scripting():
_log_api_usage_once(equalize)

if torch.jit.is_scripting() or is_simple_tensor(inpt):
return equalize_image_tensor(inpt)
elif isinstance(inpt, PIL.Image.Image):
return equalize_image_pil(inpt)
Comment on lines +564 to +567

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can worry about all that later, but will we be able to also register kernels for simple tensors and for pil? Or are we going to hit JIT issues?

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would love that, since we would remove the special casing. I don't think JIT will an issue as long as we have a

if torch.jit.is_scripting():
    return ...

at the top.

elif isinstance(inpt, datapoints._datapoint.Datapoint):
kernel = _get_kernel(equalize, type(inpt))
return kernel(inpt)
else:
raise TypeError(
f"Input can either be a plain tensor, a PIL image, or a TorchVision datapoint, "
f"but got {type(inpt)} instead."
)


@_register_kernel_internal(equalize, datapoints.Image)
def equalize_image_tensor(image: torch.Tensor) -> torch.Tensor:
if image.numel() == 0:
return image
Expand Down Expand Up @@ -629,27 +648,11 @@ def equalize_image_tensor(image: torch.Tensor) -> torch.Tensor:
equalize_image_pil = _FP.equalize


@_register_kernel_internal(equalize, datapoints.Video)
def equalize_video(video: torch.Tensor) -> torch.Tensor:
return equalize_image_tensor(video)


def equalize(inpt: datapoints._InputTypeJIT) -> datapoints._InputTypeJIT:
if not torch.jit.is_scripting():
_log_api_usage_once(equalize)

if torch.jit.is_scripting() or is_simple_tensor(inpt):
return equalize_image_tensor(inpt)
elif isinstance(inpt, datapoints._datapoint.Datapoint):
return inpt.equalize()
elif isinstance(inpt, PIL.Image.Image):
return equalize_image_pil(inpt)
else:
raise TypeError(
f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
f"but got {type(inpt)} instead."
)


def invert_image_tensor(image: torch.Tensor) -> torch.Tensor:
if image.is_floating_point():
return 1.0 - image
Expand Down
27 changes: 21 additions & 6 deletions torchvision/transforms/v2/functional/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,16 @@ def decorator(kernel):
return decorator


_BUILTIN_DATAPOINT_CLASSES = {
obj for obj in datapoints.__dict__.values() if isinstance(obj, type) and issubclass(obj, Datapoint)
}


def register_kernel(dispatcher, datapoint_cls):
if not issubclass(datapoint_cls, Datapoint):
raise TypeError(f"Can only register subclasses of datapoints.Datapoint, but got {datapoint_cls}.")
elif datapoint_cls in _BUILTIN_DATAPOINT_CLASSES:
raise TypeError(f"Cannot register kernel for builtin datapoints.{datapoint_cls.__name__}.")
return _register_kernel_internal(dispatcher, datapoint_cls, wrap_kernel=False)


Expand All @@ -91,14 +100,16 @@ def _register_explicit_noop(dispatcher, *datapoints_classes):
def lol(...):
...
"""
for cls in datapoints_classes:
register_kernel(dispatcher, cls)(_noop)
# FIXME: Just still here to keep the diff minimal
pass


def _get_supported_datapoint_classes(dispatcher):
return set(_KERNEL_REGISTRY.get(dispatcher, {}).keys())


def _get_kernel(dispatcher, datapoint_cls):
registry = _KERNEL_REGISTRY.get(dispatcher)
if not registry:
raise ValueError(f"No kernel registered for dispatcher '{dispatcher.__name__}'.")
registry = _KERNEL_REGISTRY[dispatcher]

if datapoint_cls in registry:
return registry[datapoint_cls]
Expand All @@ -107,4 +118,8 @@ def _get_kernel(dispatcher, datapoint_cls):
if issubclass(datapoint_cls, registered_cls):
return kernel

return _noop
supported_datapoint_type_names = "'" + "', '".join(cls.__name__ for cls in registry.keys()) + "'"
raise TypeError(
f"Datapoint type '{datapoint_cls.__name__}' is not supported by {dispatcher.__name__}. "
f"Supported types are {supported_datapoint_type_names} and subclasses thereof."
)