Skip to content

Commit b1584f9

Browse files
authored
Merge branch 'pymc-devs:main' into issue-173
2 parents bbcd30c + b8831aa commit b1584f9

File tree

8 files changed

+176
-21
lines changed

8 files changed

+176
-21
lines changed

pytensor/link/jax/dispatch/nlinalg.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from pytensor.link.jax.dispatch import jax_funcify
44
from pytensor.tensor.blas import BatchedDot
55
from pytensor.tensor.math import Dot, MaxAndArgmax
6-
from pytensor.tensor.nlinalg import SVD, Det, Eig, Eigh, MatrixInverse, QRFull
6+
from pytensor.tensor.nlinalg import SVD, Det, Eig, Eigh, MatrixInverse, QRFull, SLogDet
77

88

99
@jax_funcify.register(SVD)
@@ -25,6 +25,14 @@ def det(x):
2525
return det
2626

2727

28+
@jax_funcify.register(SLogDet)
29+
def jax_funcify_SLogDet(op, **kwargs):
30+
def slogdet(x):
31+
return jnp.linalg.slogdet(x)
32+
33+
return slogdet
34+
35+
2836
@jax_funcify.register(Eig)
2937
def jax_funcify_Eig(op, **kwargs):
3038
def eig(x):

pytensor/link/numba/dispatch/nlinalg.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
MatrixInverse,
1919
MatrixPinv,
2020
QRFull,
21+
SLogDet,
2122
)
2223

2324

@@ -58,6 +59,25 @@ def det(x):
5859
return det
5960

6061

62+
@numba_funcify.register(SLogDet)
63+
def numba_funcify_SLogDet(op, node, **kwargs):
64+
65+
out_dtype_1 = node.outputs[0].type.numpy_dtype
66+
out_dtype_2 = node.outputs[1].type.numpy_dtype
67+
68+
inputs_cast = int_to_float_fn(node.inputs, out_dtype_1)
69+
70+
@numba_basic.numba_njit
71+
def slogdet(x):
72+
sign, det = np.linalg.slogdet(inputs_cast(x))
73+
return (
74+
numba_basic.direct_cast(sign, out_dtype_1),
75+
numba_basic.direct_cast(det, out_dtype_2),
76+
)
77+
78+
return slogdet
79+
80+
6181
@numba_funcify.register(Eig)
6282
def numba_funcify_Eig(op, node, **kwargs):
6383

pytensor/tensor/extra_ops.py

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1643,6 +1643,11 @@ def make_node(self, a, *shape):
16431643

16441644
shape, static_shape = at.infer_static_shape(shape)
16451645

1646+
if len(shape) < a.ndim:
1647+
raise ValueError(
1648+
f"Broadcast target shape has {len(shape)} dims, which is shorter than input with {a.ndim} dims"
1649+
)
1650+
16461651
out = TensorType(dtype=a.type.dtype, shape=static_shape)()
16471652

16481653
# Attempt to prevent in-place operations on this view-based output
@@ -1686,9 +1691,12 @@ def infer_shape(self, fgraph, node, ins_shapes):
16861691
return [node.inputs[1:]]
16871692

16881693
def c_code(self, node, name, inputs, outputs, sub):
1694+
inp_dims = node.inputs[0].ndim
1695+
out_dims = node.outputs[0].ndim
1696+
new_dims = out_dims - inp_dims
1697+
16891698
(x, *shape) = inputs
16901699
(out,) = outputs
1691-
ndims = len(shape)
16921700
fail = sub["fail"]
16931701

16941702
# TODO: Could just use `PyArray_Return`, no?
@@ -1701,20 +1709,34 @@ def c_code(self, node, name, inputs, outputs, sub):
17011709

