diff --git a/conda-envs/environment-dev.yml b/conda-envs/environment-dev.yml index 0600471e7eb..6f40d2badfd 100644 --- a/conda-envs/environment-dev.yml +++ b/conda-envs/environment-dev.yml @@ -14,7 +14,7 @@ dependencies: - numpy>=1.15.0 - pandas>=0.24.0 - pip -- pytensor=2.8.10 +- pytensor=2.8.11 - python-graphviz - networkx - scipy>=1.4.1 diff --git a/conda-envs/environment-test.yml b/conda-envs/environment-test.yml index fc1decdec71..d1603b5720e 100644 --- a/conda-envs/environment-test.yml +++ b/conda-envs/environment-test.yml @@ -17,7 +17,7 @@ dependencies: - numpy>=1.15.0 - pandas>=0.24.0 - pip -- pytensor=2.8.10 +- pytensor=2.8.11 - python-graphviz - networkx - scipy>=1.4.1 diff --git a/conda-envs/windows-environment-dev.yml b/conda-envs/windows-environment-dev.yml index 0459beb8a8b..319a070105c 100644 --- a/conda-envs/windows-environment-dev.yml +++ b/conda-envs/windows-environment-dev.yml @@ -14,7 +14,7 @@ dependencies: - numpy>=1.15.0 - pandas>=0.24.0 - pip -- pytensor=2.8.10 +- pytensor=2.8.11 - python-graphviz - networkx - scipy>=1.4.1 diff --git a/conda-envs/windows-environment-test.yml b/conda-envs/windows-environment-test.yml index b637993a3f2..067e3f1008e 100644 --- a/conda-envs/windows-environment-test.yml +++ b/conda-envs/windows-environment-test.yml @@ -17,7 +17,7 @@ dependencies: - numpy>=1.15.0 - pandas>=0.24.0 - pip -- pytensor=2.8.10 +- pytensor=2.8.11 - python-graphviz - networkx - scipy>=1.4.1 diff --git a/pymc/distributions/distribution.py b/pymc/distributions/distribution.py index f59f5b633c6..6e4599c985c 100644 --- a/pymc/distributions/distribution.py +++ b/pymc/distributions/distribution.py @@ -26,7 +26,8 @@ from pytensor import tensor as at from pytensor.compile.builders import OpFromGraph from pytensor.graph import node_rewriter -from pytensor.graph.basic import Node, Variable, clone_replace +from pytensor.graph.basic import Node, Variable +from pytensor.graph.replace import clone_replace from pytensor.graph.rewriting.basic import in2out from pytensor.graph.utils import MetaType from pytensor.tensor.basic import as_tensor_variable diff --git a/pymc/distributions/timeseries.py b/pymc/distributions/timeseries.py index 1af675126ee..14cce247611 100644 --- a/pymc/distributions/timeseries.py +++ b/pymc/distributions/timeseries.py @@ -21,7 +21,8 @@ import pytensor import pytensor.tensor as at -from pytensor.graph.basic import Node, clone_replace +from pytensor.graph.basic import Node +from pytensor.graph.replace import clone_replace from pytensor.tensor import TensorVariable from pytensor.tensor.random.op import RandomVariable diff --git a/pymc/logprob/mixture.py b/pymc/logprob/mixture.py index 0b0d7b7cf54..9f2e45e1dc7 100644 --- a/pymc/logprob/mixture.py +++ b/pymc/logprob/mixture.py @@ -53,6 +53,7 @@ from pytensor.tensor.elemwise import Elemwise from pytensor.tensor.random.rewriting import ( local_dimshuffle_rv_lift, + local_rv_size_lift, local_subtensor_rv_lift, ) from pytensor.tensor.shape import shape_tuple @@ -210,6 +211,7 @@ def rv_pull_down(x: TensorVariable, dont_touch_vars=None) -> TensorVariable: return pre_greedy_node_rewriter( fgraph, [ + local_rv_size_lift, local_dimshuffle_rv_lift, local_subtensor_rv_lift, naive_bcast_rv_lift, @@ -443,13 +445,23 @@ def logprob_MixtureRV( logp_val = at.set_subtensor(logp_val[idx_m_on_axis], logp_m) else: + # If the stacking operation expands the component RVs, we have + # to expand the value and later squeeze the logprob for everything + # to work correctly + join_axis_val = None if isinstance(join_axis.type, NoneTypeT) else join_axis.data + + if join_axis_val is not None: + value = at.expand_dims(value, axis=join_axis_val) + logp_val = 0.0 for i, comp_rv in enumerate(comp_rvs): comp_logp = logprob(comp_rv, value) + if join_axis_val is not None: + comp_logp = at.squeeze(comp_logp, axis=join_axis_val) logp_val += ifelse( at.eq(indices[0], i), comp_logp, - at.zeros_like(value), + at.zeros_like(comp_logp), ) return logp_val diff --git a/pymc/logprob/transforms.py b/pymc/logprob/transforms.py index 8ae8b432409..463c59d1fd3 100644 --- a/pymc/logprob/transforms.py +++ b/pymc/logprob/transforms.py @@ -42,10 +42,11 @@ import pytensor.tensor as at from pytensor.gradient import DisconnectedType, jacobian -from pytensor.graph.basic import Apply, Node, Variable, clone_replace +from pytensor.graph.basic import Apply, Node, Variable from pytensor.graph.features import AlreadyThere, Feature from pytensor.graph.fg import FunctionGraph from pytensor.graph.op import Op +from pytensor.graph.replace import clone_replace from pytensor.graph.rewriting.basic import GraphRewriter, in2out, node_rewriter from pytensor.scalar import Add, Exp, Log, Mul from pytensor.scan.op import Scan diff --git a/pymc/pytensorf.py b/pymc/pytensorf.py index 211c4a99a68..c92bed6ef1b 100644 --- a/pymc/pytensorf.py +++ b/pymc/pytensorf.py @@ -1082,7 +1082,7 @@ def collect_default_updates( assert random_var.owner.op is not None if isinstance(random_var.owner.op, RandomVariable): rng = random_var.owner.inputs[0] - if hasattr(rng, "default_update"): + if getattr(rng, "default_update", None) is not None: update_map = {rng: rng.default_update} else: update_map = {rng: random_var.owner.outputs[0]} diff --git a/pymc/sampling/jax.py b/pymc/sampling/jax.py index c3ac07fa33a..3f017c5a8f6 100644 --- a/pymc/sampling/jax.py +++ b/pymc/sampling/jax.py @@ -22,8 +22,9 @@ from arviz.data.base import make_attrs from pytensor.compile import SharedVariable, Supervisor, mode -from pytensor.graph.basic import clone_replace, graph_inputs +from pytensor.graph.basic import graph_inputs from pytensor.graph.fg import FunctionGraph +from pytensor.graph.replace import clone_replace from pytensor.link.jax.dispatch import jax_funcify from pytensor.raise_op import Assert from pytensor.tensor import TensorVariable @@ -70,7 +71,7 @@ def _replace_shared_variables(graph: List[TensorVariable]) -> List[TensorVariabl shared_variables = [var for var in graph_inputs(graph) if isinstance(var, SharedVariable)] - if any(hasattr(var, "default_update") for var in shared_variables): + if any(var.default_update is not None for var in shared_variables): raise ValueError( "Graph contains shared variables with default_update which cannot " "be safely replaced." diff --git a/pymc/smc/kernels.py b/pymc/smc/kernels.py index 9b7a02e11cf..7cdbfd3166f 100644 --- a/pymc/smc/kernels.py +++ b/pymc/smc/kernels.py @@ -21,7 +21,7 @@ import numpy as np import pytensor.tensor as at -from pytensor.graph.basic import clone_replace +from pytensor.graph.replace import clone_replace from scipy.special import logsumexp from scipy.stats import multivariate_normal diff --git a/pymc/tests/distributions/test_shape_utils.py b/pymc/tests/distributions/test_shape_utils.py index b335fefbdc2..dcd78dab8f5 100644 --- a/pymc/tests/distributions/test_shape_utils.py +++ b/pymc/tests/distributions/test_shape_utils.py @@ -562,13 +562,13 @@ def test_change_rv_size_default_update(): new_x = change_dist_size(x, new_size=(2,)) new_rng = new_x.owner.inputs[0] assert rng.default_update is next_rng - assert not hasattr(new_rng, "default_update") + assert new_rng.default_update is None - # Test that default_update is not set if there was none before - del rng.default_update + # Test that default_update is not set if it was None before + rng.default_update = None new_x = change_dist_size(x, new_size=(2,)) new_rng = new_x.owner.inputs[0] - assert not hasattr(new_rng, "default_update") + assert new_rng.default_update is None def test_change_specify_shape_size_univariate(): diff --git a/pymc/tests/logprob/test_mixture.py b/pymc/tests/logprob/test_mixture.py index 984bf264bb3..e29c7abc5e7 100644 --- a/pymc/tests/logprob/test_mixture.py +++ b/pymc/tests/logprob/test_mixture.py @@ -226,6 +226,26 @@ def test_hetero_mixture_binomial(p_val, size): (), 0, ), + # Degenerate vector mixture components, scalar index + ( + ( + np.array([0], dtype=pytensor.config.floatX), + np.array(1, dtype=pytensor.config.floatX), + ), + ( + np.array([0.5], dtype=pytensor.config.floatX), + np.array(0.5, dtype=pytensor.config.floatX), + ), + ( + np.array([100], dtype=pytensor.config.floatX), + np.array(1, dtype=pytensor.config.floatX), + ), + np.array([0.1, 0.5, 0.4], dtype=pytensor.config.floatX), + None, + (), + (), + 0, + ), # Scalar mixture components, vector index ( ( @@ -298,22 +318,23 @@ def test_hetero_mixture_binomial(p_val, size): np.array(1, dtype=pytensor.config.floatX), ), np.array([0.1, 0.5, 0.4], dtype=pytensor.config.floatX), - (), - (), + (2,), + (2,), (), 0, ), + # Same as before but with degenerate vector parameters ( ( - np.array(0, dtype=pytensor.config.floatX), + np.array([0], dtype=pytensor.config.floatX), np.array(1, dtype=pytensor.config.floatX), ), ( - np.array(0.5, dtype=pytensor.config.floatX), + np.array([0.5], dtype=pytensor.config.floatX), np.array(0.5, dtype=pytensor.config.floatX), ), ( - np.array(100, dtype=pytensor.config.floatX), + np.array([100], dtype=pytensor.config.floatX), np.array(1, dtype=pytensor.config.floatX), ), np.array([0.1, 0.5, 0.4], dtype=pytensor.config.floatX), @@ -360,7 +381,7 @@ def test_hetero_mixture_binomial(p_val, size): (), 0, ), - ( + pytest.param( ( np.array(0, dtype=pytensor.config.floatX), np.array(1, dtype=pytensor.config.floatX), @@ -378,6 +399,7 @@ def test_hetero_mixture_binomial(p_val, size): (3,), (slice(None),), 1, + marks=pytest.mark.xfail(IndexError, reason="Bug in AdvancedIndex Mixture logprob"), ), ( ( @@ -462,6 +484,9 @@ def test_hetero_mixture_categorical( gamma_sp = sp.gamma(Y_args[0], scale=1 / Y_args[1]) norm_2_sp = sp.norm(loc=Z_args[0], scale=Z_args[1]) + # Handle scipy annoying squeeze of random draws + real_comp_size = tuple(X_rv.shape.eval()) + for i in range(10): i_val = CategoricalRV.rng_fn(test_val_rng, p_val, idx_size) @@ -469,9 +494,15 @@ def test_hetero_mixture_categorical( indices_val.insert(join_axis, i_val) indices_val = tuple(indices_val) - x_val = norm_1_sp.rvs(size=comp_size, random_state=test_val_rng) - y_val = gamma_sp.rvs(size=comp_size, random_state=test_val_rng) - z_val = norm_2_sp.rvs(size=comp_size, random_state=test_val_rng) + x_val = np.broadcast_to( + norm_1_sp.rvs(size=comp_size, random_state=test_val_rng), real_comp_size + ) + y_val = np.broadcast_to( + gamma_sp.rvs(size=comp_size, random_state=test_val_rng), real_comp_size + ) + z_val = np.broadcast_to( + norm_2_sp.rvs(size=comp_size, random_state=test_val_rng), real_comp_size + ) component_logps = np.stack( [norm_1_sp.logpdf(x_val), gamma_sp.logpdf(y_val), norm_2_sp.logpdf(z_val)], diff --git a/pymc/tests/test_initial_point.py b/pymc/tests/test_initial_point.py index a9fdb5a6200..400b510a599 100644 --- a/pymc/tests/test_initial_point.py +++ b/pymc/tests/test_initial_point.py @@ -238,7 +238,7 @@ def test_numeric_moment_shape(self, rv_cls): @pytest.mark.parametrize("rv_cls", [pm.Flat, pm.HalfFlat]) def test_symbolic_moment_shape(self, rv_cls): - s = at.scalar() + s = at.scalar(dtype="int64") rv = rv_cls.dist(shape=(s,)) assert not hasattr(rv.tag, "test_value") assert tuple(moment(rv).shape.eval({s: 4})) == (4,) diff --git a/pymc/tests/test_pytensorf.py b/pymc/tests/test_pytensorf.py index ce02dbae18f..983e5275aaf 100644 --- a/pymc/tests/test_pytensorf.py +++ b/pymc/tests/test_pytensorf.py @@ -336,7 +336,7 @@ def test_compile_pymc_sets_rng_updates(self): assert not np.isclose(f(), f()) # Check that update was not done inplace - assert not hasattr(rng, "default_update") + assert rng.default_update is None f = pytensor.function([], x) assert f() == f() diff --git a/requirements-dev.txt b/requirements-dev.txt index 714bc30d4f8..c1ef9b70e80 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -17,7 +17,7 @@ numpydoc pandas>=0.24.0 polyagamma pre-commit>=2.8.0 -pytensor==2.8.10 +pytensor==2.8.11 pytest-cov>=2.5 pytest>=3.0 scipy>=1.4.1 diff --git a/requirements.txt b/requirements.txt index 4f64eb94740..cc6e6afc177 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,6 +4,6 @@ cloudpickle fastprogress>=0.2.0 numpy>=1.15.0 pandas>=0.24.0 -pytensor==2.8.10 +pytensor==2.8.11 scipy>=1.4.1 typing-extensions>=3.7.4