Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MPS support] Make Jaccard Index working on MPS #1205

Merged
merged 7 commits into from
Sep 12, 2022
Merged

Conversation

stancld
Copy link
Contributor

@stancld stancld commented Sep 6, 2022

What does this PR do?

Fixes #1196

  • Change _bincount calculation for MPS to run for loop fallback
  • Use torch.where implementation to apply absent_score instead of relying on item assignment

There's still a warning on PyTorch side (using CPU fallback for some operations; see UserWarning below), however,
no actions on our users' side are now required and the results are obtained smoothly. Calculation runs with warnings (but without errors!) both with PYTORCH_ENABLE_MPS_FALLBACK=0 and PYTORCH_ENABLE_MPS_FALLBACK=1.

/Users/stancld/miniconda3/envs/metrics/lib/python3.9/site-packages/torch/_tensor_str.py:103: UserWarning: The operator 'aten::bitwise_and.Tensor_out' is not currently supported on the MPS backend and will fall back to run on the CPU. This
 may have performance implications. (Triggered internally at  /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/mps/MPSFallback.mm:11.)                                                                                                
  nonzero_finite_vals = torch.masked_select(tensor_view, torch.isfinite(tensor_view) & tensor_view.ne(0))  

(tested on my M1 Mac)

Before submitting

  • Was this discussed/approved via a Github issue? (no need for typos and docs improvements)
  • Did you read the contributor guideline, Pull Request section?
  • Did you make sure to update the docs?
  • Did you write any new necessary tests?

PR review

Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.

Did you have fun?

Make sure you had fun coding 🙃

Fixes #1196

* Change `_bincount` calculation for `MPS` to run for loop fallback
* Use torch.where implementation to apply `absent_score` instead of relying on item assignment

There's still a warning on PyTorch side (using CPU fallback for some operations), however,
no actions on our users' side is now required and the results are obtained smoothly.
@stancld stancld added this to the v0.9 milestone Sep 6, 2022
@stancld stancld added the bug / fix Something isn't working label Sep 6, 2022
@stancld stancld marked this pull request as draft September 6, 2022 10:13
@stancld stancld marked this pull request as ready for review September 6, 2022 10:30
@codecov
Copy link

codecov bot commented Sep 6, 2022

Codecov Report

Merging #1205 (fbd9b57) into master (5bc828d) will increase coverage by 0%.
The diff coverage is 100%.

@@          Coverage Diff           @@
##           master   #1205   +/-   ##
======================================
  Coverage      94%     94%           
======================================
  Files         185     185           
  Lines        8407    8408    +1     
======================================
+ Hits         7895    7896    +1     
  Misses        512     512           

@mergify mergify bot added the ready label Sep 9, 2022
@Borda Borda enabled auto-merge (squash) September 12, 2022 17:03
@Borda Borda merged commit c97c59d into master Sep 12, 2022
@Borda Borda deleted the bugfix/fix-jacard-on-mps branch September 12, 2022 17:25
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug / fix Something isn't working ready
Projects
None yet
Development

Successfully merging this pull request may close these issues.

JaccardIndex mps fails also with PYTORCH_ENABLE_MPS_FALLBACK=1
3 participants