Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

local_sum_make_vector rewrite can introduce forbidden float64 operations at the graph level #653

Closed
tvwenger opened this issue Feb 21, 2024 · 14 comments
Labels
bug Something isn't working graph rewriting

Comments

@tvwenger
Copy link
Contributor

tvwenger commented Feb 21, 2024

Describe the issue:

This may be related to pymc-devs/pymc#6779

Description:

When creating a model with floatX="float32" that includes a Dirichlet distribution a single distribution, the floatX assignment is respected. When creating a model with a Dirichlet distribution as well as another distribution two distributions, however, the floatX assignment is NOT respected, but only upon sampling. This is a weird bug.

Expected Behavior

The model should respect floatX in all cases.

Actual Behavior

When the model includes a Dirichlet distribution and then ANY other distribution two distributions, the graph includes float64 despite the request that floatX="float32".

Minimum Working Example

In the following MWE, I create four models. The first has one Dirichlet distribution, the second has one Normal distribution, and the remaining two include a Dirichlet distribution and then either a Normal or HalfCauchy distribution.

The first two models sample without issue, and floatX is respected.

The second and third models raise float64 errors during sampling. The error appears after model.point_logps(), which is what was all that was being checked in pymc-devs/pymc#6779

The output (with truncated error messages) is appended below:

pytensor version: 2.18.6
pymc version: 5.10.4
pytensor.config.floatX =  float64

test_dirichlet
pytensor.config.floatX =  float32
foo float32
{'foo': -1.5}
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [foo]
Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 2 seconds.ns, 0 divergences]

test_normal
pytensor.config.floatX =  float32
foo float32
{'foo': -0.92}
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [foo]
Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 2 seconds.ns, 0 divergences]

test_dirichlet_normal
pytensor.config.floatX =  float32
foo float32
bar float32
{'foo': -1.5, 'bar': -0.92}
ERROR (pytensor.graph.rewriting.basic): Rewrite failure due to: local_sum_make_vector
ERROR (pytensor.graph.rewriting.basic): node: Sum{axes=None}(MakeVector{dtype='float32'}.0)
ERROR (pytensor.graph.rewriting.basic): TRACEBACK:
ERROR (pytensor.graph.rewriting.basic): Traceback (most recent call last):
 ...
Exception: You are creating a TensorVariable with float64 dtype. You requested an action via the PyTensor flag warn_float64={ignore,warn,raise,pdb}.

Multiprocess sampling (4 chains in 4 jobs)
NUTS: [foo, bar]
Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 3 seconds.ns, 0 divergences]

test_dirichlet_halfcauchy
pytensor.config.floatX =  float32
foo float32
bar float32
{'foo': -1.5, 'bar': -1.14}
ERROR (pytensor.graph.rewriting.basic): Rewrite failure due to: local_sum_make_vector
ERROR (pytensor.graph.rewriting.basic): node: Sum{axes=None}(MakeVector{dtype='float32'}.0)
ERROR (pytensor.graph.rewriting.basic): TRACEBACK:
ERROR (pytensor.graph.rewriting.basic): Traceback (most recent call last):
...
Exception: You are creating a TensorVariable with float64 dtype. You requested an action via the PyTensor flag warn_float64={ignore,warn,raise,pdb}.

Multiprocess sampling (4 chains in 4 jobs)
NUTS: [foo, bar]
Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 3 seconds.ns, 0 divergences]

Note that the output of print(model.point_logps()) demonstrates that the error occurs after model.point_logps(). The error occurs during sampling.

Reproduceable code example:

import pytensor
import pytensor.tensor as pt
import pymc as pm

print("pytensor version:", pytensor.__version__)
print("pymc version:", pm.__version__)

print("pytensor.config.floatX = ", pytensor.config.floatX)
print()


def test_dirichlet():
    print("test_dirichlet")
    print("pytensor.config.floatX = ", pytensor.config.floatX)
    with pm.Model() as model:
        foo = pm.Dirichlet("foo", a=pt.ones(3))
        print(foo, foo.dtype)
    print(model.point_logps())
    with model:
        trace = pm.sample()
    print()


def test_normal():
    print("test_normal")
    print("pytensor.config.floatX = ", pytensor.config.floatX)
    with pm.Model() as model:
        foo = pm.Normal("foo", mu=0.0, sigma=1.0)
        print(foo, foo.dtype)
    print(model.point_logps())
    with model:
        trace = pm.sample()
    print()


def test_dirichlet_normal():
    print("test_dirichlet_normal")
    print("pytensor.config.floatX = ", pytensor.config.floatX)
    with pm.Model() as model:
        foo = pm.Dirichlet("foo", a=pt.ones(3))
        print(foo, foo.dtype)
        bar = pm.Normal("bar", mu=0.0, sigma=1.0)
        print(bar, bar.dtype)
    print(model.point_logps())
    with model:
        trace = pm.sample()
    print()


def test_dirichlet_halfcauchy():
    print("test_dirichlet_halfcauchy")
    print("pytensor.config.floatX = ", pytensor.config.floatX)
    with pm.Model() as model:
        foo = pm.Dirichlet("foo", a=pt.ones(3))
        print(foo, foo.dtype)
        bar = pm.HalfCauchy("bar", beta=1.0)
        print(bar, bar.dtype)
    print(model.point_logps())
    with model:
        trace = pm.sample()
    print()


with pytensor.config.change_flags(warn_float64="raise", floatX="float32"):
    test_dirichlet()

with pytensor.config.change_flags(warn_float64="raise", floatX="float32"):
    test_normal()

with pytensor.config.change_flags(warn_float64="raise", floatX="float32"):
    test_dirichlet_normal()

with pytensor.config.change_flags(warn_float64="raise", floatX="float32"):
    test_dirichlet_halfcauchy()

Error message:

