-
Notifications
You must be signed in to change notification settings - Fork 7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
enable get_params alias for transforms v2 #7153
Conversation
_v1_transform_cls: Optional[Type[nn.Module]] = None | ||
|
||
def __init_subclass__(cls) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As an alternative to __init_subclass__()
(of which I keep forgetting the existence and the purpose), would this work too?
@staticmethod
def get_params(cls):
if cls._v1_transform_cls is not None and hasattr(cls._v1_transform_cls, "get_params"):
return cls._v1_transform_cls.get_params()
else:
raise AttributeError(
"cls {cls} has no get_params method. You probably don't need one anymore
as the same RNG is applied to all images, bboxes and masks in the same transform call.
If what you need is a way to transform different batches with the same RNG,
please reach out at #1234567 (the feedback issue.
")
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Two problems here with JIT:
-
get_params
takes parameters in v1, e.g.vision/torchvision/transforms/transforms.py
Lines 617 to 618 in 7cf0f4c
@staticmethod def get_params(img: Tensor, output_size: Tuple[int, int]) -> Tuple[int, int, int, int]: Since I'm not aware of an arbitrary parameter passthrough like
(*args: Any, **kwargs: Any)
that works with JIT, we would have to implement this manually on every class that needs it. -
You used the
@staticmethod
decorator, but takecls
as first input. Since you needcls
, we need to switch to@classmethod
to make it work. That should work in eager mode, but for JIT this now means that the whole class needs to be scriptable. For example, manually definingget_params
forRandomCrop
as explained in 1. and trying to script it, leads toRuntimeError: 'Tensor (inferred)' object has no attribute or method '_v1_transform_cls'.: File "/home/philip/git/pytorch/vision/torchvision/prototype/transforms/_geometry.py", line 437 @classmethod def get_params(cls, img: torch.Tensor, output_size: Tuple[int, int]) -> Tuple[int, int, int, int]: return cls._v1_transform_cls.get_params(img, output_size) ~~~~~~~~~~~~~~~~~~~~~ <--- HERE
It seems it can't infer the type of
cls
and usestorch.Tensor
as fallback. Annotationcls: Type[RandomCrop]
yieldsRuntimeError: Unknown type constructor Type:
I agree using __init_subclass__
is unconventional, but it seems like the cleanest solution here. Since we actually alias the function, we avoid all of the JIT crazyness that we would have to deal with otherwise. If you can find a working solution, I'm happy to adopt it though.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the details. My last try would be to declare get_params
as a @property
but... let me guess... JIT doesn't support it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You guessed right. But even if did, it would wouldn't work for us here. @property
needs an instance.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks Philip, LGTM, just some minor comments / Qs before approving
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot Philip
Reviewed By: vmoens Differential Revision: D43116104 fbshipit-source-id: 2e1139ef8da93850de7945baa4ea0d4f7ca667cc
After #7135, this is trivial to achieve. If available, we simply alias
get_params
from the v1 transformation to v2.Although this doesn't sound like much, it is the final step of getting v2 BC with v1. Meaning, leaving bugs aside, v2 should now be a proper drop-in replacement for v1 🎉
cc @vfdev-5 @bjuncek