Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PoC] simplify simple tensor fallback heuristic #7340

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

pmeier
Copy link
Collaborator

@pmeier pmeier commented Feb 27, 2023

Since we had some internal discussions about the heuristic before and it came up again in #7331 (comment), this PR is an attempt to simplify it while adhering to the original goals. Let's start with a little bit of context:

When transforms v2 was conceived, one major design goal was to make it BC to v1. Part of that is that we need to treat simple torch.Tensor's as images and don't require users to wrap them into a datapoints.Image or similar. To achieve that the functional API internally just dispatches to the *_image_tensor kernel, e.g.

if torch.jit.is_scripting() or is_simple_tensor(inpt):
return horizontal_flip_image_tensor(inpt)

By not adding any logic other than allowing simple tensors to be transformed, the transforms inherited this behavior. However, this proved detrimental for two reasons:

  1. After the decision was made that we'll leave datapoints.Label and datapoints.OneHotLabel in the prototype area for now, we wanted to represent them as simple tensors
  2. Some datasets like CelebA return simple tensors as part of their annotations.

To support these use cases, the initial idea was to introduce a no-op datapoint ():

class DontTouchMe(Datapoint):
    pass

This could be easily filtered out by the transforms. However this again had two issues:

  1. Users are forced to address this situation by wrapping their simple tensor inputs that aren't images.
  2. We increase our API surface.

To overcome this, #7170 added a heuristic that currently behaves as follows:

# * If we find an explicit image or video (:class:`torchvision.datapoints.Image`, :class:`torchvision.datapoints.Video`,
# or :class:`PIL.Image.Image`) in the input, all other plain tensors are passed through.
# * If there is no explicit image or video, only the first plain :class:`torch.Tensor` will be transformed as image or
# video, while all others will be passed through.

This solves the issues above. However, it goes beyond the original goal of keeping BC: v1 does not support joint transformations and thus allowing simple tensors to act as images in a joint context is not needed for BC. And this is the part that makes the current heuristic more complicated than it has to be since it introduces stuff like order into the picture.

The heuristic this PR proposed goes a more pragmatic approach regarding BC:

  1. If the input is a single simple tensor, it will be transformed as image as it was done in v1.
  2. Otherwise, simple tensors will not be transformed, but rather be passed through.

The only thing we are losing by going for this simplification is the ability to intentionally use simple tensors as images in a joint setting. IMHO, it isn't a big ask of users to wrap into a datapoints.Image there, since they will have to wrap into datapoints.Mask's and datapoints.BoundingBox'es anyway in a joint context.

Copy link
Collaborator Author

@pmeier pmeier left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I intentionally only touched the Transform base class for now. Of course we also need to add these changes to other special transforms that overwrite forward.


flat_outputs = [
self._transform(inpt, params) if needs_transform else inpt
for (inpt, needs_transform) in zip(flat_inputs, needs_transform_list)
self._transform(inpt, params) if isinstance(inpt, self._transformed_types) else inpt for inpt in flat_inputs
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we no longer care about simple tensors here, we can use the regular isinstance and can potentially eliminate

def check_type(obj: Any, types_or_checks: Tuple[Union[Type, Callable[[Any], bool]], ...]) -> bool:

@NicolasHug
Copy link
Member

Thanks for the proposal @pmeier and for trying to simplify the heuristic.

IIUC this proposal makes this use-case invalid now:

transform(some_pure_tensor, bboxes, masks)

I agree this isn't something that was supported before and so BC requirements are lighter here. However, so far we've tried to hard to keep a strong equivalent between pure tensors and images (as e.g. reflected in the dispatchers or lower-level kernels). It makes the mental model simpler (although it seems to make our code harder, that's true).

I don't really know how I feel about this right now. On one hand it seems to simplify our code, but I'm not sure it simplifies the overall mental model for users. The current heuristic is somewhat tricky but it is irrelevant for 99% of users; but in contrast, the one introduced in this PR would have to be disclosed to all users. I'm also not sure how it plays out with the fact that we unwrap subclasses for any non-transform operation, which may force users to re-wrap their input manually into Images and create friction. There's also the fact that another (opposite) direction we could take right now would be to literally get rid of the Image subclass as it has no meta-data attached to it...

Hopefully we'll get more direct feedback from users about this and about the unwrapping mechanism, which will allow us to make a more informed decision?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants