Skip to content

Commit

Permalink
Add test for Blockwise logp regression
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Dec 21, 2023
1 parent 986738f commit 118be0f
Showing 1 changed file with 16 additions and 1 deletion.
17 changes: 16 additions & 1 deletion tests/distributions/test_multivariate.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,13 @@
import numpy.random as npr
import numpy.testing as npt
import pytensor
import pytensor.tensor as pt
import pytest
import scipy.special as sp
import scipy.stats as st

from pytensor import tensor as pt
from pytensor.tensor import TensorVariable
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.random.utils import broadcast_params
from pytensor.tensor.slinalg import Cholesky

Expand Down Expand Up @@ -2387,6 +2387,21 @@ def test_mvnormal_no_cholesky_in_model_logp():
assert not contains_cholesky_op(logp_dlogp._pytensor_function.maker.fgraph)


def test_mvnormal_blockwise_solve_opt():
"""Check that no blockwise show up in the d/logp graph of a 2D MvNormal with a single covariance.
See #6993
"""
with pm.Model() as m:
pm.MvNormal("y", mu=0, cov=pt.diag([2, 2]), shape=(3, 2))

logp = m.logp()
dlogp = pytensor.grad(logp, wrt=m.value_vars[0])
fn = m.compile_fn(inputs=m.value_vars, outs=[logp, dlogp], point_fn=False)

assert not any(isinstance(node.op, Blockwise) for node in fn.maker.fgraph.apply_nodes)


def test_mvnormal_mu_convenience():
"""Test that mu is broadcasted to the length of cov and provided a default of zero"""
x = pm.MvNormal.dist(cov=np.eye(3))
Expand Down

0 comments on commit 118be0f

Please sign in to comment.