-
Notifications
You must be signed in to change notification settings - Fork 423
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
JaccardIndex mps fails also with PYTORCH_ENABLE_MPS_FALLBACK=1 #1196
Comments
Hi! thanks for your contribution!, great first issue! |
@justusschock should we do anything to support this explicitly or do we need to wait for support from PyTorch side? |
Hi @SkafteNicki, I think as a temporary fix, we may adjust the following part of code: def _bincount(x: Tensor, minlength: Optional[int] = None) -> Tensor:
"""``torch.bincount`` currently does not support deterministic mode on GPU.
This implementation fallback to a for-loop counting occurrences in that case.
Args:
x: tensor to count
minlength: minimum length to count
Returns:
Number of occurrences for each unique element in x
"""
- if x.is_cuda and deterministic():
+ if (x.is_cuda and deterministic()) or x.is_mps:
if minlength is None:
minlength = len(torch.unique(x))
output = torch.zeros(minlength, device=x.device, dtype=torch.long)
for i in range(minlength):
output[i] = (x == i).sum()
return output
else:
return torch.bincount(x, minlength=minlength) |
Also, there are some other operations which are not yet supported on |
@stancld feel free to do so :) |
@Borda Do you think it'd be feasible and wise to run our test suites also on an M1 runner? |
Fixes #1196 * Change `_bincount` calculation for `MPS` to run for loop fallback * Use torch.where implementation to apply `absent_score` instead of relying on item assignment There's still a warning on PyTorch side (using CPU fallback for some operations), however, no actions on our users' side is now required and the results are obtained smoothly.
@stancld thanks for taking a stab! @SkafteNicki other than the fixes pointed out by @stancld there is nothing we can do besides waiting. |
Correct, we do not have any m1 HW in our CI fleet yet... let us check what we can do here :) |
🐛 Bug
JaccardIndex call fails with M1 mps backend since "aten::bincount" is not currently supported. This happens also when PYTORCH_ENABLE_MPS_FALLBACK=1 is set environment variable.
To Reproduce
UserWarning: The operator 'aten::bincount' is not currently supported on the MPS backend and will fall back to run on the CPU. This may have performance implications. (Triggered internally at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/mps/MPSFallback.mm:11.)
return torch.bincount(x, minlength=minlength)
Environment
pip
): 0.9.3The text was updated successfully, but these errors were encountered: