-
Notifications
You must be signed in to change notification settings - Fork 86
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fixing out of bound access for nll_loss #1752
base: main
Are you sure you want to change the base?
Conversation
for more information, see https://pre-commit.ci
thunder/torch/__init__.py
Outdated
@@ -5111,12 +5111,18 @@ def _nll_loss_helper( | |||
bcast_weight = reshape(weight, [num_class] + [1 for _ in range(2, a.ndim)]) | |||
out = out * bcast_weight | |||
|
|||
assert isinstance(ignore_index, Number) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If ignore_index
is a NumberProxy are the constraints created as expected for >= 0
and < num_class
comparisons with the symbolic values cache mode?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For today with constant constraints, I think that's right. But then we could have the same out of bound access issue with NumberProxy for ignore_index.
If we do need to support that, I think for now we need to use < 0 or >= num_class
to handle it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
BTW, added the support.
For some reason thunderfx does indeed hit give ignore_index
as NumberProxy instead, you probably know this better than I do.
Fixes #1744
See torch.nn.functional.nll_loss definition here: https://pytorch.org/docs/stable/generated/torch.nn.functional.nll_loss.html
Because we use
prims.take_along_axis
, which doesn't cap out-of-bound access. For ignore_index that's out side of [0, C), we need to:target
(so we are not going to be accessing arbitrary memory);