-
Notifications
You must be signed in to change notification settings - Fork 7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
refactor Datapoint dispatch mechanism #7747
Conversation
@@ -158,6 +158,32 @@ def _compute_resized_output_size( | |||
return __compute_resized_output_size(spatial_size, size=size, max_size=max_size) | |||
|
|||
|
|||
def resize( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We need to move the definition of the dispatcher above the kernel definitions, since the dispatcher is used in the decorator. Other than that, only the datapoint branch was changed.
I was also looking into a more involved solution, since we still have quite a bit of boilerplate. I'll put it down here as idea. import torch
from torchvision import datapoints
from torchvision.transforms.v2.utils import is_simple_tensor
class dispatcher:
def __init__(self, dispatcher_fn):
self._dispatcher_fn = dispatcher_fn
self._kernels = {}
def register(self, datapoint_cls):
def decorator(kernel):
self._kernels[datapoint_cls] = kernel
return kernel
return decorator
def __call__(self, inpt, *args, **kwargs):
dispatch_cls = datapoints.Image if is_simple_tensor(inpt) else type(inpt)
kernel = self._kernels.get(dispatch_cls)
if kernel:
return kernel(inpt, *args, **kwargs)
output = self._dispatcher_fn(inpt, *args, **kwargs)
if output is not None:
return output
raise TypeError(f"Got input of type {dispatch_cls}, but only {self._kernels.keys()} are supported")
@dispatcher
def resize(inpt, size, max_size=None):
# We have the chance to handle any object here, for which no kernel is registered
# This is useful for uncommon dispatchers like convert_bounding_box_format or any non-standard dispatcher.
# If we return something here, this will be the output of the dispatcher.
# Otherwise we error out since we don't know how to handle the object.
if isinstance(inpt, str):
return "Boo!"
@resize.register(datapoints.Image)
def resize_image_tensor(image, size, max_size=None):
return image.new_zeros(size)
input = datapoints.Image(torch.ones(10, 10))
size = (20, 20)
output = resize(input, size)
assert (output == torch.zeros(size)).all()
assert resize("Boo?", size) == "Boo!" Of course this solution is more complex, but it has a few upsides:
However, it also has one glaring hole: JIT. Due to our usage of Similar to what we do in def __prepare_scriptable__(self):
kernel = self._kernels.get(datapoints.Image)
if kernel is None:
raise RuntimeError("Dispatcher cannot be scripted")
return kernel Unfortunately, this hook is only available for diff --git a/torch/jit/_script.py b/torch/jit/_script.py
index a6b2cb9cea7..cb37f372028 100644
--- a/torch/jit/_script.py
+++ b/torch/jit/_script.py
@@ -967,6 +967,7 @@ else:
super().__init__()
def call_prepare_scriptable_func_impl(obj, memo):
+ obj = obj.__prepare_scriptable__() if hasattr(obj, '__prepare_scriptable__') else obj # type: ignore[operator]
if not isinstance(obj, torch.nn.Module):
return obj
@@ -977,7 +978,6 @@ def call_prepare_scriptable_func_impl(obj, memo):
if obj_id in memo:
return memo[id(obj)]
- obj = obj.__prepare_scriptable__() if hasattr(obj, '__prepare_scriptable__') else obj # type: ignore[operator]
# Record obj in memo to avoid infinite recursion in the case of cycles in the module
# hierarchy when recursing below.
memo[obj_id] = obj Of course we need to talk to core before, but my guess is that don't really care as long as it doesn't break anything, given that JIT is only in maintenance. Thoughts? |
Would it make anything simpler if we were to write our dispatcher kernels like this? def resize(...):
if isinstance(datapoints.Image):
return image_kernel(...) # we can pass the strict subset of parameters that are needed here, no need for an extra layer
elif isinstance(datapoints.BoundingBox):
return bbox_kernel(...)
# ...
elif isinstance(datapoints.Datapoint): # for all user-defined kernels
return _get_kernel(...) # pass all parameters, let users handle that themselves Basically we hard-code the dispatching logic for the torchvision-owned datapoint classes, but we still allow arbitrary user-registered kernels? |
Also 2 things come to mind:
(None of these are "new" problems, they exist with what we already have, but it's worth thinking about) |
It would make it simpler in the sense that we don't need an extra layer ever. However it adds quite the amount of boilerplate. I know that the two of us ride on opposite sides of the simplicity / boilerplate <-> complexity / "magic" spectrum. So not sure here. One upside that I see in not hardcoding our datapoints is that we would use exactly the same mechanism as the users. Meaning, users could look at our source to see how it is done.
No, I don't see this as a good thing and yes I do care. Imagine there is a third-party library that builds upon TorchVision. They could register something new for images at import such that the TorchVision behavior is different whether or not the import of the other library is present. I don't think that is acceptable. However, I think there is a fairly straight forward solution that doesn't require us to hardcode our datapoints:
Yes, users would need to always add We don't need Regarding JIT, neither the current nor my or your hardcoding proposal supports JIT for dispatch. We basically only support JIT for BC and thus I wouldn't worry about it here. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK.
No, I don't see this as a good thing and yes I do care
I don't have a strong opinion on that yet... I'm OK to guard against this as long as we keep it very simple and that there's no perf implication.
OK as well regarding BC guarantees. We'll have to remember to document the need for **kwargs
properly.
I'm still uncomfortable with the wrapping / unwrapping happening in the registration decorator. It has some smell.
Can we completely separate the registration from the wrapping/unwrapping by just registering the already-decorated kernel manually, without a decorator? e.g.
def _unwrap_then_call_then_rewrap(...):
# some basic helper
_KERNEL_REGISTRY = {
(resize, Image): _unwrap_then_call_then_rewrap(resize_image_tensor)
(resize, BBox): ...,
}
This way we keep resize_image_tensor
intact?
And maybe we can have another helper for the parameer-subsetting logic to avoid writing those additional kernels (but I'm not sure what it would look like yet, haven't thought about it)
Pushed 4 new commits that are design related. I'll go over them in detail below. I think when we have that done, I can start porting everything. bbaa35c: Per offline request from @vfdev-5, I've added the new dispatch logic to a dispatcher that needs to pass through some datapoints. That was a good call, since the previous design did not account for that and would have errored out. This is solved by letting ca4ad32: Acting on the second point in #7747 (comment), we no longer allow users to overwrite registered kernels. However, PoC ca4ad32: After #7747 (review) I had an offline discussion with @NicolasHug and we came to the conclusion that we don't want to expose the (un-)wrapping magic to the users. This is a fairly TorchVision centric thing. Of course the user also need to at least handle the re-wrapping before returning from the kernel, but it should be sufficient to properly document this rather to provide magic helpers that might go wrong. However, we also agreed that it would be nice to have this kind of convenience for our internal stuff. This commit adds a
With this we should be able to remove 99% of the explicit intermediate layers that we needed a lot before. Meaning, we can decorate mask and bounding box kernels directly. PoC bf47188: Due to the no-op behavior for unregistered kernels (see explanation for the first commit in this comment), it seems like we don't need to register no-ops for builtin datapoints, if the dispatcher does not support it. However, this would allow users to register kernels for builtin datapoints on builtin dispatchers that rely on this no-op behavior. For example from torchvision.transforms.v2 import functional as F
@F.register_kernel(F.adjust_brightness, datapoints.BoundingBox)
def lol(...):
... would be valid user code. That is likely not what we want. This PR adds a |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks Philip, I made some comments but this looks mostly OK to me.
One thing though, as discussed offline and as @vfdev-5 pointed-out, I think there's value in manually registering the kernels into the dict in one centralized place, instead of using the _register_kernel_internal
as a decorator or _register_explicit_noops()
in some random places. It will be easier to read the code and understand the code-paths being taken if we register everything e.g. in __init__.py
like
_KERNEL_REGISTRY = {
# all (wrapped) registration here
}
(This is non-blocking though)
for registered_cls, kernel in registry.items(): | ||
if issubclass(datapoint_cls, registered_cls): | ||
return kernel |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just to make sure I understand: this is just for subclasses of Datapoints, which we don't have in our code-based, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Correct. We currently don't have anything like that, but
- We were talking about a
class DetectionMask(Mask)
andSegmentationMask(Mask)
. Unless we want to register them separately instead of having one kernel forMask
's, we need to keep this loop. - Users can do something like
class MyImage(Image)
if datapoint_cls in registry: | ||
return registry[datapoint_cls] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this already covered by the for
loop below? (No strong opinion)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since our stuff is the first thing to be registered, you are correct. I thought I make it explicit that an exact type match always beats the subclass check regardless of registration order.
Imagine the following
class Foo(datapoints.Datapoint):
...
class Bar(Foo):
...
def dispatcher(inpt):
...
@F.register_kernel(dispatcher, Foo)
def foo_kernel(foo):
...
@F.register_kernel(dispatcher, Bar)
def bar_kernel(bar):
...
dispatcher(Foo(...)) # calls foo_kernel, ok
dispatcher(Bar(...)) # calls foo_kernel, oops
This scenario is unlikely for now. However, who knows how our stuff is adopted in the future.
registry = _KERNEL_REGISTRY.setdefault(dispatcher, {}) | ||
if datapoint_cls in registry: | ||
raise TypeError( | ||
f"Dispatcher '{dispatcher.__name__}' already has a kernel registered for type '{datapoint_cls.__name__}'." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's OK to be restrictive for now, but I think we should keep an open mind and potentially re-consider eventually if there are user-requests for that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My concern here is that a library based on TorchVision registers something for our builtin stuff and thus invalidates our docs if we import the library. Imagine a library does this
@F.register_kernel(F.resize, datapoints.Image)
def lol(...):
...
I can already see a user of that library coming to our issue tracker with something like
I see these weird behavior for resize that is not consistent with the docs. Here is my env
And after some back and forth we find that they have import third_party_library
at the top of his script.
I want to prevent this. Especially since with the subclass check, there is no reason for it. The third-party library can simply do
class CustomImage(datapoints.Image):
...
@F.register_kernel(F.resize, CustomImage)
def kernel(...):
...
Now, if the user deals with plain datapoints.Image
's they will always get our behavior.
@vfdev-5 I've implemented your suggestion of a central registry here: a9125f1. I didn't push yet another PoC to this branch and make it even more confusing. From my perspective:
Overall, I personally would stick with the "decentralized" registration using decorators. |
@pmeier let's keep original approach with fully dynamic registration mechanism using the decorator |
I found a few outliers that we need to make decisions for:
They all currently don't support arbitrary datapoint dispatch, but rather do vision/torchvision/transforms/v2/functional/_augment.py Lines 52 to 57 in 72dcc17
I do not recall why we went for that design. I think it was for "these transforms are not defined for bounding boxes and masks". However we circumvent this from the transform side by just not passing these types to the dispatcher at all
Thus, we have no-op behavior from the transform API, but error out on the dispatcher directly. This makes little sense. Plus, it makes it impossible for users to register custom kernels. I propose we fix this inconsistency by enabling arbitrary datapoint dispatch on these dispatchers. |
It makes sense to me. We have to pass-through in the transforms because these transforms can be called on any input and the input may contains some bboxes etc. The transform may not apply to bboxes, but if it applies to images. We can't just error if there's a bbox in a sample. OTOH, the dispatchers only take one specific input (either and image OR a bbox, not both). There's no reason to allow We don't need to register a kernel for everything. |
Agreed. Examples for this are
|
Conflicts: test/test_transforms_v2_refactored.py torchvision/datapoints/_bounding_box.py torchvision/transforms/v2/functional/_color.py torchvision/transforms/v2/functional/_geometry.py
The commit cac079b in particular, but also the ones after address #7747 (comment). We had a longer offline discussion and these are the conclusions (cherry-picked from a summary that @NicolasHug wrote, with a few changes from me):
|
("dispatcher", "registered_datapoint_clss"), | ||
[(dispatcher, set(registry.keys())) for dispatcher, registry in _KERNEL_REGISTRY.items()], | ||
) | ||
def test_exhaustive_kernel_registration(dispatcher, registered_datapoint_clss): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This test will also be removed in the future, since we'll remove the passthrough behavior and thus the noop registration. But let's keep it until that happens to be sure this PR in this intermediate design stage is good as is.
Test failures for normalize dispatch are expected, since I haven't properly fixed the tests yet. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
31bee5f is the bulk of the port after we have "finalized" the design:
- Remove all methods from the datapoints classes
- Expose
Datapoint
intorchvision.datapoints
- Use the new dispatch mechanism in the dispatchers.
@@ -214,34 +213,32 @@ def check_dispatcher( | |||
check_dispatch=True, | |||
**kwargs, | |||
): | |||
unknown_input = object() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Driveby. Minor refactoring to avoid calling the dispatcher twice.
if check_scripted_smoke: | ||
_check_dispatcher_scripted_smoke(dispatcher, input, *args, **kwargs) | ||
|
||
if check_dispatch: | ||
_check_dispatcher_dispatch(dispatcher, kernel, input, *args, **kwargs) | ||
|
||
|
||
def _check_dispatcher_kernel_signature_match(dispatcher, *, kernel, input_type): | ||
def check_dispatcher_kernel_signature_match(dispatcher, *, kernel, input_type): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Rename, since the other check for the signature match of the datapoint method is removed. Meaning, this is the "single entrypoint" now.
): | ||
if name in args_kwargs.kwargs: | ||
del args_kwargs.kwargs[name] | ||
if hasattr(datapoint_type, "__annotations__"): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We need to guard since for the reason I mentioned above.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for this great effort Philip. I only have minor Qs but this LGTM. I can address my own comments if that helps.
Co-authored-by: Nicolas Hug <contact@nicolas-hug.com>
Hey @pmeier! You merged this PR, but no labels were added. The list of valid labels is available at https://github.com/pytorch/vision/blob/main/.github/process_commit.py |
Although the sun is setting for torchscript, it is not [officially deprecated](#103841 (comment)) since nothing currently fully replaces it. Thus, "downstream" libraries like TorchVision, that started offering torchscript support still need to support it for BC. torchscript has forced us to use workaround after workaround since forever. Although this makes the code harder to read and maintain, we made our peace with it. However, we are currently looking into more elaborate API designs that are severely hampered by our torchscript BC guarantees. Although likely not intended as such, while looking for ways to enable our design while keeping a subset of it scriptable, we found the undocumented `__prepare_scriptable__` escape hatch: https://github.com/pytorch/pytorch/blob/0cf918947d161e02f208a6e93d204a0f29aaa643/torch/jit/_script.py#L977 One can define this method and if you call `torch.jit.script` on the object, the returned object of the method will be scripted rather than the original object. In TorchVision we are using exactly [this mechanism to enable BC](https://github.com/pytorch/vision/blob/3966f9558bfc8443fc4fe16538b33805dd42812d/torchvision/transforms/v2/_transform.py#L122-L136) while allowing the object in eager mode to be a lot more flexible (`*args, **kwargs`, dynamic dispatch, ...). Unfortunately, this escape hatch is only available for `nn.Module`'s https://github.com/pytorch/pytorch/blob/0cf918947d161e02f208a6e93d204a0f29aaa643/torch/jit/_script.py#L1279-L1283 This was fine for the example above since we were subclassing from `nn.Module` anyway. However, we recently also hit a case [where this wasn't the case](pytorch/vision#7747 (comment)). Given the frozen state on JIT, would it be possible to give us a general escape hatch so that we can move forward with the design unconstrained while still keeping BC? This PR implements just this by re-using the `__prepare_scriptable__` hook. Pull Request resolved: #106229 Approved by: https://github.com/lezcano, https://github.com/ezyang
Summary: Co-authored-by: Nicolas Hug <contact@nicolas-hug.com> Reviewed By: matteobettini Differential Revision: D48642281 fbshipit-source-id: 33a1dcba4bbc254a26ae091452a61609bb80f663
We are currently discussing removing the methods from the datapoints and replacing it with a registration mechanism for the dispatchers. This PR is a PoC for the latter. If we find this a viable approach, I'll post a more detailed plan in #7028 to get more feedback.
The overarching goal is twofold:
vision/torchvision/datapoints/_datapoint.py
Lines 102 to 112 in 29418e3
I'll add some inline comments to explain, but here is how one can author a custom datapoint with this PoC:
cc @vfdev-5