-
Notifications
You must be signed in to change notification settings - Fork 112
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
local_sum_make_vector
rewrite can introduce forbidden float64 operations at the graph level
#653
Comments
Also may be related with pymc-devs/pymc#7114 |
I forgot to test combinations of other distributions (besides Dirichlet), and I just discovered that this issue is not exclusively related to Dirichlet. Any model that includes two distributions fails to respect floatX. Consider the following MWE that samples from a model with one Normal distribution without error, but fails to respect floatX for a model with two Normal distributions. Code: import pytensor
import pytensor.tensor as pt
import pymc as pm
print("pytensor version:", pytensor.__version__)
print("pymc version:", pm.__version__)
print("pytensor.config.floatX = ", pytensor.config.floatX)
print()
def test_normal():
print("test_normal")
print("pytensor.config.floatX = ", pytensor.config.floatX)
with pm.Model() as model:
foo = pm.Normal("foo", mu=0.0, sigma=1.0)
print(foo, foo.dtype)
print(model.point_logps())
with model:
trace = pm.sample()
print()
def test_normal_normal():
print("test_normal_normal")
print("pytensor.config.floatX = ", pytensor.config.floatX)
with pm.Model() as model:
foo = pm.Normal("foo", mu=0.0, sigma=1.0)
print(foo, foo.dtype)
bar = pm.Normal("bar", mu=0.0, sigma=1.0)
print(bar, bar.dtype)
print(model.point_logps())
with model:
trace = pm.sample()
print()
with pytensor.config.change_flags(warn_float64="raise", floatX="float32"):
test_normal()
with pytensor.config.change_flags(warn_float64="raise", floatX="float32"):
test_normal_normal() Output: pytensor version: 2.18.6
pymc version: 5.10.4
pytensor.config.floatX = float64
test_normal
pytensor.config.floatX = float32
foo float32
{'foo': -0.92}
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [foo]
Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 1 seconds.:00<00:00 Sampling 4 chains, 0 divergences]
test_normal_normal
pytensor.config.floatX = float32
foo float32
bar float32
{'foo': -0.92, 'bar': -0.92}
ERROR (pytensor.graph.rewriting.basic): Rewrite failure due to: local_sum_make_vector
ERROR (pytensor.graph.rewriting.basic): node: Sum{axes=None}(MakeVector{dtype='float32'}.0)
ERROR (pytensor.graph.rewriting.basic): TRACEBACK:
ERROR (pytensor.graph.rewriting.basic): Traceback (most recent call last):
...
Exception: You are creating a TensorVariable with float64 dtype. You requested an action via the PyTensor flag warn_float64={ignore,warn,raise,pdb}.
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [foo, bar]
Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 1 seconds.:00<00:00 Sampling 4 chains, 0 divergences] |
The rewrite |
Looks like this was reported on the discourse back in September: https://discourse.pymc.io/t/how-to-force-float32/12947 |
@ricardoV94 I suspect that |
But that flagged rewrite gets rid of the CAReduce like sum(inputs) in favor of of add(*inputs) when they are all scalars. I think it's the rewrite that is causing upcasting not the original CAReduce version |
Looking at the logic in
|
The accumulator dtype shouldn't matter. That's internal and not what it's triggering the float64 check. Otherwise you couldn't ever sum inputs in float32. The problem is the rewrite that removes the CAReduce is not watching out for this. |
Here is a minimal reproducible example: import pytensor
import pytensor.tensor as pt
with pytensor.config.change_flags(warn_float64="raise", floatX="float32"):
x1,x2,x3 = pt.scalars("x1","x2","x3")
out = pt.sum([x1, x2, x3], acc_dtype="float32")
out.eval({x1:0.0, x2:1.0, x3:2.0}) # Fine
with pytensor.config.change_flags(warn_float64="raise", floatX="float32"):
x1,x2,x3 = pt.scalars("x1","x2","x3")
out = pt.sum([x1, x2, x3])
out.eval({x1:0.0, x2:1.0, x3:2.0}) The problem is the rewrite exposes the internal pytensor/pytensor/tensor/rewriting/basic.py Lines 948 to 951 in 8aeda39
We could restrict the rewrite to only apply in cases where the internal acc_dtype doesn't have higher precision than the input/output dtypes |
local_sum_make_vector
rewrite can introduce forbidden float64 operations at the graph level
Aha, I see now. Thanks for the clarification @ricardoV94 If you point me in the right direction, I can work on a PR. |
Well, actually I couldn't quite figure out what you meant. Instead, I noticed that the function |
What I was saying is that it may be more reasonable to opt out of the rewrite than trying to change the default internals of CAReduce |
@ricardoV94 OK, I finally understand what you meant about skipping the rewrite entirely. I've implemented that fix in a new PR: #656 |
Closed via #659 |
Describe the issue:
This may be related to pymc-devs/pymc#6779
Description:
When creating a model with
floatX="float32"
that includesa Dirichlet distributiona single distribution, thefloatX
assignment is respected. When creating a model witha Dirichlet distribution as well as another distributiontwo distributions, however, thefloatX
assignment is NOT respected, but only upon sampling. This is a weird bug.Expected Behavior
The model should respect
floatX
in all cases.Actual Behavior
When the model includes
a Dirichlet distribution and then ANY other distributiontwo distributions, the graph includesfloat64
despite the request thatfloatX="float32"
.Minimum Working Example
In the following MWE, I create four models. The first has one Dirichlet distribution, the second has one Normal distribution, and the remaining two include a Dirichlet distribution and then either a Normal or HalfCauchy distribution.
The first two models sample without issue, and
floatX
is respected.The second and third models raise
float64
errors during sampling. The error appears aftermodel.point_logps()
, which is what was all that was being checked in pymc-devs/pymc#6779The output (with truncated error messages) is appended below:
Note that the output of
print(model.point_logps())
demonstrates that the error occurs aftermodel.point_logps()
. The error occurs during sampling.Reproduceable code example:
Error message:
PyMC version information:
Context for the issue:
Models that include a Dirichlet distribution as well as any other distribution cannot use
float32
.The text was updated successfully, but these errors were encountered: