From 09a586fe6ee3d3b8e780260ac24eafa10572108a Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Thu, 15 Dec 2022 11:46:15 +0100 Subject: [PATCH 1/5] Remove duplicated test parametrization --- pymc/tests/logprob/test_mixture.py | 19 ------------------- 1 file changed, 19 deletions(-) diff --git a/pymc/tests/logprob/test_mixture.py b/pymc/tests/logprob/test_mixture.py index 984bf264bb3..e188052bb10 100644 --- a/pymc/tests/logprob/test_mixture.py +++ b/pymc/tests/logprob/test_mixture.py @@ -284,25 +284,6 @@ def test_hetero_mixture_binomial(p_val, size): (), 0, ), - ( - ( - np.array(0, dtype=pytensor.config.floatX), - np.array(1, dtype=pytensor.config.floatX), - ), - ( - np.array(0.5, dtype=pytensor.config.floatX), - np.array(0.5, dtype=pytensor.config.floatX), - ), - ( - np.array(100, dtype=pytensor.config.floatX), - np.array(1, dtype=pytensor.config.floatX), - ), - np.array([0.1, 0.5, 0.4], dtype=pytensor.config.floatX), - (), - (), - (), - 0, - ), ( ( np.array(0, dtype=pytensor.config.floatX), From 37977ad9cbbf4c086691ce2459f49956469bf3f4 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Thu, 15 Dec 2022 10:34:36 +0100 Subject: [PATCH 2/5] Fix bug in IfElse Mixture logprob It did not account for the extra dimension added by the stacking operation, resulting in a logp call to an expanded RV with the original unexpanded value --- pymc/logprob/mixture.py | 12 +++++++++- pymc/tests/logprob/test_mixture.py | 35 +++++++++++++++++++++++++++--- 2 files changed, 43 insertions(+), 4 deletions(-) diff --git a/pymc/logprob/mixture.py b/pymc/logprob/mixture.py index 0b0d7b7cf54..38383bac770 100644 --- a/pymc/logprob/mixture.py +++ b/pymc/logprob/mixture.py @@ -443,13 +443,23 @@ def logprob_MixtureRV( logp_val = at.set_subtensor(logp_val[idx_m_on_axis], logp_m) else: + # If the stacking operation expands the component RVs, we have + # to expand the value and later squeeze the logprob for everything + # to work correctly + join_axis_val = None if isinstance(join_axis.type, NoneTypeT) else join_axis.data + + if join_axis_val is not None: + value = at.expand_dims(value, axis=join_axis_val) + logp_val = 0.0 for i, comp_rv in enumerate(comp_rvs): comp_logp = logprob(comp_rv, value) + if join_axis_val is not None: + comp_logp = at.squeeze(comp_logp, axis=join_axis_val) logp_val += ifelse( at.eq(indices[0], i), comp_logp, - at.zeros_like(value), + at.zeros_like(comp_logp), ) return logp_val diff --git a/pymc/tests/logprob/test_mixture.py b/pymc/tests/logprob/test_mixture.py index e188052bb10..f143ad81e75 100644 --- a/pymc/tests/logprob/test_mixture.py +++ b/pymc/tests/logprob/test_mixture.py @@ -226,6 +226,26 @@ def test_hetero_mixture_binomial(p_val, size): (), 0, ), + # Degenerate vector mixture components, scalar index + ( + ( + np.array([0], dtype=pytensor.config.floatX), + np.array(1, dtype=pytensor.config.floatX), + ), + ( + np.array([0.5], dtype=pytensor.config.floatX), + np.array(0.5, dtype=pytensor.config.floatX), + ), + ( + np.array([100], dtype=pytensor.config.floatX), + np.array(1, dtype=pytensor.config.floatX), + ), + np.array([0.1, 0.5, 0.4], dtype=pytensor.config.floatX), + None, + (), + (), + 0, + ), # Scalar mixture components, vector index ( ( @@ -443,6 +463,9 @@ def test_hetero_mixture_categorical( gamma_sp = sp.gamma(Y_args[0], scale=1 / Y_args[1]) norm_2_sp = sp.norm(loc=Z_args[0], scale=Z_args[1]) + # Handle scipy annoying squeeze of random draws + real_comp_size = tuple(X_rv.shape.eval()) + for i in range(10): i_val = CategoricalRV.rng_fn(test_val_rng, p_val, idx_size) @@ -450,9 +473,15 @@ def test_hetero_mixture_categorical( indices_val.insert(join_axis, i_val) indices_val = tuple(indices_val) - x_val = norm_1_sp.rvs(size=comp_size, random_state=test_val_rng) - y_val = gamma_sp.rvs(size=comp_size, random_state=test_val_rng) - z_val = norm_2_sp.rvs(size=comp_size, random_state=test_val_rng) + x_val = np.broadcast_to( + norm_1_sp.rvs(size=comp_size, random_state=test_val_rng), real_comp_size + ) + y_val = np.broadcast_to( + gamma_sp.rvs(size=comp_size, random_state=test_val_rng), real_comp_size + ) + z_val = np.broadcast_to( + norm_2_sp.rvs(size=comp_size, random_state=test_val_rng), real_comp_size + ) component_logps = np.stack( [norm_1_sp.logpdf(x_val), gamma_sp.logpdf(y_val), norm_2_sp.logpdf(z_val)], From 4a2ce47fe1dad296a65cd44273a24a089b39fc14 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Thu, 15 Dec 2022 12:18:52 +0100 Subject: [PATCH 3/5] Add test showing pre-existing bug in Advanced Mixture logprob --- pymc/tests/logprob/test_mixture.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/pymc/tests/logprob/test_mixture.py b/pymc/tests/logprob/test_mixture.py index f143ad81e75..34f8f4fed71 100644 --- a/pymc/tests/logprob/test_mixture.py +++ b/pymc/tests/logprob/test_mixture.py @@ -323,6 +323,27 @@ def test_hetero_mixture_binomial(p_val, size): (), 0, ), + # Same as before but with degenerate vector parameters + pytest.param( + ( + np.array([0], dtype=pytensor.config.floatX), + np.array(1, dtype=pytensor.config.floatX), + ), + ( + np.array([0.5], dtype=pytensor.config.floatX), + np.array(0.5, dtype=pytensor.config.floatX), + ), + ( + np.array([100], dtype=pytensor.config.floatX), + np.array(1, dtype=pytensor.config.floatX), + ), + np.array([0.1, 0.5, 0.4], dtype=pytensor.config.floatX), + (2,), + (2,), + (), + 0, + marks=pytest.mark.xfail(IndexError, reason="Bug in AdvancedIndexing Mixture logprob"), + ), ( ( np.array(0, dtype=pytensor.config.floatX), From be097f748ef11fe6634403d31aaaf6d16aaf9a82 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Wed, 14 Dec 2022 11:16:44 +0100 Subject: [PATCH 4/5] Update PyTensor dependency --- conda-envs/environment-dev.yml | 2 +- conda-envs/environment-test.yml | 2 +- conda-envs/windows-environment-dev.yml | 2 +- conda-envs/windows-environment-test.yml | 2 +- pymc/distributions/distribution.py | 3 ++- pymc/distributions/timeseries.py | 3 ++- pymc/logprob/transforms.py | 3 ++- pymc/pytensorf.py | 2 +- pymc/sampling/jax.py | 5 +++-- pymc/smc/kernels.py | 2 +- pymc/tests/distributions/test_shape_utils.py | 8 ++++---- pymc/tests/test_initial_point.py | 2 +- pymc/tests/test_pytensorf.py | 2 +- requirements-dev.txt | 2 +- requirements.txt | 2 +- 15 files changed, 23 insertions(+), 19 deletions(-) diff --git a/conda-envs/environment-dev.yml b/conda-envs/environment-dev.yml index 0600471e7eb..6f40d2badfd 100644 --- a/conda-envs/environment-dev.yml +++ b/conda-envs/environment-dev.yml @@ -14,7 +14,7 @@ dependencies: - numpy>=1.15.0 - pandas>=0.24.0 - pip -- pytensor=2.8.10 +- pytensor=2.8.11 - python-graphviz - networkx - scipy>=1.4.1 diff --git a/conda-envs/environment-test.yml b/conda-envs/environment-test.yml index fc1decdec71..d1603b5720e 100644 --- a/conda-envs/environment-test.yml +++ b/conda-envs/environment-test.yml @@ -17,7 +17,7 @@ dependencies: - numpy>=1.15.0 - pandas>=0.24.0 - pip -- pytensor=2.8.10 +- pytensor=2.8.11 - python-graphviz - networkx - scipy>=1.4.1 diff --git a/conda-envs/windows-environment-dev.yml b/conda-envs/windows-environment-dev.yml index 0459beb8a8b..319a070105c 100644 --- a/conda-envs/windows-environment-dev.yml +++ b/conda-envs/windows-environment-dev.yml @@ -14,7 +14,7 @@ dependencies: - numpy>=1.15.0 - pandas>=0.24.0 - pip -- pytensor=2.8.10 +- pytensor=2.8.11 - python-graphviz - networkx - scipy>=1.4.1 diff --git a/conda-envs/windows-environment-test.yml b/conda-envs/windows-environment-test.yml index b637993a3f2..067e3f1008e 100644 --- a/conda-envs/windows-environment-test.yml +++ b/conda-envs/windows-environment-test.yml @@ -17,7 +17,7 @@ dependencies: - numpy>=1.15.0 - pandas>=0.24.0 - pip -- pytensor=2.8.10 +- pytensor=2.8.11 - python-graphviz - networkx - scipy>=1.4.1 diff --git a/pymc/distributions/distribution.py b/pymc/distributions/distribution.py index f59f5b633c6..6e4599c985c 100644 --- a/pymc/distributions/distribution.py +++ b/pymc/distributions/distribution.py @@ -26,7 +26,8 @@ from pytensor import tensor as at from pytensor.compile.builders import OpFromGraph from pytensor.graph import node_rewriter -from pytensor.graph.basic import Node, Variable, clone_replace +from pytensor.graph.basic import Node, Variable +from pytensor.graph.replace import clone_replace from pytensor.graph.rewriting.basic import in2out from pytensor.graph.utils import MetaType from pytensor.tensor.basic import as_tensor_variable diff --git a/pymc/distributions/timeseries.py b/pymc/distributions/timeseries.py index 1af675126ee..14cce247611 100644 --- a/pymc/distributions/timeseries.py +++ b/pymc/distributions/timeseries.py @@ -21,7 +21,8 @@ import pytensor import pytensor.tensor as at -from pytensor.graph.basic import Node, clone_replace +from pytensor.graph.basic import Node +from pytensor.graph.replace import clone_replace from pytensor.tensor import TensorVariable from pytensor.tensor.random.op import RandomVariable diff --git a/pymc/logprob/transforms.py b/pymc/logprob/transforms.py index 8ae8b432409..463c59d1fd3 100644 --- a/pymc/logprob/transforms.py +++ b/pymc/logprob/transforms.py @@ -42,10 +42,11 @@ import pytensor.tensor as at from pytensor.gradient import DisconnectedType, jacobian -from pytensor.graph.basic import Apply, Node, Variable, clone_replace +from pytensor.graph.basic import Apply, Node, Variable from pytensor.graph.features import AlreadyThere, Feature from pytensor.graph.fg import FunctionGraph from pytensor.graph.op import Op +from pytensor.graph.replace import clone_replace from pytensor.graph.rewriting.basic import GraphRewriter, in2out, node_rewriter from pytensor.scalar import Add, Exp, Log, Mul from pytensor.scan.op import Scan diff --git a/pymc/pytensorf.py b/pymc/pytensorf.py index 211c4a99a68..c92bed6ef1b 100644 --- a/pymc/pytensorf.py +++ b/pymc/pytensorf.py @@ -1082,7 +1082,7 @@ def collect_default_updates( assert random_var.owner.op is not None if isinstance(random_var.owner.op, RandomVariable): rng = random_var.owner.inputs[0] - if hasattr(rng, "default_update"): + if getattr(rng, "default_update", None) is not None: update_map = {rng: rng.default_update} else: update_map = {rng: random_var.owner.outputs[0]} diff --git a/pymc/sampling/jax.py b/pymc/sampling/jax.py index c3ac07fa33a..3f017c5a8f6 100644 --- a/pymc/sampling/jax.py +++ b/pymc/sampling/jax.py @@ -22,8 +22,9 @@ from arviz.data.base import make_attrs from pytensor.compile import SharedVariable, Supervisor, mode -from pytensor.graph.basic import clone_replace, graph_inputs +from pytensor.graph.basic import graph_inputs from pytensor.graph.fg import FunctionGraph +from pytensor.graph.replace import clone_replace from pytensor.link.jax.dispatch import jax_funcify from pytensor.raise_op import Assert from pytensor.tensor import TensorVariable @@ -70,7 +71,7 @@ def _replace_shared_variables(graph: List[TensorVariable]) -> List[TensorVariabl shared_variables = [var for var in graph_inputs(graph) if isinstance(var, SharedVariable)] - if any(hasattr(var, "default_update") for var in shared_variables): + if any(var.default_update is not None for var in shared_variables): raise ValueError( "Graph contains shared variables with default_update which cannot " "be safely replaced." diff --git a/pymc/smc/kernels.py b/pymc/smc/kernels.py index 9b7a02e11cf..7cdbfd3166f 100644 --- a/pymc/smc/kernels.py +++ b/pymc/smc/kernels.py @@ -21,7 +21,7 @@ import numpy as np import pytensor.tensor as at -from pytensor.graph.basic import clone_replace +from pytensor.graph.replace import clone_replace from scipy.special import logsumexp from scipy.stats import multivariate_normal diff --git a/pymc/tests/distributions/test_shape_utils.py b/pymc/tests/distributions/test_shape_utils.py index b335fefbdc2..dcd78dab8f5 100644 --- a/pymc/tests/distributions/test_shape_utils.py +++ b/pymc/tests/distributions/test_shape_utils.py @@ -562,13 +562,13 @@ def test_change_rv_size_default_update(): new_x = change_dist_size(x, new_size=(2,)) new_rng = new_x.owner.inputs[0] assert rng.default_update is next_rng - assert not hasattr(new_rng, "default_update") + assert new_rng.default_update is None - # Test that default_update is not set if there was none before - del rng.default_update + # Test that default_update is not set if it was None before + rng.default_update = None new_x = change_dist_size(x, new_size=(2,)) new_rng = new_x.owner.inputs[0] - assert not hasattr(new_rng, "default_update") + assert new_rng.default_update is None def test_change_specify_shape_size_univariate(): diff --git a/pymc/tests/test_initial_point.py b/pymc/tests/test_initial_point.py index a9fdb5a6200..400b510a599 100644 --- a/pymc/tests/test_initial_point.py +++ b/pymc/tests/test_initial_point.py @@ -238,7 +238,7 @@ def test_numeric_moment_shape(self, rv_cls): @pytest.mark.parametrize("rv_cls", [pm.Flat, pm.HalfFlat]) def test_symbolic_moment_shape(self, rv_cls): - s = at.scalar() + s = at.scalar(dtype="int64") rv = rv_cls.dist(shape=(s,)) assert not hasattr(rv.tag, "test_value") assert tuple(moment(rv).shape.eval({s: 4})) == (4,) diff --git a/pymc/tests/test_pytensorf.py b/pymc/tests/test_pytensorf.py index ce02dbae18f..983e5275aaf 100644 --- a/pymc/tests/test_pytensorf.py +++ b/pymc/tests/test_pytensorf.py @@ -336,7 +336,7 @@ def test_compile_pymc_sets_rng_updates(self): assert not np.isclose(f(), f()) # Check that update was not done inplace - assert not hasattr(rng, "default_update") + assert rng.default_update is None f = pytensor.function([], x) assert f() == f() diff --git a/requirements-dev.txt b/requirements-dev.txt index 714bc30d4f8..c1ef9b70e80 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -17,7 +17,7 @@ numpydoc pandas>=0.24.0 polyagamma pre-commit>=2.8.0 -pytensor==2.8.10 +pytensor==2.8.11 pytest-cov>=2.5 pytest>=3.0 scipy>=1.4.1 diff --git a/requirements.txt b/requirements.txt index 4f64eb94740..cc6e6afc177 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,6 +4,6 @@ cloudpickle fastprogress>=0.2.0 numpy>=1.15.0 pandas>=0.24.0 -pytensor==2.8.10 +pytensor==2.8.11 scipy>=1.4.1 typing-extensions>=3.7.4 From b54b2d88eb17fadc5df591a260d9dfdd8caf2c3b Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Thu, 15 Dec 2022 11:50:03 +0100 Subject: [PATCH 5/5] Workaround buggy AdvancedIndexing Mixture logprob Now that Dimshuffle lift broadcasts both the parameters and the size, the AdvancedIndexing logprob fails most of the times, even though these are valid graphs. All but one of the failing cases can be helped by introducing the local_rv_size_lift rewrite. --- pymc/logprob/mixture.py | 2 ++ pymc/tests/logprob/test_mixture.py | 6 +++--- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/pymc/logprob/mixture.py b/pymc/logprob/mixture.py index 38383bac770..9f2e45e1dc7 100644 --- a/pymc/logprob/mixture.py +++ b/pymc/logprob/mixture.py @@ -53,6 +53,7 @@ from pytensor.tensor.elemwise import Elemwise from pytensor.tensor.random.rewriting import ( local_dimshuffle_rv_lift, + local_rv_size_lift, local_subtensor_rv_lift, ) from pytensor.tensor.shape import shape_tuple @@ -210,6 +211,7 @@ def rv_pull_down(x: TensorVariable, dont_touch_vars=None) -> TensorVariable: return pre_greedy_node_rewriter( fgraph, [ + local_rv_size_lift, local_dimshuffle_rv_lift, local_subtensor_rv_lift, naive_bcast_rv_lift, diff --git a/pymc/tests/logprob/test_mixture.py b/pymc/tests/logprob/test_mixture.py index 34f8f4fed71..e29c7abc5e7 100644 --- a/pymc/tests/logprob/test_mixture.py +++ b/pymc/tests/logprob/test_mixture.py @@ -324,7 +324,7 @@ def test_hetero_mixture_binomial(p_val, size): 0, ), # Same as before but with degenerate vector parameters - pytest.param( + ( ( np.array([0], dtype=pytensor.config.floatX), np.array(1, dtype=pytensor.config.floatX), @@ -342,7 +342,6 @@ def test_hetero_mixture_binomial(p_val, size): (2,), (), 0, - marks=pytest.mark.xfail(IndexError, reason="Bug in AdvancedIndexing Mixture logprob"), ), ( ( @@ -382,7 +381,7 @@ def test_hetero_mixture_binomial(p_val, size): (), 0, ), - ( + pytest.param( ( np.array(0, dtype=pytensor.config.floatX), np.array(1, dtype=pytensor.config.floatX), @@ -400,6 +399,7 @@ def test_hetero_mixture_binomial(p_val, size): (3,), (slice(None),), 1, + marks=pytest.mark.xfail(IndexError, reason="Bug in AdvancedIndex Mixture logprob"), ), ( (