-
Notifications
You must be signed in to change notification settings - Fork 115
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Prevent local_sum_make_vector
from introducing forbidden float64
at graph level
#656
Conversation
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## main #656 +/- ##
=======================================
Coverage 80.82% 80.82%
=======================================
Files 162 162
Lines 46812 46814 +2
Branches 11437 11438 +1
=======================================
+ Hits 37836 37838 +2
Misses 6725 6725
Partials 2251 2251
|
@ricardoV94 This PR is ready for review. |
pytensor/tensor/rewriting/basic.py
Outdated
@@ -941,6 +941,11 @@ def local_sum_make_vector(fgraph, node): | |||
elements = array.owner.inputs | |||
acc_dtype = node.op.acc_dtype | |||
out_dtype = node.op.dtype | |||
|
|||
# Skip rewrite when acc_dtype has higher precision than out_dtype | |||
if is_an_upcast(out_dtype, acc_dtype): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems less restrictive and still avoid the original issue?
if is_an_upcast(out_dtype, acc_dtype): | |
if acc_dtype=="float64" and out_dtype != "float64" and pytensor.config.floatX != "float64": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The issue could also appear if the user requests even less precision. For example, if config.floatX = "float16"
, then acc_dtype
will be float32
and we'll have the same problem. The issue will always exist when acc_dtype
is more precise than out_dtype
, which is true for every floatX
except float64
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I take your point though. I'll re-work this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we support floatX=16 btw? I think float16 was only ever an experimental feature in the codebase and not widely supported
@ricardoV94 This will be ready for review when the checks are complete |
pytensor/tensor/rewriting/basic.py
Outdated
@@ -941,6 +941,11 @@ def local_sum_make_vector(fgraph, node): | |||
elements = array.owner.inputs | |||
acc_dtype = node.op.acc_dtype | |||
out_dtype = node.op.dtype | |||
|
|||
# Skip rewrite when acc_dtype has higher precision than floatX | |||
if is_an_upcast(config.floatX, acc_dtype): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is still too restrictive. floatx=32 does not mean float64 are forbidden, only that PyTensor won't introduce them automatically if they can be avoided. The user is still allowed to call sum(dtype=float64)
or pass a float64 input in which case this check is too restrictive. There's already a float64 at one of the ends of this operation so it's fine for the float64 acc_dtype to be exposed.
That's why my suggestion was quite more verbose.
3a3c00a
to
d175203
Compare
Description
This PR changes the behavior of
local_sum_make_vector
to skip rewriting whenever the internal accumulator is more precise than the input/output data. Otherwise, the internal accumulator is added to the graph, which can introduce forbiddenfloat64
to the graph when the user requestsconfig.floatX="float32"
or less precision.Replaces PR #655
Related Issue
local_sum_make_vector
rewrite can introduce forbidden float64 operations at the graph level #653Checklist
Type of change