Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update PyTensor dependency and fix bugs in inferred mixture logprob #6397

Merged
merged 5 commits into from
Dec 15, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion conda-envs/environment-dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ dependencies:
- numpy>=1.15.0
- pandas>=0.24.0
- pip
- pytensor=2.8.10
- pytensor=2.8.11
- python-graphviz
- networkx
- scipy>=1.4.1
Expand Down
2 changes: 1 addition & 1 deletion conda-envs/environment-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ dependencies:
- numpy>=1.15.0
- pandas>=0.24.0
- pip
- pytensor=2.8.10
- pytensor=2.8.11
- python-graphviz
- networkx
- scipy>=1.4.1
Expand Down
2 changes: 1 addition & 1 deletion conda-envs/windows-environment-dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ dependencies:
- numpy>=1.15.0
- pandas>=0.24.0
- pip
- pytensor=2.8.10
- pytensor=2.8.11
- python-graphviz
- networkx
- scipy>=1.4.1
Expand Down
2 changes: 1 addition & 1 deletion conda-envs/windows-environment-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ dependencies:
- numpy>=1.15.0
- pandas>=0.24.0
- pip
- pytensor=2.8.10
- pytensor=2.8.11
- python-graphviz
- networkx
- scipy>=1.4.1
Expand Down
3 changes: 2 additions & 1 deletion pymc/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@
from pytensor import tensor as at
from pytensor.compile.builders import OpFromGraph
from pytensor.graph import node_rewriter
from pytensor.graph.basic import Node, Variable, clone_replace
from pytensor.graph.basic import Node, Variable
from pytensor.graph.replace import clone_replace
from pytensor.graph.rewriting.basic import in2out
from pytensor.graph.utils import MetaType
from pytensor.tensor.basic import as_tensor_variable
Expand Down
3 changes: 2 additions & 1 deletion pymc/distributions/timeseries.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@
import pytensor
import pytensor.tensor as at

from pytensor.graph.basic import Node, clone_replace
from pytensor.graph.basic import Node
from pytensor.graph.replace import clone_replace
from pytensor.tensor import TensorVariable
from pytensor.tensor.random.op import RandomVariable

Expand Down
14 changes: 13 additions & 1 deletion pymc/logprob/mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
from pytensor.tensor.elemwise import Elemwise
from pytensor.tensor.random.rewriting import (
local_dimshuffle_rv_lift,
local_rv_size_lift,
local_subtensor_rv_lift,
)
from pytensor.tensor.shape import shape_tuple
Expand Down Expand Up @@ -210,6 +211,7 @@ def rv_pull_down(x: TensorVariable, dont_touch_vars=None) -> TensorVariable:
return pre_greedy_node_rewriter(
fgraph,
[
local_rv_size_lift,
local_dimshuffle_rv_lift,
local_subtensor_rv_lift,
naive_bcast_rv_lift,
Expand Down Expand Up @@ -443,13 +445,23 @@ def logprob_MixtureRV(
logp_val = at.set_subtensor(logp_val[idx_m_on_axis], logp_m)

else:
# If the stacking operation expands the component RVs, we have
# to expand the value and later squeeze the logprob for everything
# to work correctly
join_axis_val = None if isinstance(join_axis.type, NoneTypeT) else join_axis.data

if join_axis_val is not None:
value = at.expand_dims(value, axis=join_axis_val)

logp_val = 0.0
for i, comp_rv in enumerate(comp_rvs):
comp_logp = logprob(comp_rv, value)
if join_axis_val is not None:
comp_logp = at.squeeze(comp_logp, axis=join_axis_val)
logp_val += ifelse(
at.eq(indices[0], i),
comp_logp,
at.zeros_like(value),
at.zeros_like(comp_logp),
)

return logp_val
Expand Down
3 changes: 2 additions & 1 deletion pymc/logprob/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,11 @@
import pytensor.tensor as at

from pytensor.gradient import DisconnectedType, jacobian
from pytensor.graph.basic import Apply, Node, Variable, clone_replace
from pytensor.graph.basic import Apply, Node, Variable
from pytensor.graph.features import AlreadyThere, Feature
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.op import Op
from pytensor.graph.replace import clone_replace
from pytensor.graph.rewriting.basic import GraphRewriter, in2out, node_rewriter
from pytensor.scalar import Add, Exp, Log, Mul
from pytensor.scan.op import Scan
Expand Down
2 changes: 1 addition & 1 deletion pymc/pytensorf.py
Original file line number Diff line number Diff line change
Expand Up @@ -1082,7 +1082,7 @@ def collect_default_updates(
assert random_var.owner.op is not None
if isinstance(random_var.owner.op, RandomVariable):
rng = random_var.owner.inputs[0]
if hasattr(rng, "default_update"):
if getattr(rng, "default_update", None) is not None:
update_map = {rng: rng.default_update}
else:
update_map = {rng: random_var.owner.outputs[0]}
Expand Down
5 changes: 3 additions & 2 deletions pymc/sampling/jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,9 @@

from arviz.data.base import make_attrs
from pytensor.compile import SharedVariable, Supervisor, mode
from pytensor.graph.basic import clone_replace, graph_inputs
from pytensor.graph.basic import graph_inputs
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.replace import clone_replace
from pytensor.link.jax.dispatch import jax_funcify
from pytensor.raise_op import Assert
from pytensor.tensor import TensorVariable
Expand Down Expand Up @@ -70,7 +71,7 @@ def _replace_shared_variables(graph: List[TensorVariable]) -> List[TensorVariabl

shared_variables = [var for var in graph_inputs(graph) if isinstance(var, SharedVariable)]

if any(hasattr(var, "default_update") for var in shared_variables):
if any(var.default_update is not None for var in shared_variables):
raise ValueError(
"Graph contains shared variables with default_update which cannot "
"be safely replaced."
Expand Down
2 changes: 1 addition & 1 deletion pymc/smc/kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import numpy as np
import pytensor.tensor as at

from pytensor.graph.basic import clone_replace
from pytensor.graph.replace import clone_replace
from scipy.special import logsumexp
from scipy.stats import multivariate_normal

Expand Down
8 changes: 4 additions & 4 deletions pymc/tests/distributions/test_shape_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,13 +562,13 @@ def test_change_rv_size_default_update():
new_x = change_dist_size(x, new_size=(2,))
new_rng = new_x.owner.inputs[0]
assert rng.default_update is next_rng
assert not hasattr(new_rng, "default_update")
assert new_rng.default_update is None

# Test that default_update is not set if there was none before
del rng.default_update
# Test that default_update is not set if it was None before
rng.default_update = None
new_x = change_dist_size(x, new_size=(2,))
new_rng = new_x.owner.inputs[0]
assert not hasattr(new_rng, "default_update")
assert new_rng.default_update is None


def test_change_specify_shape_size_univariate():
Expand Down
49 changes: 40 additions & 9 deletions pymc/tests/logprob/test_mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,26 @@ def test_hetero_mixture_binomial(p_val, size):
(),
0,
),
# Degenerate vector mixture components, scalar index
(
(
np.array([0], dtype=pytensor.config.floatX),
np.array(1, dtype=pytensor.config.floatX),
),
(
np.array([0.5], dtype=pytensor.config.floatX),
np.array(0.5, dtype=pytensor.config.floatX),
),
(
np.array([100], dtype=pytensor.config.floatX),
np.array(1, dtype=pytensor.config.floatX),
),
np.array([0.1, 0.5, 0.4], dtype=pytensor.config.floatX),
None,
(),
(),
0,
),
# Scalar mixture components, vector index
(
(
Expand Down Expand Up @@ -298,22 +318,23 @@ def test_hetero_mixture_binomial(p_val, size):
np.array(1, dtype=pytensor.config.floatX),
),
np.array([0.1, 0.5, 0.4], dtype=pytensor.config.floatX),
(),
(),
(2,),
(2,),
(),
0,
),
# Same as before but with degenerate vector parameters
(
(
np.array(0, dtype=pytensor.config.floatX),
np.array([0], dtype=pytensor.config.floatX),
np.array(1, dtype=pytensor.config.floatX),
),
(
np.array(0.5, dtype=pytensor.config.floatX),
np.array([0.5], dtype=pytensor.config.floatX),
np.array(0.5, dtype=pytensor.config.floatX),
),
(
np.array(100, dtype=pytensor.config.floatX),
np.array([100], dtype=pytensor.config.floatX),
np.array(1, dtype=pytensor.config.floatX),
),
np.array([0.1, 0.5, 0.4], dtype=pytensor.config.floatX),
Expand Down Expand Up @@ -360,7 +381,7 @@ def test_hetero_mixture_binomial(p_val, size):
(),
0,
),
(
pytest.param(
(
np.array(0, dtype=pytensor.config.floatX),
np.array(1, dtype=pytensor.config.floatX),
Expand All @@ -378,6 +399,7 @@ def test_hetero_mixture_binomial(p_val, size):
(3,),
(slice(None),),
1,
marks=pytest.mark.xfail(IndexError, reason="Bug in AdvancedIndex Mixture logprob"),
),
(
(
Expand Down Expand Up @@ -462,16 +484,25 @@ def test_hetero_mixture_categorical(
gamma_sp = sp.gamma(Y_args[0], scale=1 / Y_args[1])
norm_2_sp = sp.norm(loc=Z_args[0], scale=Z_args[1])

# Handle scipy annoying squeeze of random draws
real_comp_size = tuple(X_rv.shape.eval())

for i in range(10):
i_val = CategoricalRV.rng_fn(test_val_rng, p_val, idx_size)

indices_val = list(extra_indices)
indices_val.insert(join_axis, i_val)
indices_val = tuple(indices_val)

x_val = norm_1_sp.rvs(size=comp_size, random_state=test_val_rng)
y_val = gamma_sp.rvs(size=comp_size, random_state=test_val_rng)
z_val = norm_2_sp.rvs(size=comp_size, random_state=test_val_rng)
x_val = np.broadcast_to(
norm_1_sp.rvs(size=comp_size, random_state=test_val_rng), real_comp_size
)
y_val = np.broadcast_to(
gamma_sp.rvs(size=comp_size, random_state=test_val_rng), real_comp_size
)
z_val = np.broadcast_to(
norm_2_sp.rvs(size=comp_size, random_state=test_val_rng), real_comp_size
)

component_logps = np.stack(
[norm_1_sp.logpdf(x_val), gamma_sp.logpdf(y_val), norm_2_sp.logpdf(z_val)],
Expand Down
2 changes: 1 addition & 1 deletion pymc/tests/test_initial_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ def test_numeric_moment_shape(self, rv_cls):

@pytest.mark.parametrize("rv_cls", [pm.Flat, pm.HalfFlat])
def test_symbolic_moment_shape(self, rv_cls):
s = at.scalar()
s = at.scalar(dtype="int64")
rv = rv_cls.dist(shape=(s,))
assert not hasattr(rv.tag, "test_value")
assert tuple(moment(rv).shape.eval({s: 4})) == (4,)
Expand Down
2 changes: 1 addition & 1 deletion pymc/tests/test_pytensorf.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,7 @@ def test_compile_pymc_sets_rng_updates(self):
assert not np.isclose(f(), f())

# Check that update was not done inplace
assert not hasattr(rng, "default_update")
assert rng.default_update is None
f = pytensor.function([], x)
assert f() == f()

Expand Down
2 changes: 1 addition & 1 deletion requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ numpydoc
pandas>=0.24.0
polyagamma
pre-commit>=2.8.0
pytensor==2.8.10
pytensor==2.8.11
pytest-cov>=2.5
pytest>=3.0
scipy>=1.4.1
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,6 @@ cloudpickle
fastprogress>=0.2.0
numpy>=1.15.0
pandas>=0.24.0
pytensor==2.8.10
pytensor==2.8.11
scipy>=1.4.1
typing-extensions>=3.7.4