Skip to content

Commit 89d567a

Browse files
committed
Numba dispatch of ScalarLoop
1 parent 8dd6f7e commit 89d567a

File tree

3 files changed

+156
-10
lines changed

3 files changed

+156
-10
lines changed

pytensor/link/numba/dispatch/scalar.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
get_name_for_object,
1717
unique_name_generator,
1818
)
19+
from pytensor.scalar import ScalarLoop
1920
from pytensor.scalar.basic import (
2021
Add,
2122
Cast,
@@ -336,3 +337,52 @@ def softplus(x):
336337
return numba_basic.direct_cast(value, out_dtype)
337338

338339
return softplus, scalar_op_cache_key(op)
340+
341+
342+
@register_funcify_and_cache_key(ScalarLoop)
343+
def numba_funcify_ScalarLoop(op, node, **kwargs):
344+
inner_fn, inner_fn_cache_key = numba_funcify_and_cache_key(op.fgraph)
345+
if inner_fn_cache_key is None:
346+
loop_cache_key = None
347+
else:
348+
loop_cache_key = sha256(
349+
str((type(op), op.is_while, inner_fn_cache_key)).encode()
350+
).hexdigest()
351+
352+
if op.is_while:
353+
n_update = len(op.outputs) - 1
354+
355+
@numba_basic.numba_njit
356+
def while_loop(n_steps, *inputs):
357+
carry, constant = inputs[:n_update], inputs[n_update:]
358+
359+
until = False
360+
for i in range(n_steps):
361+
outputs = inner_fn(*carry, *constant)
362+
carry, until = outputs[:-1], outputs[-1]
363+
if until:
364+
break
365+
366+
return *carry, until
367+
368+
return while_loop, loop_cache_key
369+
370+
else:
371+
n_update = len(op.outputs)
372+
373+
@numba_basic.numba_njit
374+
def for_loop(n_steps, *inputs):
375+
carry, constant = inputs[:n_update], inputs[n_update:]
376+
377+
if n_steps < 0:
378+
raise ValueError("ScalarLoop does not have a termination condition.")
379+
380+
for i in range(n_steps):
381+
carry = inner_fn(*carry, *constant)
382+
383+
if n_update == 1:
384+
return carry[0]
385+
else:
386+
return carry
387+
388+
return for_loop, loop_cache_key

tests/link/numba/test_elemwise.py

Lines changed: 34 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -587,18 +587,42 @@ def test_elemwise_multiple_inplace_outs():
587587

588588

589589
def test_scalar_loop():
590-
a = float64("a")
591-
scalar_loop = pytensor.scalar.ScalarLoop([a], [a + a])
590+
a_scalar = float64("a")
591+
const_scalar = float64("const")
592+
scalar_loop = pytensor.scalar.ScalarLoop(
593+
init=[a_scalar],
594+
update=[a_scalar + a_scalar + const_scalar],
595+
constant=[const_scalar],
596+
)
592597

593-
x = pt.tensor("x", shape=(3,))
594-
elemwise_loop = Elemwise(scalar_loop)(3, x)
598+
a = pt.tensor("a", shape=(3,))
599+
const = pt.tensor("const", shape=(3,))
600+
n_steps = 3
601+
elemwise_loop = Elemwise(scalar_loop)(n_steps, a, const)
595602

596-
with pytest.warns(UserWarning, match="object mode"):
597-
compare_numba_and_py(
598-
[x],
599-
[elemwise_loop],
600-
(np.array([1, 2, 3], dtype="float64"),),
601-
)
603+
compare_numba_and_py(
604+
[a, const],
605+
[elemwise_loop],
606+
[np.array([1, 2, 3], dtype="float64"), np.array([1, 1, 1], dtype="float64")],
607+
)
608+
609+
610+
def test_gammainc_wrt_k_grad():
611+
x = pt.vector("x", dtype="float64")
612+
k = pt.vector("k", dtype="float64")
613+
614+
out = pt.gammainc(k, x)
615+
grad_out = grad(out.sum(), k)
616+
617+
compare_numba_and_py(
618+
[x, k],
619+
[grad_out],
620+
# These values of x and k trigger all the branches in the gradient of gammainc
621+
[
622+
np.array([0.0, 29.0, 31.0], dtype="float64"),
623+
np.array([1.0, 13.0, 11.0], dtype="float64"),
624+
],
625+
)
602626

603627

604628
class TestsBenchmark:

tests/link/numba/test_scalar.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import pytensor.scalar.math as psm
77
import pytensor.tensor as pt
88
from pytensor import config, function
9+
from pytensor.scalar import ScalarLoop
910
from pytensor.scalar.basic import Composite
1011
from pytensor.tensor import tensor
1112
from pytensor.tensor.elemwise import Elemwise
@@ -194,3 +195,74 @@ def test_discrete_power():
194195
compare_numba_and_py(
195196
[x, exponent], [out], [np.array(0.5), np.array(2, dtype="int8")]
196197
)
198+
199+
200+
class TestScalarLoop:
201+
def test_scalar_for_loop_single_out(self):
202+
n_steps = ps.int64("n_steps")
203+
x0 = ps.float64("x0")
204+
const = ps.float64("const")
205+
x = x0 + const
206+
207+
op = ScalarLoop(init=[x0], constant=[const], update=[x])
208+
x = op(n_steps, x0, const)
209+
210+
fn = function([n_steps, x0, const], [x], mode=numba_mode)
211+
212+
res_x = fn(n_steps=5, x0=0, const=1)
213+
np.testing.assert_allclose(res_x, 5)
214+
215+
res_x = fn(n_steps=5, x0=0, const=2)
216+
np.testing.assert_allclose(res_x, 10)
217+
218+
res_x = fn(n_steps=4, x0=3, const=-1)
219+
np.testing.assert_allclose(res_x, -1)
220+
221+
def test_scalar_for_loop_multiple_outs(self):
222+
n_steps = ps.int64("n_steps")
223+
x0 = ps.float64("x0")
224+
y0 = ps.int64("y0")
225+
const = ps.float64("const")
226+
x = x0 + const
227+
y = y0 + 1
228+
229+
op = ScalarLoop(init=[x0, y0], constant=[const], update=[x, y])
230+
x, y = op(n_steps, x0, y0, const)
231+
232+
fn = function([n_steps, x0, y0, const], [x, y], mode=numba_mode)
233+
234+
res_x, res_y = fn(n_steps=5, x0=0, y0=0, const=1)
235+
np.testing.assert_allclose(res_x, 5)
236+
np.testing.assert_allclose(res_y, 5)
237+
238+
res_x, res_y = fn(n_steps=5, x0=0, y0=0, const=2)
239+
np.testing.assert_allclose(res_x, 10)
240+
np.testing.assert_allclose(res_y, 5)
241+
242+
res_x, res_y = fn(n_steps=4, x0=3, y0=2, const=-1)
243+
np.testing.assert_allclose(res_x, -1)
244+
np.testing.assert_allclose(res_y, 6)
245+
246+
def test_scalar_while_loop(self):
247+
n_steps = ps.int64("n_steps")
248+
x0 = ps.float64("x0")
249+
x = x0 + 1
250+
until = x >= 10
251+
252+
op = ScalarLoop(init=[x0], update=[x], until=until)
253+
fn = function([n_steps, x0], op(n_steps, x0), mode=numba_mode)
254+
np.testing.assert_allclose(fn(n_steps=20, x0=0), [10, True])
255+
np.testing.assert_allclose(fn(n_steps=20, x0=1), [10, True])
256+
np.testing.assert_allclose(fn(n_steps=5, x0=1), [6, False])
257+
np.testing.assert_allclose(fn(n_steps=0, x0=1), [1, False])
258+
259+
def test_loop_with_cython_wrapped_op(self):
260+
x = ps.float64("x")
261+
op = ScalarLoop(init=[x], update=[ps.psi(x)])
262+
out = op(1, x)
263+
264+
fn = function([x], out, mode=numba_mode)
265+
x_test = np.float64(0.5)
266+
res = fn(x_test)
267+
expected_res = ps.psi(x).eval({x: x_test})
268+
np.testing.assert_allclose(res, expected_res)

0 commit comments

Comments
 (0)