Skip to content

Conversation

@pmeier
Copy link
Contributor

@pmeier pmeier commented Feb 3, 2023

Addresses the thread in #6663 (comment).

cc @vfdev-5 @bjuncek

Copy link
Member

@NicolasHug NicolasHug left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Philip, gave it a quick look

Copy link
Member

@NicolasHug NicolasHug left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Philip, some minor comments but it looks great

@pmeier
Copy link
Contributor Author

pmeier commented Feb 6, 2023

There are three transforms for which this heuristic is somewhat awkward and all for the same reason:

  • ToDtype
  • PermuteDimensions
  • TransposeDimensions

They all take an input argument that can be a dictionary and the transform selects the appropriate value based on the input type. For example:

import torch
from torchvision.prototype import datapoints, transforms

sample = dict(
    image=datapoints.Image(torch.randint(0, 256, (3, 32, 32), dtype=torch.uint8)),
    boxes=datapoints.BoundingBox(torch.randint(0, 32, (5,)), format="xyxy", spatial_size=(32, 32)),
)

dtype = {
    datapoints.Image: torch.float32,
    datapoints.BoundingBox: torch.float64,
}
transform = transforms.ToDtype(dtype)

transformed_sample = transform(sample)

assert transformed_sample["image"].dtype is torch.float32
assert transformed_sample["boxes"].dtype is torch.float64

sample["tensor"] = torch.rand((3, 16, 16), dtype=torch.float16)
dtype[torch.Tensor] = torch.int32

transform = transforms.ToDtype(dtype)
transformed_sample = transform(sample)

assert transformed_sample["image"].dtype is torch.float32
assert transformed_sample["boxes"].dtype is torch.float64
assert transformed_sample["tensor"].dtype is torch.int32  # boom

As shown above, the transform is not applied to the plain tensor according to the heuristic above. That is somewhat awkward since we specified it explicitly in parameter. The example above works on main.

I guess one way to fix this is to disallow torch.Tensor in the parameter dictionary. Better yet, only allow datapoints. Meaning, if someone wants to use this fine-grained control, they'll have to wrap their inputs.

Thoughts?

Copy link
Member

@NicolasHug NicolasHug left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Philip, NIT but LGTM regardless

@NicolasHug
Copy link
Member

Sorry I had missed your comment above before I approved.

This is unrelated to this issue, but I'm tempted to keep PermuteDimensions and TransposeDimensions in the prototype area for now, because they break a lot of assumptions; so I'll just focus on ToDtype here.

I guess one way to fix this is to disallow torch.Tensor in the parameter dictionary. Better yet, only allow datapoints. Meaning, if someone wants to use this fine-grained control, they'll have to wrap their inputs.

Can we just raise a warning specific to ToDtype() if the Tensor key is specified along Image and Video saying

Hey, you passed Tensors and Images (or Videos), but we won't be transforming the tensors

and still support the Tensor key if neither Image or Video are specified?

Copy link
Member

@NicolasHug NicolasHug left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Still LGTM!

@pmeier pmeier merged commit 1120aa9 into pytorch:main Feb 8, 2023
@pmeier pmeier deleted the tensor-fallback-heuristic branch February 8, 2023 19:02
facebook-github-bot pushed a commit that referenced this pull request Mar 28, 2023
… v2 (#7170)

Reviewed By: vmoens

Differential Revision: D44416271

fbshipit-source-id: 20c92067665ea106550bc29947f2596a36000025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants