Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

enable get_params alias for transforms v2 #7153

Merged
merged 4 commits into from
Feb 1, 2023
Merged

Conversation

pmeier
Copy link
Collaborator

@pmeier pmeier commented Jan 31, 2023

After #7135, this is trivial to achieve. If available, we simply alias get_params from the v1 transformation to v2.

Although this doesn't sound like much, it is the final step of getting v2 BC with v1. Meaning, leaving bugs aside, v2 should now be a proper drop-in replacement for v1 🎉

cc @vfdev-5 @bjuncek

_v1_transform_cls: Optional[Type[nn.Module]] = None

def __init_subclass__(cls) -> None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As an alternative to __init_subclass__() (of which I keep forgetting the existence and the purpose), would this work too?

@staticmethod
def get_params(cls):
   if cls._v1_transform_cls is not None and hasattr(cls._v1_transform_cls, "get_params"):
        return cls._v1_transform_cls.get_params()
    else:
        raise AttributeError(
            "cls {cls} has no get_params method. You probably don't need one anymore
             as the same RNG is applied to all images, bboxes and masks in the same transform call.
             If what you need is a way to transform different batches with the same RNG,
             please reach out at #1234567 (the feedback issue.
         ")

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Two problems here with JIT:

  1. get_params takes parameters in v1, e.g.

    @staticmethod
    def get_params(img: Tensor, output_size: Tuple[int, int]) -> Tuple[int, int, int, int]:

    Since I'm not aware of an arbitrary parameter passthrough like (*args: Any, **kwargs: Any) that works with JIT, we would have to implement this manually on every class that needs it.

  2. You used the @staticmethod decorator, but take cls as first input. Since you need cls, we need to switch to @classmethod to make it work. That should work in eager mode, but for JIT this now means that the whole class needs to be scriptable. For example, manually defining get_params for RandomCrop as explained in 1. and trying to script it, leads to

    RuntimeError: 
    'Tensor (inferred)' object has no attribute or method '_v1_transform_cls'.:
      File "/home/philip/git/pytorch/vision/torchvision/prototype/transforms/_geometry.py", line 437
        @classmethod
        def get_params(cls, img: torch.Tensor, output_size: Tuple[int, int]) -> Tuple[int, int, int, int]:
            return cls._v1_transform_cls.get_params(img, output_size)
                   ~~~~~~~~~~~~~~~~~~~~~ <--- HERE
    

    It seems it can't infer the type of cls and uses torch.Tensor as fallback. Annotation cls: Type[RandomCrop] yields

    RuntimeError: 
    Unknown type constructor Type:
    

I agree using __init_subclass__ is unconventional, but it seems like the cleanest solution here. Since we actually alias the function, we avoid all of the JIT crazyness that we would have to deal with otherwise. If you can find a working solution, I'm happy to adopt it though.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the details. My last try would be to declare get_params as a @property but... let me guess... JIT doesn't support it?

Copy link
Collaborator Author

@pmeier pmeier Feb 1, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You guessed right. But even if did, it would wouldn't work for us here. @property needs an instance.

Copy link
Member

@NicolasHug NicolasHug left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Philip, LGTM, just some minor comments / Qs before approving

torchvision/prototype/transforms/_transform.py Outdated Show resolved Hide resolved
test/test_prototype_transforms_consistency.py Show resolved Hide resolved
Copy link
Member

@NicolasHug NicolasHug left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot Philip

@pmeier pmeier merged commit 82c51c4 into pytorch:main Feb 1, 2023
@pmeier pmeier deleted the v2-get_params branch February 1, 2023 15:37
facebook-github-bot pushed a commit that referenced this pull request Feb 9, 2023
Reviewed By: vmoens

Differential Revision: D43116104

fbshipit-source-id: 2e1139ef8da93850de7945baa4ea0d4f7ca667cc
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants