Skip to content

Commit fa177fb

Browse files
committed
Workaround numba/numba#9554
1 parent 39fa1fc commit fa177fb

File tree

2 files changed

+28
-0
lines changed

2 files changed

+28
-0
lines changed

pytensor/link/numba/dispatch/scalar.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
Composite,
2424
Identity,
2525
Mul,
26+
Pow,
2627
Reciprocal,
2728
ScalarOp,
2829
Second,
@@ -171,6 +172,23 @@ def {binary_op_name}({input_signature}):
171172
return nary_fn
172173

173174

175+
@register_funcify_and_cache_key(Pow)
176+
def numba_funcify_Pow(op, node, **kwargs):
177+
pow_dtype = node.inputs[1].type.dtype
178+
if pow_dtype.startswith("int"):
179+
# Numba power fails when exponents are non 64-bit discrete integers and fasthmath=True
180+
# https://github.com/numba/numba/issues/9554
181+
182+
def pow(x, y):
183+
return x ** np.asarray(y, dtype=np.int64).item()
184+
else:
185+
186+
def pow(x, y):
187+
return x**y
188+
189+
return numba_basic.numba_njit(pow), scalar_op_cache_key(op)
190+
191+
174192
@register_funcify_and_cache_key(Add)
175193
def numba_funcify_Add(op, node, **kwargs):
176194
nary_add_fn = binary_to_nary_func(node.inputs, "add", "+")

tests/link/numba/test_scalar.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,3 +184,13 @@ def test_Softplus(dtype):
184184
strict=True,
185185
err_msg=f"Failed for value {value}",
186186
)
187+
188+
189+
def test_discrete_power():
190+
# Test we don't fail to compile power with discrete exponents due to https://github.com/numba/numba/issues/9554
191+
x = pt.scalar("x", dtype="float64")
192+
exponent = pt.scalar("exponent", dtype="int8")
193+
out = pt.power(x, exponent)
194+
compare_numba_and_py(
195+
[x, exponent], [out], [np.array(0.5), np.array(2, dtype="int8")]
196+
)

0 commit comments

Comments
 (0)