-
Notifications
You must be signed in to change notification settings - Fork 4
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
PoC arbitrary dispatch #3
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks Philip, looks good, I'm just not sure about the need for sets
# We support both Types and callables that are able to do further checks on the type of the input. | ||
_transformed_types: Tuple[Union[Type, Callable[[Any], bool]], ...] = (torch.Tensor, PIL.Image.Image) | ||
# This needs to be populated by the subclasses | ||
_USED_DISPATCHERS: Set[Type] = {} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure about the pros and cons of making this None
instead... Are we cool with mutable class attributes? Maybe we can think of that later
(If the answer is that this makes mypy unhappy then hmmmmmmmm)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, we can go with either.
if torch.jit.is_scripting() or is_simple_tensor(inpt): | ||
return equalize_image_tensor(inpt) | ||
elif isinstance(inpt, PIL.Image.Image): | ||
return equalize_image_pil(inpt) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can worry about all that later, but will we be able to also register kernels for simple tensors and for pil? Or are we going to hit JIT issues?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would love that, since we would remove the special casing. I don't think JIT will an issue as long as we have a
if torch.jit.is_scripting():
return ...
at the top.
@@ -289,6 +289,8 @@ class RandomEqualize(_RandomApplyTransform): | |||
|
|||
_v1_transform_cls = _transforms.RandomEqualize | |||
|
|||
_USED_DISPATCHERS = {F.equalize} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
QQs:
- Do we expect more than one dispatcher per transform?
- does this need really to be a set?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we expect more than one dispatcher per transform?
Yes. For example
- https://github.com/pytorch/vision/blob/3966f9558bfc8443fc4fe16538b33805dd42812d/torchvision/transforms/v2/_color.py#L162-L177
- https://github.com/pytorch/vision/blob/3966f9558bfc8443fc4fe16538b33805dd42812d/torchvision/transforms/v2/_geometry.py#L919-L927
does this need really to be a set?
No. What would you prefer?
types = _get_supported_datapoint_classes(dispatchers.pop()) | ||
for dispatcher in dispatchers: | ||
types &= _get_supported_datapoint_classes(dispatcher) | ||
self._supported_datapoint_types = types |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this just
types = {}
for dispatcher in dispatchers:
types.update(_get_supported_datapoint_classes(dispatcher))
?
and _get_supported_datapoint_classes
probably doesn't need to return a set if we do that
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wanted to make sure that we only go for the common ground. Meaning, we only transfer types that are support by all dispatchers. This is why I have the intersection &
in there.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It makes we wonder about the semantic and the use of _supported_datapoint_types
.. what if there's a transform with multiple dispatchers where each of them handle a different part of the input? Would this transform end-up being a passthrough for everything?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I mean typically if we're using the intersection, we can't just have a simple logic in the transforms like
if type(inpt) not in self._supported_datapoint_types:
# pass-through inpt
We can't use the union either because that would pass some inputs to some dispatchers that don't support them (and error).
Seems like we have to implement the pass-through logic on a per-dispatcher basis instead (within the transforms).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems like we have to implement the pass-through logic on a per-dispatcher basis instead (within the transforms).
You are right. Let me see how this fits in
Most of our transforms don't touch
Apart from these groups, there is also |
Co-authored-by: Nicolas Hug <contact@nicolas-hug.com>
@NicolasHug @vfdev-5 following our offline discussion. This is just a PoC for
F.equalize
/transforms.RandomEqualize
. I have to clean up a little, but the general structure is there.