Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prevent local_sum_make_vector from introducing forbidden float64 at graph level #656

Closed
wants to merge 0 commits into from

Conversation

tvwenger
Copy link
Contributor

@tvwenger tvwenger commented Mar 3, 2024

Description

This PR changes the behavior of local_sum_make_vector to skip rewriting whenever the internal accumulator is more precise than the input/output data. Otherwise, the internal accumulator is added to the graph, which can introduce forbidden float64 to the graph when the user requests config.floatX="float32" or less precision.

Replaces PR #655

Related Issue

Checklist

Type of change

  • New feature / enhancement
  • Bug fix
  • Documentation
  • Maintenance
  • Other (please specify):

@codecov-commenter
Copy link

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 80.82%. Comparing base (e8693bd) to head (d68a252).

Additional details and impacted files

Impacted file tree graph

@@           Coverage Diff           @@
##             main     #656   +/-   ##
=======================================
  Coverage   80.82%   80.82%           
=======================================
  Files         162      162           
  Lines       46812    46814    +2     
  Branches    11437    11438    +1     
=======================================
+ Hits        37836    37838    +2     
  Misses       6725     6725           
  Partials     2251     2251           
Files Coverage Δ
pytensor/tensor/rewriting/basic.py 94.07% <100.00%> (+0.02%) ⬆️

@tvwenger
Copy link
Contributor Author

tvwenger commented Mar 3, 2024

@ricardoV94 This PR is ready for review.

@@ -941,6 +941,11 @@ def local_sum_make_vector(fgraph, node):
elements = array.owner.inputs
acc_dtype = node.op.acc_dtype
out_dtype = node.op.dtype

# Skip rewrite when acc_dtype has higher precision than out_dtype
if is_an_upcast(out_dtype, acc_dtype):
Copy link
Member

@ricardoV94 ricardoV94 Mar 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems less restrictive and still avoid the original issue?

Suggested change
if is_an_upcast(out_dtype, acc_dtype):
if acc_dtype=="float64" and out_dtype != "float64" and pytensor.config.floatX != "float64":

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The issue could also appear if the user requests even less precision. For example, if config.floatX = "float16", then acc_dtype will be float32 and we'll have the same problem. The issue will always exist when acc_dtype is more precise than out_dtype, which is true for every floatX except float64.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I take your point though. I'll re-work this.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we support floatX=16 btw? I think float16 was only ever an experimental feature in the codebase and not widely supported

@tvwenger
Copy link
Contributor Author

tvwenger commented Mar 4, 2024

@ricardoV94 This will be ready for review when the checks are complete

@@ -941,6 +941,11 @@ def local_sum_make_vector(fgraph, node):
elements = array.owner.inputs
acc_dtype = node.op.acc_dtype
out_dtype = node.op.dtype

# Skip rewrite when acc_dtype has higher precision than floatX
if is_an_upcast(config.floatX, acc_dtype):
Copy link
Member

@ricardoV94 ricardoV94 Mar 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is still too restrictive. floatx=32 does not mean float64 are forbidden, only that PyTensor won't introduce them automatically if they can be avoided. The user is still allowed to call sum(dtype=float64) or pass a float64 input in which case this check is too restrictive. There's already a float64 at one of the ends of this operation so it's fine for the float64 acc_dtype to be exposed.

That's why my suggestion was quite more verbose.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants