-
Notifications
You must be signed in to change notification settings - Fork 7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
remove default value from LabelToOneHot #7173
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
I'm don't have major converns regarding removing the default as it's always possible to re-add it later smoothly, however I'd like to understand better what could go critically wrong if we left the
Side note that the above isn't strictly correct as the Also, if we really want to prevent bad stuff from happening, should we check against -1 and raise an error? Right now, when this PR is merged users will still be able to do |
Thanks for noticing. It improves the situation, but doesn't resolve it.
Let's take ImageNet for example, since for a higher number of classes this becomes more pronounced. vision/torchvision/prototype/transforms/_augment.py Lines 118 to 122 in 135a0f9
Meaning, we are hitting this transformation after batching was performed. For example, import torch
from torchvision.prototype import datapoints, transforms
num_categories = 1_000
batch_size = 8
torch.manual_seed(0)
label = datapoints.Label(torch.randint(num_categories, (batch_size,), dtype=torch.int64))
print(label)
transform = transforms.Compose(
[
transforms.LabelToOneHot(),
transforms.ToDtype(torch.float32),
]
)
one_hot_label = transform(label)
print(one_hot_label.shape)
image = datapoints.Image(torch.rand(batch_size, 3, 32, 32))
# transform works without issues
transform = transforms.RandomMixup(alpha=0.5, p=1.0)
mixed_sample = transform(image, one_hot_label)
# coming from the model that takes `num_categories` as input
prediction_logits = torch.randn(batch_size, num_categories)
# simulate loss function
assert prediction_logits.shape == one_hot_label.shape
So, I was wrong above in saying that this could cause silent errors. It will error in the evaluation stage. Still, it leaves the user searching for the problem, where we could have easily prevented it by raising an expressive error message. Note that we don't require any information that is not already there. Defining the model also requires the number of categories vision/torchvision/models/maxvit.py Line 586 in 135a0f9
so the user only has to set it explicitly on the transforms as well. Just to make this clear: I don't want a default value like what is used for the models either. This value is closely tied to a dataset, but the transform should be more general.
👍 I'll send a patch. |
Following the conclusion in #7171 (comment), we will no longer remove the Note for reviewers: this PR is now out of scope for the upcoming first release of transforms v2 and thus has no priority in reviewing. I'll ping you again when the dust has settled and we can deal with this again. |
As noted in #7171 (comment), the default value of
-1
is sketchy at best. It can easily lead to silent bugs if the inference is wrong. Inferring from thecategories
field ofLabel
is better in the sense that it will give the right answer. However, sincecategories
is optional, we can't rely on it either.Thus, this PR removes the default value and users have to specify the number of categories explicitly.
One other option that depends on whether we move forward with #7171 or not is this: we use
num_categories: Optional[int] = None
and error on< 1
in the constructor. Inside_transform
we check whether eitherself.num_categories
orlabel.categories
is set and error otherwise.cc @vfdev-5 @bjuncek