ERROR (pytensor.graph.rewriting.basic): Rewrite failure due to: local_sum_make_vector
ERROR (pytensor.graph.rewriting.basic): node: Sum{axes=None}(MakeVector{dtype='float32'}.0)
ERROR (pytensor.graph.rewriting.basic): TRACEBACK:
ERROR (pytensor.graph.rewriting.basic): Traceback (most recent call last):
  File "/home/twenger/miniconda3/envs/pymc/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1922, in process_node
    replacements = node_rewriter.transform(fgraph, node)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/twenger/miniconda3/envs/pymc/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1082, in transform
    return self.fn(fgraph, node)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/twenger/miniconda3/envs/pymc/lib/python3.11/site-packages/pytensor/tensor/rewriting/basic.py", line 950, in local_sum_make_vector
    add(*[cast(value, acc_dtype) for value in elements]), out_dtype
         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/twenger/miniconda3/envs/pymc/lib/python3.11/site-packages/pytensor/tensor/rewriting/basic.py", line 950, in <listcomp>
    add(*[cast(value, acc_dtype) for value in elements]), out_dtype
          ^^^^^^^^^^^^^^^^^^^^^^
  File "/home/twenger/miniconda3/envs/pymc/lib/python3.11/site-packages/pytensor/tensor/basic.py", line 763, in cast
    return _cast_mapping[dtype_name](x)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/twenger/miniconda3/envs/pymc/lib/python3.11/site-packages/pytensor/graph/op.py", line 295, in __call__
    node = self.make_node(*inputs, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/twenger/miniconda3/envs/pymc/lib/python3.11/site-packages/pytensor/tensor/elemwise.py", line 484, in make_node
    outputs = [
              ^
  File "/home/twenger/miniconda3/envs/pymc/lib/python3.11/site-packages/pytensor/tensor/elemwise.py", line 485, in <listcomp>
    TensorType(dtype=dtype, shape=shape)()
  File "/home/twenger/miniconda3/envs/pymc/lib/python3.11/site-packages/pytensor/graph/type.py", line 228, in __call__
    return utils.add_tag_trace(self.make_variable(name))
                               ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/twenger/miniconda3/envs/pymc/lib/python3.11/site-packages/pytensor/graph/type.py", line 200, in make_variable
    return self.variable_type(self, None, name=name)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/twenger/miniconda3/envs/pymc/lib/python3.11/site-packages/pytensor/tensor/variable.py", line 900, in __init__
    raise Exception(msg)
Exception: You are creating a TensorVariable with float64 dtype. You requested an action via the PyTensor flag warn_float64={ignore,warn,raise,pdb}.

PyMC version information:

pytensor version: 2.18.6 pymc version: 5.10.4

Context for the issue:

Models that include a Dirichlet distribution as well as any other distribution cannot use float32.

@tvwenger tvwenger added the bug Something isn't working label Feb 21, 2024
@aerubanov
Copy link
Contributor

Also may be related with pymc-devs/pymc#7114

@tvwenger
Copy link
Contributor Author

I forgot to test combinations of other distributions (besides Dirichlet), and I just discovered that this issue is not exclusively related to Dirichlet. Any model that includes two distributions fails to respect floatX. Consider the following MWE that samples from a model with one Normal distribution without error, but fails to respect floatX for a model with two Normal distributions.

Code:

import pytensor
import pytensor.tensor as pt
import pymc as pm

print("pytensor version:", pytensor.__version__)
print("pymc version:", pm.__version__)

print("pytensor.config.floatX = ", pytensor.config.floatX)
print()


def test_normal():
    print("test_normal")
    print("pytensor.config.floatX = ", pytensor.config.floatX)
    with pm.Model() as model:
        foo = pm.Normal("foo", mu=0.0, sigma=1.0)
        print(foo, foo.dtype)
    print(model.point_logps())
    with model:
        trace = pm.sample()
    print()


def test_normal_normal():
    print("test_normal_normal")
    print("pytensor.config.floatX = ", pytensor.config.floatX)
    with pm.Model() as model:
        foo = pm.Normal("foo", mu=0.0, sigma=1.0)
        print(foo, foo.dtype)
        bar = pm.Normal("bar", mu=0.0, sigma=1.0)
        print(bar, bar.dtype)
    print(model.point_logps())
    with model:
        trace = pm.sample()
    print()


with pytensor.config.change_flags(warn_float64="raise", floatX="float32"):
    test_normal()

with pytensor.config.change_flags(warn_float64="raise", floatX="float32"):
    test_normal_normal()

Output:

pytensor version: 2.18.6
pymc version: 5.10.4
pytensor.config.floatX =  float64

test_normal
pytensor.config.floatX =  float32
foo float32
{'foo': -0.92}
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [foo]
Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 1 seconds.:00<00:00 Sampling 4 chains, 0 divergences]

test_normal_normal
pytensor.config.floatX =  float32
foo float32
bar float32
{'foo': -0.92, 'bar': -0.92}
ERROR (pytensor.graph.rewriting.basic): Rewrite failure due to: local_sum_make_vector
ERROR (pytensor.graph.rewriting.basic): node: Sum{axes=None}(MakeVector{dtype='float32'}.0)
ERROR (pytensor.graph.rewriting.basic): TRACEBACK:
ERROR (pytensor.graph.rewriting.basic): Traceback (most recent call last):
...
Exception: You are creating a TensorVariable with float64 dtype. You requested an action via the PyTensor flag warn_float64={ignore,warn,raise,pdb}.

Multiprocess sampling (4 chains in 4 jobs)
NUTS: [foo, bar]
Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 1 seconds.:00<00:00 Sampling 4 chains, 0 divergences]

@tvwenger tvwenger changed the title BUG: Models that include Dirichlet and another distribution do not respect floatX BUG: Models that include two distributions do not respect floatX Feb 27, 2024
@ricardoV94
Copy link
Member

The rewrite local_sum_make_vector seems to be creating float64. I don't know if that's something that can be avoided, would need to look at the graph the rewrite is being applied on in an interactive debugger.

@tvwenger
Copy link
Contributor Author

Looks like this was reported on the discourse back in September: https://discourse.pymc.io/t/how-to-force-float32/12947

@tvwenger
Copy link
Contributor Author