17021710
src = (
17031711
"""
1704-
npy_intp itershape[%(ndims)s] = {%(dims_array)s};
1712+
npy_intp itershape[%(out_dims)s] = {%(dims_array)s};
17051713
1714+
NpyIter *iter;
17061715
PyArrayObject *ops[1] = {%(x)s};
17071716
npy_uint32 flags = NPY_ITER_MULTI_INDEX | NPY_ITER_REFS_OK | NPY_ITER_ZEROSIZE_OK;
17081717
npy_uint32 op_flags[1] = {NPY_ITER_READONLY};
17091718
PyArray_Descr *op_dtypes[1] = {NULL};
1710-
int oa_ndim = %(ndims)s;
1719+
int oa_ndim = %(out_dims)s;
17111720
int* op_axes[1] = {NULL};
17121721
npy_intp buffersize = 0;
17131722
1714-
NpyIter *iter = NpyIter_AdvancedNew(
1723+
for(int i = 0; i < %(inp_dims)s; i++)
1724+
{
1725+
if ((PyArray_DIMS(%(x)s)[i] != 1) && (PyArray_DIMS(%(x)s)[i] != itershape[i + %(new_dims)s]))
1726+
{
1727+
PyErr_Format(PyExc_ValueError,
1728+
"Shape mismatch in broadcast_to: target shape[%%i] = %%lld is incompatible with input shape = %%lld.",
1729+
i,
1730+
(long long int) itershape[i + %(new_dims)s],
1731+
(long long int) PyArray_DIMS(%(x)s)[i]
1732+
);
1733+
%(fail)s
1734+
}
1735+
}
1736+
1737+
iter = NpyIter_AdvancedNew(
17151738
1, ops, flags, NPY_CORDER, NPY_NO_CASTING, op_flags, op_dtypes, oa_ndim, op_axes, itershape, buffersize
17161739
);
1717-
17181740
%(out)s = NpyIter_GetIterView(iter, 0);
17191741
17201742
if(%(out)s == NULL){
@@ -1733,7 +1755,7 @@ def c_code(self, node, name, inputs, outputs, sub):
17331755
return src
17341756

17351757
def c_code_cache_version(self):
1736-
return (1,)
1758+
return (2,)
17371759

17381760

17391761
broadcast_to_ = BroadcastTo()

pytensor/tensor/nlinalg.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,39 @@ def __str__(self):
231231
det = Det()
232232

233233

234+
class SLogDet(Op):
235+
"""
236+
Compute the log determinant and its sign of the matrix. Input should be a square matrix.
237+
"""
238+
239+
__props__ = ()
240+
241+
def make_node(self, x):
242+
x = as_tensor_variable(x)
243+
assert x.ndim == 2
244+
sign = scalar(dtype=x.dtype)
245+
det = scalar(dtype=x.dtype)
246+
return Apply(self, [x], [sign, det])
247+
248+
def perform(self, node, inputs, outputs):
249+
(x,) = inputs
250+
(sign, det) = outputs
251+
try:
252+
sign[0], det[0] = (z.astype(x.dtype) for z in np.linalg.slogdet(x))
253+
except Exception:
254+
print("Failed to compute determinant", x)
255+
raise
256+
257+
def infer_shape(self, fgraph, node, shapes):
258+
return [(), ()]
259+
260+
def __str__(self):
261+
return "SLogDet"
262+
263+
264+
slogdet = SLogDet()
265+
266+
234267
class Eig(Op):
235268
"""
236269
Compute the eigenvalues and right eigenvectors of a square array.

tests/link/jax/test_nlinalg.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,10 @@ def assert_fn(x, y):
8585
out_fg = FunctionGraph([x], outs)
8686
compare_jax_and_py(out_fg, [X.astype(config.floatX)], assert_fn=assert_fn)
8787

88+
outs = at_nlinalg.slogdet(x)
89+
out_fg = FunctionGraph([x], outs)
90+
compare_jax_and_py(out_fg, [X.astype(config.floatX)], assert_fn=assert_fn)
91+
8892

8993
@pytest.mark.xfail(
9094
version_parse(jax.__version__) >= version_parse("0.2.12"),

tests/link/numba/test_nlinalg.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,41 @@ def test_Det(x, exc):
179179
)
180180

181181

182+
@pytest.mark.parametrize(
183+
"x, exc",
184+
[
185+
(
186+
set_test_value(
187+
at.dmatrix(),
188+
(lambda x: x.T.dot(x))(rng.random(size=(3, 3)).astype("float64")),
189+
),
190+
None,
191+
),
192+
(
193+
set_test_value(
194+
at.lmatrix(),
195+
(lambda x: x.T.dot(x))(rng.poisson(size=(3, 3)).astype("int64")),
196+
),
197+
None,
198+
),
199+
],
200+
)
201+
def test_SLogDet(x, exc):
202+
g = nlinalg.SLogDet()(x)
203+
g_fg = FunctionGraph(outputs=g)
204+
205+
cm = contextlib.suppress() if exc is None else pytest.warns(exc)
206+
with cm:
207+
compare_numba_and_py(
208+
g_fg,
209+
[
210+
i.tag.test_value
211+
for i in g_fg.inputs
212+
if not isinstance(i, (SharedVariable, Constant))
213+
],
214+
)
215+
216+
182217
# We were seeing some weird results in CI where the following two almost
183218
# sign-swapped results were being return from Numba and Python, respectively.
184219
# The issue might be related to https://github.com/numba/numba/issues/4519.

tests/tensor/test_extra_ops.py

Lines changed: 34 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1253,41 +1253,52 @@ def test_avoid_useless_subtensors(self):
12531253
@pytest.mark.parametrize("linker", ["cvm", "py"])
12541254
def test_perform(self, linker):
12551255

1256-
a = pytensor.shared(5)
1256+
a = pytensor.shared(np.full((3, 1, 1), 5))
1257+
s_0 = iscalar("s_0")
12571258
s_1 = iscalar("s_1")
1258-
shape = (s_1, 1)
1259+
shape = (s_0, s_1, 1)
12591260

12601261
bcast_res = broadcast_to(a, shape)
1261-
assert bcast_res.broadcastable == (False, True)
1262+
assert bcast_res.broadcastable == (False, False, True)
12621263

12631264
bcast_fn = pytensor.function(
1264-
[s_1], bcast_res, mode=Mode(optimizer=None, linker=linker)
1265+
[s_0, s_1], bcast_res, mode=Mode(optimizer=None, linker=linker)
12651266
)
12661267
bcast_fn.vm.allow_gc = False
12671268

1268-
bcast_at = bcast_fn(4)
1269-
bcast_np = np.broadcast_to(5, (4, 1))
1269+
bcast_at = bcast_fn(3, 4)
1270+
bcast_np = np.broadcast_to(5, (3, 4, 1))
12701271

12711272
assert np.array_equal(bcast_at, bcast_np)
12721273

1273-
bcast_var = bcast_fn.maker.fgraph.outputs[0].owner.inputs[0]
1274-
bcast_in = bcast_fn.vm.storage_map[a]
1275-
bcast_out = bcast_fn.vm.storage_map[bcast_var]
1274+
with pytest.raises(ValueError):
1275+
bcast_fn(5, 4)
12761276

12771277
if linker != "py":
1278+
bcast_var = bcast_fn.maker.fgraph.outputs[0].owner.inputs[0]
1279+
bcast_in = bcast_fn.vm.storage_map[a]
1280+
bcast_out = bcast_fn.vm.storage_map[bcast_var]
12781281
assert np.shares_memory(bcast_out[0], bcast_in[0])
12791282

1283+
def test_make_node_error_handling(self):
1284+
with pytest.raises(
1285+
ValueError,
1286+
match="Broadcast target shape has 1 dims, which is shorter than input with 2 dims",
1287+
):
1288+
broadcast_to(at.zeros((3, 4)), (5,))
1289+
12801290
@pytest.mark.skipif(
12811291
not config.cxx, reason="G++ not available, so we need to skip this test."
12821292
)
1283-
def test_memory_leak(self):
1293+
@pytest.mark.parametrize("valid", (True, False))
1294+
def test_memory_leak(self, valid):
12841295
import gc
12851296
import tracemalloc
12861297

12871298
from pytensor.link.c.cvm import CVM
12881299

12891300
n = 100_000
1290-
x = pytensor.shared(np.ones(n, dtype=np.float64))
1301+
x = pytensor.shared(np.ones((1, n), dtype=np.float64))
12911302
y = broadcast_to(x, (5, n))
12921303

12931304
f = pytensor.function([], y, mode=Mode(optimizer=None, linker="cvm"))
@@ -1303,8 +1314,17 @@ def test_memory_leak(self):
13031314
blocks_last = None
13041315
block_diffs = []
13051316
for i in range(1, 50):
1306-
x.set_value(np.ones(n))
1307-
_ = f()
1317+
if valid:
1318+
x.set_value(np.ones((1, n)))
1319+
_ = f()
1320+
else:
1321+
x.set_value(np.ones((2, n)))
1322+
try:
1323+
_ = f()
1324+
except ValueError:
1325+
pass
1326+
else:
1327+
raise RuntimeError("Should have failed")
13081328
_ = gc.collect()
13091329
blocks_i, _ = tracemalloc.get_traced_memory()
13101330
if blocks_last is not None:
@@ -1313,7 +1333,7 @@ def test_memory_leak(self):
13131333
blocks_last = blocks_i
13141334

13151335
tracemalloc.stop()
1316-
assert np.allclose(np.mean(block_diffs), 0)
1336+
assert np.all(np.array(block_diffs) <= (0 + 1e-8))
13171337

13181338
@pytest.mark.parametrize(
13191339
"fn,input_dims",

tests/tensor/test_nlinalg.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
norm,
2525
pinv,
2626
qr,
27+
slogdet,
2728
svd,
2829
tensorinv,
2930
tensorsolve,
@@ -280,6 +281,18 @@ def test_det_shape():
280281
assert tuple(det_shape.data) == ()
281282

282283

284+
def test_slogdet():
285+
rng = np.random.default_rng(utt.fetch_seed())
286+
287+
r = rng.standard_normal((5, 5)).astype(config.floatX)
288+
x = matrix()
289+
f = pytensor.function([x], slogdet(x))
290+
f_sign, f_det = f(r)
291+
sign, det = np.linalg.slogdet(r)
292+
assert np.equal(sign, f_sign)
293+
assert np.allclose(det, f_det)
294+
295+
283296
def test_trace():
284297
rng = np.random.default_rng(utt.fetch_seed())
285298
x = matrix()

0 commit comments

Comments
 (0)