-
Notifications
You must be signed in to change notification settings - Fork 4
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
PoC arbitrary dispatch #3
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
import contextlib | ||
|
||
import torch | ||
from torchvision import datapoints | ||
from torchvision.transforms import v2 as transforms | ||
from torchvision.transforms.v2 import functional as F | ||
|
||
|
||
@contextlib.contextmanager | ||
def assert_raises(): | ||
try: | ||
yield | ||
except: | ||
pass | ||
else: | ||
assert False | ||
|
||
|
||
image = datapoints.Image(torch.rand(3, 32, 32)) | ||
bounding_box = datapoints.BoundingBox([1, 2, 3, 4], format="XYXY", spatial_size=(32, 32)) | ||
mask = datapoints.Mask(torch.randint(10, size=(32, 32), dtype=torch.uint8)) | ||
video = datapoints.Video(torch.rand(2, 3, 32, 32)) | ||
|
||
assert (F.equalize(image) != image).all() | ||
|
||
with assert_raises(): | ||
F.equalize(bounding_box) | ||
|
||
with assert_raises(): | ||
F.equalize(mask) | ||
|
||
assert (F.equalize(video) != video).all() | ||
|
||
transform = transforms.RandomEqualize(p=1) | ||
|
||
sample = dict( | ||
image=image, | ||
bounding_box=bounding_box, | ||
mask=mask, | ||
video=video, | ||
) | ||
|
||
transformed_sample = transform(sample) | ||
|
||
assert (transformed_sample["image"] != sample["image"]).all() | ||
assert transformed_sample["bounding_box"] is sample["bounding_box"] | ||
assert transformed_sample["mask"] is sample["mask"] | ||
assert (transformed_sample["video"] != sample["video"]).all() |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,27 +1,32 @@ | ||
from __future__ import annotations | ||
|
||
import enum | ||
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union | ||
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union | ||
|
||
import PIL.Image | ||
import torch | ||
from torch import nn | ||
from torch.utils._pytree import tree_flatten, tree_unflatten | ||
from torchvision import datapoints | ||
from torchvision.transforms.v2.functional._utils import _get_supported_datapoint_classes | ||
from torchvision.transforms.v2.utils import check_type, has_any, is_simple_tensor | ||
from torchvision.utils import _log_api_usage_once | ||
|
||
|
||
class Transform(nn.Module): | ||
|
||
# Class attribute defining transformed types. Other types are passed-through without any transformation | ||
# We support both Types and callables that are able to do further checks on the type of the input. | ||
_transformed_types: Tuple[Union[Type, Callable[[Any], bool]], ...] = (torch.Tensor, PIL.Image.Image) | ||
# This needs to be re-defined by the subclasses | ||
_USED_DISPATCHERS: Set[Type] = {} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not sure about the pros and cons of making this (If the answer is that this makes mypy unhappy then hmmmmmmmm) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No, we can go with either. |
||
|
||
def __init__(self) -> None: | ||
super().__init__() | ||
_log_api_usage_once(self) | ||
|
||
dispatchers = self._USED_DISPATCHERS.copy() | ||
types = _get_supported_datapoint_classes(dispatchers.pop()) | ||
for dispatcher in dispatchers: | ||
types &= _get_supported_datapoint_classes(dispatcher) | ||
self._supported_datapoint_types = types | ||
Comment on lines
+25
to
+28
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this just types = {}
for dispatcher in dispatchers:
types.update(_get_supported_datapoint_classes(dispatcher)) ? and There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I wanted to make sure that we only go for the common ground. Meaning, we only transfer types that are support by all dispatchers. This is why I have the intersection There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It makes we wonder about the semantic and the use of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I mean typically if we're using the intersection, we can't just have a simple logic in the transforms like if type(inpt) not in self._supported_datapoint_types:
# pass-through inpt We can't use the union either because that would pass some inputs to some dispatchers that don't support them (and error). Seems like we have to implement the pass-through logic on a per-dispatcher basis instead (within the transforms). There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
You are right. Let me see how this fits in |
||
|
||
def _check_inputs(self, flat_inputs: List[Any]) -> None: | ||
pass | ||
|
||
|
@@ -68,15 +73,15 @@ def _needs_transform_list(self, flat_inputs: List[Any]) -> List[bool]: | |
needs_transform_list = [] | ||
transform_simple_tensor = not has_any(flat_inputs, datapoints.Image, datapoints.Video, PIL.Image.Image) | ||
for inpt in flat_inputs: | ||
needs_transform = True | ||
needs_transform = False | ||
|
||
if not check_type(inpt, self._transformed_types): | ||
needs_transform = False | ||
# FIXME: we need to find a way to indicate "this transform doesn't support PIL images / plain tensors" | ||
if any(isinstance(inpt, cls) for cls in [*self._supported_datapoint_types, PIL.Image.Image]): | ||
needs_transform = True | ||
elif is_simple_tensor(inpt): | ||
if transform_simple_tensor: | ||
needs_transform = True | ||
transform_simple_tensor = False | ||
else: | ||
needs_transform = False | ||
needs_transform_list.append(needs_transform) | ||
return needs_transform_list | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -557,6 +557,25 @@ def autocontrast(inpt: datapoints._InputTypeJIT) -> datapoints._InputTypeJIT: | |
) | ||
|
||
|
||
def equalize(inpt: datapoints._InputTypeJIT) -> datapoints._InputTypeJIT: | ||
if not torch.jit.is_scripting(): | ||
_log_api_usage_once(equalize) | ||
|
||
if torch.jit.is_scripting() or is_simple_tensor(inpt): | ||
return equalize_image_tensor(inpt) | ||
elif isinstance(inpt, PIL.Image.Image): | ||
return equalize_image_pil(inpt) | ||
Comment on lines
+564
to
+567
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We can worry about all that later, but will we be able to also register kernels for simple tensors and for pil? Or are we going to hit JIT issues? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would love that, since we would remove the special casing. I don't think JIT will an issue as long as we have a if torch.jit.is_scripting():
return ... at the top. |
||
elif isinstance(inpt, datapoints._datapoint.Datapoint): | ||
kernel = _get_kernel(equalize, type(inpt)) | ||
return kernel(inpt) | ||
else: | ||
raise TypeError( | ||
f"Input can either be a plain tensor, a PIL image, or a TorchVision datapoint, " | ||
f"but got {type(inpt)} instead." | ||
) | ||
|
||
|
||
@_register_kernel_internal(equalize, datapoints.Image) | ||
def equalize_image_tensor(image: torch.Tensor) -> torch.Tensor: | ||
if image.numel() == 0: | ||
return image | ||
|
@@ -629,27 +648,11 @@ def equalize_image_tensor(image: torch.Tensor) -> torch.Tensor: | |
equalize_image_pil = _FP.equalize | ||
|
||
|
||
@_register_kernel_internal(equalize, datapoints.Video) | ||
def equalize_video(video: torch.Tensor) -> torch.Tensor: | ||
return equalize_image_tensor(video) | ||
|
||
|
||
def equalize(inpt: datapoints._InputTypeJIT) -> datapoints._InputTypeJIT: | ||
if not torch.jit.is_scripting(): | ||
_log_api_usage_once(equalize) | ||
|
||
if torch.jit.is_scripting() or is_simple_tensor(inpt): | ||
return equalize_image_tensor(inpt) | ||
elif isinstance(inpt, datapoints._datapoint.Datapoint): | ||
return inpt.equalize() | ||
elif isinstance(inpt, PIL.Image.Image): | ||
return equalize_image_pil(inpt) | ||
else: | ||
raise TypeError( | ||
f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " | ||
f"but got {type(inpt)} instead." | ||
) | ||
|
||
|
||
def invert_image_tensor(image: torch.Tensor) -> torch.Tensor: | ||
if image.is_floating_point(): | ||
return 1.0 - image | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
QQs:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes. For example
No. What would you prefer?