-
Notifications
You must be signed in to change notification settings - Fork 27k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Mask2Former
] Move normalization for numerical stability
#29542
[Mask2Former
] Move normalization for numerical stability
#29542
Conversation
Mask2Former
] Move normalization for numerical stability
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
cross_entropy_loss_pos = criterion(inputs, torch.ones_like(inputs)) / height_and_width | ||
cross_entropy_loss_neg = criterion(inputs, torch.zeros_like(inputs)) / height_and_width | ||
|
||
loss_pos = torch.matmul(cross_entropy_loss_pos, labels.T) | ||
loss_neg = torch.matmul(cross_entropy_loss_neg, (1 - labels).T) | ||
loss = loss_pos + loss_neg |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fine for me. But the name might be a bit misleading (I am not 100% sure): does criterion(inputs, torch.ones_like(inputs)) / height_and_width
really representing cross_entropy_loss
?
I am not checking the full context here, but personally, I might just do something like
cross_entropy_loss_pos = criterion(inputs, torch.ones_like(inputs))
cross_entropy_loss_neg = criterion(inputs, torch.zeros_like(inputs))
loss_pos = torch.matmul(cross_entropy_loss_pos / height_and_width, labels.T)
loss_neg = torch.matmul(cross_entropy_loss_neg / height_and_width, (1 - labels).T)
to avoid (if any) possible confusion.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point! Used your suggestion
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
51acd85
to
1a3c162
Compare
* Move normalization for numerical stability * Apply suggestions from code review Remove useless x=x line * PR comment - normalize later to preserve var name meaning
What does this PR do?
Moving the normalization before the matmul operation makes the calculation more stable and less likely to overflow.
Same differences introduce in #26086 which was closed after becoming stale.