-
Notifications
You must be signed in to change notification settings - Fork 482
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add float cast tensor op #2483
Add float cast tensor op #2483
Conversation
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #2483 +/- ##
==========================================
- Coverage 82.97% 82.85% -0.13%
==========================================
Files 812 814 +2
Lines 105212 105356 +144
==========================================
- Hits 87304 87288 -16
- Misses 17908 18068 +160 ☔ View full report in Codecov by Sentry. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Checklist
run-checks all
script has been executed.Changes
Add
tensor.cast(dtype)
op for floating point casting to different precision types.Currently implemented for
candle
andtch
. Had to remove the element type generic on both tensor primitives for this change.Will be implemented in a follow-up PR for:
burn-ndarray
,burn-jit
,burn-fusion
andburn-router
.Important: handling different floating point precision types in operations between multiple tensors might require automatic type promotion based on a selected strategy (e.g., cast to the higher precision type like torch). This is not in the current scope, instead backends can panic with a dtype mismatch.
Testing
Added a unit test.