The rewrite local_sum_make_vector seems to be creating float64. I don't know if that's something that can be avoided, would need to look at the graph the rewrite is being applied on in an interactive debugger.

@ricardoV94 I suspect that pytensor is upcasting the dtype for the internal accumulator. Looking at the source for CAReduce in tensor/elemwise.py, that seems to be what is happening.

@ricardoV94
Copy link
Member

But that flagged rewrite gets rid of the CAReduce like sum(inputs) in favor of of add(*inputs) when they are all scalars. I think it's the rewrite that is causing upcasting not the original CAReduce version

@ricardoV94 ricardoV94 transferred this issue from pymc-devs/pymc Feb 27, 2024
@tvwenger
Copy link
Contributor Author

Looking at the logic in pytensor/tensor/elemwise.py, it looks like the default behavior, when acc_dtype is not specified, is to upcast acc_dtype relative to the input:

    def _acc_dtype(self, idtype):
        acc_dtype = self.acc_dtype
        if acc_dtype is None:
            return dict(
                bool="int64",
                int8="int64",
                int16="int64",
                int32="int64",
                uint8="uint64",
                uint16="uint64",
                uint32="uint64",
                float16="float32",
                float32="float64",
                complex64="complex128",
            ).get(idtype, idtype)

@ricardoV94
Copy link
Member

The accumulator dtype shouldn't matter. That's internal and not what it's triggering the float64 check. Otherwise you couldn't ever sum inputs in float32.

The problem is the rewrite that removes the CAReduce is not watching out for this.

@ricardoV94
Copy link
Member

Here is a minimal reproducible example:

import pytensor
import pytensor.tensor as pt

with pytensor.config.change_flags(warn_float64="raise", floatX="float32"):
    x1,x2,x3 = pt.scalars("x1","x2","x3")
    out = pt.sum([x1, x2, x3], acc_dtype="float32")
    out.eval({x1:0.0, x2:1.0, x3:2.0})  # Fine

with pytensor.config.change_flags(warn_float64="raise", floatX="float32"):
    x1,x2,x3 = pt.scalars("x1","x2","x3")
    out = pt.sum([x1, x2, x3])
    out.eval({x1:0.0, x2:1.0, x3:2.0})

The problem is the rewrite exposes the internal acc_dtype when it replaces the CAReduce by an Add:

else:
element_sum = cast(
add(*[cast(value, acc_dtype) for value in elements]), out_dtype
)

We could restrict the rewrite to only apply in cases where the internal acc_dtype doesn't have higher precision than the input/output dtypes

@ricardoV94 ricardoV94 changed the title BUG: Models that include two distributions do not respect floatX local_sum_make_vector rewrite can introduce forbidden float64 operations at the graph level Feb 28, 2024
@tvwenger
Copy link
Contributor Author

tvwenger commented Mar 1, 2024

Aha, I see now. Thanks for the clarification @ricardoV94

If you point me in the right direction, I can work on a PR.

@tvwenger tvwenger mentioned this issue Mar 2, 2024
10 tasks
@tvwenger
Copy link
Contributor Author

tvwenger commented Mar 2, 2024

Well, actually I couldn't quite figure out what you meant. Instead, I noticed that the function local_sum_make_vector is getting acc_dtype from the node, which is ultimately being set via the logic I posted earlier in pytensor/tensor/elemwise.py. In the linked PR, I simply changed the default behavior when acc_dtype is not specified so that instead of upcasting from the input dtype it defaults to config.floatX. I tested the MWE you posted and this simple change fixes the problem.

@ricardoV94
Copy link
Member

What I was saying is that it may be more reasonable to opt out of the rewrite than trying to change the default internals of CAReduce

@tvwenger
Copy link
Contributor Author

tvwenger commented Mar 3, 2024

@ricardoV94 OK, I finally understand what you meant about skipping the rewrite entirely. I've implemented that fix in a new PR: #656

@ricardoV94
Copy link
Member

Closed via #659

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working graph rewriting
Projects
None yet
Development

No branches or pull requests

3 participants