Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

remove default value from LabelToOneHot #7173

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

pmeier
Copy link
Collaborator

@pmeier pmeier commented Feb 3, 2023

As noted in #7171 (comment), the default value of -1 is sketchy at best. It can easily lead to silent bugs if the inference is wrong. Inferring from the categories field of Label is better in the sense that it will give the right answer. However, since categories is optional, we can't rely on it either.

Thus, this PR removes the default value and users have to specify the number of categories explicitly.

One other option that depends on whether we move forward with #7171 or not is this: we use num_categories: Optional[int] = None and error on < 1 in the constructor. Inside _transform we check whether either self.num_categories or label.categories is set and error otherwise.

cc @vfdev-5 @bjuncek

Copy link
Collaborator

@vfdev-5 vfdev-5 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@NicolasHug
Copy link
Member

I'm don't have major converns regarding removing the default as it's always possible to re-add it later smoothly, however I'd like to understand better what could go critically wrong if we left the -1 default?

Obviously, if the input doesn't contain the smallest and largest value, this will give false results.

Side note that the above isn't strictly correct as the -1 heuristic only relies on the max value (not the min); Does that change anything to the potential silently wrong results?

Also, if we really want to prevent bad stuff from happening, should we check against -1 and raise an error? Right now, when this PR is merged users will still be able to do LabelToOneHot(-1) and rely on the heuristic.

@pmeier
Copy link
Collaborator Author

pmeier commented Feb 6, 2023

Side note that the above isn't strictly correct as the -1 heuristic only relies on the max value (not the min); Does that change anything to the potential silently wrong results?

Thanks for noticing. It improves the situation, but doesn't resolve it.

I'd like to understand better what could go critically wrong if we left the -1 default?

Let's take ImageNet for example, since for a higher number of classes this becomes more pronounced. LabelToOneHot is usually used together with MixUp and CutMix, since they require it:

if not (
has_any(flat_inputs, datapoints.Image, datapoints.Video, is_simple_tensor)
and has_any(flat_inputs, datapoints.OneHotLabel)
):
raise TypeError(f"{type(self).__name__}() is only defined for tensor images/videos and one-hot labels.")

Meaning, we are hitting this transformation after batching was performed. For example,

import torch

from torchvision.prototype import datapoints, transforms

num_categories = 1_000
batch_size = 8

torch.manual_seed(0)
label = datapoints.Label(torch.randint(num_categories, (batch_size,), dtype=torch.int64))
print(label)

transform = transforms.Compose(
    [
        transforms.LabelToOneHot(),
        transforms.ToDtype(torch.float32),
    ]
)
one_hot_label = transform(label)
print(one_hot_label.shape)

image = datapoints.Image(torch.rand(batch_size, 3, 32, 32))

# transform works without issues
transform = transforms.RandomMixup(alpha=0.5, p=1.0)
mixed_sample = transform(image, one_hot_label)

# coming from the model that takes `num_categories` as input
prediction_logits = torch.randn(batch_size, num_categories)

# simulate loss function
assert prediction_logits.shape == one_hot_label.shape
Label([ 44, 239, 933, 760, 963, 379, 427, 503])
torch.Size([8, 964])
AssertionError

So, I was wrong above in saying that this could cause silent errors. It will error in the evaluation stage. Still, it leaves the user searching for the problem, where we could have easily prevented it by raising an expressive error message. Note that we don't require any information that is not already there. Defining the model also requires the number of categories

num_classes (int): Number of classes. Default: 1000.

so the user only has to set it explicitly on the transforms as well.

Just to make this clear: I don't want a default value like what is used for the models either. This value is closely tied to a dataset, but the transform should be more general.

Also, if we really want to prevent bad stuff from happening, should we check against -1 and raise an error?

👍 I'll send a patch.

@pmeier
Copy link
Collaborator Author

pmeier commented Feb 6, 2023

Following the conclusion in #7171 (comment), we will no longer remove the .categories attribute. Thus, I've opted to include the dynamic num_categories handling explained in my top post.

Note for reviewers: this PR is now out of scope for the upcoming first release of transforms v2 and thus has no priority in reviewing. I'll ping you again when the dust has settled and we can deal with this again.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants