Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PoC arbitrary dispatch #3

Closed
wants to merge 2 commits into from

Conversation

pmeier
Copy link
Owner

@pmeier pmeier commented Jul 28, 2023

@NicolasHug @vfdev-5 following our offline discussion. This is just a PoC for F.equalize / transforms.RandomEqualize. I have to clean up a little, but the general structure is there.

Copy link

@NicolasHug NicolasHug left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Philip, looks good, I'm just not sure about the need for sets

# We support both Types and callables that are able to do further checks on the type of the input.
_transformed_types: Tuple[Union[Type, Callable[[Any], bool]], ...] = (torch.Tensor, PIL.Image.Image)
# This needs to be populated by the subclasses
_USED_DISPATCHERS: Set[Type] = {}
Copy link

@NicolasHug NicolasHug Jul 28, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure about the pros and cons of making this None instead... Are we cool with mutable class attributes? Maybe we can think of that later

(If the answer is that this makes mypy unhappy then hmmmmmmmm)

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, we can go with either.

torchvision/transforms/v2/_transform.py Outdated Show resolved Hide resolved
Comment on lines +564 to +567
if torch.jit.is_scripting() or is_simple_tensor(inpt):
return equalize_image_tensor(inpt)
elif isinstance(inpt, PIL.Image.Image):
return equalize_image_pil(inpt)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can worry about all that later, but will we be able to also register kernels for simple tensors and for pil? Or are we going to hit JIT issues?

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would love that, since we would remove the special casing. I don't think JIT will an issue as long as we have a

if torch.jit.is_scripting():
    return ...

at the top.

@@ -289,6 +289,8 @@ class RandomEqualize(_RandomApplyTransform):

_v1_transform_cls = _transforms.RandomEqualize

_USED_DISPATCHERS = {F.equalize}

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

QQs:

  1. Do we expect more than one dispatcher per transform?
  2. does this need really to be a set?

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment on lines +25 to +28
types = _get_supported_datapoint_classes(dispatchers.pop())
for dispatcher in dispatchers:
types &= _get_supported_datapoint_classes(dispatcher)
self._supported_datapoint_types = types

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this just

types = {}
for dispatcher in dispatchers:
    types.update(_get_supported_datapoint_classes(dispatcher))

?

and _get_supported_datapoint_classes probably doesn't need to return a set if we do that

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wanted to make sure that we only go for the common ground. Meaning, we only transfer types that are support by all dispatchers. This is why I have the intersection & in there.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It makes we wonder about the semantic and the use of _supported_datapoint_types.. what if there's a transform with multiple dispatchers where each of them handle a different part of the input? Would this transform end-up being a passthrough for everything?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean typically if we're using the intersection, we can't just have a simple logic in the transforms like

if type(inpt) not in self._supported_datapoint_types:
    # pass-through inpt

We can't use the union either because that would pass some inputs to some dispatchers that don't support them (and error).

Seems like we have to implement the pass-through logic on a per-dispatcher basis instead (within the transforms).

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like we have to implement the pass-through logic on a per-dispatcher basis instead (within the transforms).

You are right. Let me see how this fits in

@pmeier
Copy link
Owner Author

pmeier commented Jul 28, 2023

Most of our transforms don't touch _transformed_types at all. Here is a list of the ones that do

  • Dispatcher does not support arbitrary dispatch and would error for unsupported datapoints. Thus, the transform has to filter them out
  • The dispatcher supports arbitrary dispatch, but the transform filters for performance reasons (not sure the second half is true, but I can't think of anything else)
    • Grayscale
    • RandomGrayscale
  • Transform has no corresponding dispatcher and thus has to filter itself. We should look into why this is the case. Maybe it has a good reason, but maybe it is just an oversight?
    • LinearTransformation
  • Transform only applies to a single type by design. The corresponding dispatcher has no kernels registered
    • ConvertBoundingBoxFormat
    • ClampBoundingBox
  • Transform converts the actual type and thus applies only to a few specific types by design. The corresponding dispatcher has no kernels registered
    • ToTensor
    • PILToTensor
    • ToImageTensor
    • ToImagePIL
  • Transform supports tensor images, but no PIL images:

Apart from these groups, there is also Lambda. It takes the supported types as parameter. Nothing would change with the design in the PR, since it already circumvents _transformed_types.

Co-authored-by: Nicolas Hug <contact@nicolas-hug.com>
@pmeier pmeier deleted the branch kernel-registration August 7, 2023 08:05
@pmeier pmeier closed this Aug 7, 2023
@pmeier pmeier deleted the automatic-supported-types branch August 7, 2023 10:48
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants