Skip to content

Commit

Permalink
add aesara assert
Browse files Browse the repository at this point in the history
  • Loading branch information
aerubanov committed Jul 14, 2021
1 parent 35c4a43 commit 5ddd1d5
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 6 deletions.
9 changes: 5 additions & 4 deletions pymc3/distributions/multivariate.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import numpy as np
import scipy

from aesara.assert_op import Assert
from aesara.graph.basic import Apply
from aesara.graph.op import Op
from aesara.sparse.basic import sp_sum
Expand Down Expand Up @@ -1939,12 +1940,12 @@ def make_node(self, rng, size, dtype, mu, W, alpha, tau):
raise ValueError("W must be a matrix (ndim=2).")

sparse = isinstance(W, aesara.sparse.SparseVariable)
msg = "W must be a symmetric adjacency matrix."
if sparse:
if not at.isclose(aesara.sparse.basic.sp_sum(W - W.T), 0):
raise ValueError("W must be a symmetric adjacency matrix.")
abs_diff = aesara.sparse.basic.mul(aesara.sparse.basic.sgn(W - W.T), W - W.T)
W = Assert(msg)(W, at.isclose(aesara.sparse.basic.sp_sum(abs_diff), 0))
else:
if not at.allclose(W, W.T):
raise ValueError("W must be a symmetric adjacency matrix.")
W = Assert(msg)(W, at.allclose(W, W.T))

tau = at.as_tensor_variable(floatX(tau))
alpha = at.as_tensor_variable(floatX(alpha))
Expand Down
6 changes: 4 additions & 2 deletions pymc3/tests/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3093,6 +3093,7 @@ def test_car_symmetry_check(sparse):
tau = 2
alpha = 0.5
mu = np.zeros(4)
xs = np.random.randn(*mu.shape)

# non-symmetric matrix
W = np.array(
Expand All @@ -3102,8 +3103,9 @@ def test_car_symmetry_check(sparse):
if sparse:
W = aesara.sparse.csr_from_dense(W)

with pytest.raises(ValueError):
car_dist = CAR.dist(mu, W, alpha, tau)
car_dist = CAR.dist(mu, W, alpha, tau)
with pytest.raises(AssertionError):
logp(car_dist, xs).eval()


class TestBugfixes:
Expand Down

0 comments on commit 5ddd1d5

Please sign in to comment.