Skip to content

Commit ae182f0

Browse files
Support read-only Numba types
1 parent 255ade4 commit ae182f0

File tree

6 files changed

+62
-22
lines changed

6 files changed

+62
-22
lines changed

aesara/link/numba/dispatch/basic.py

+25-10
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949

5050

5151
if TYPE_CHECKING:
52+
from aesara.graph.basic import Variable
5253
from aesara.graph.op import StorageMapType
5354

5455

@@ -79,13 +80,21 @@ def numba_vectorize(*args, **kwargs):
7980

8081

8182
@singledispatch
82-
def get_numba_type(aesara_type: Type, **kwargs) -> numba.types.Type:
83-
r"""Create a Numba type object for a :class:`Type`."""
83+
def get_numba_type(aesara_type: Type, var: "Variable", **kwargs) -> numba.types.Type:
84+
r"""Create a Numba type object for a :class:`Type`.
85+
86+
Parameters
87+
----------
88+
aesara_type
89+
The :class:`Type` to convert.
90+
var
91+
The :class:`Variable` corresponding to `aesara_type`.
92+
"""
8493
return numba.types.pyobject
8594

8695

8796
@get_numba_type.register(ScalarType)
88-
def get_numba_type_ScalarType(aesara_type, **kwargs):
97+
def get_numba_type_ScalarType(aesara_type, var, **kwargs):
8998
dtype = np.dtype(aesara_type.dtype)
9099
numba_dtype = numba.from_dtype(dtype)
91100
return numba_dtype
@@ -94,6 +103,7 @@ def get_numba_type_ScalarType(aesara_type, **kwargs):
94103
@get_numba_type.register(TensorType)
95104
def get_numba_type_TensorType(
96105
aesara_type,
106+
var: "Variable",
97107
layout: str = "A",
98108
force_scalar: bool = False,
99109
reduce_to_scalar: bool = False,
@@ -103,6 +113,8 @@ def get_numba_type_TensorType(
103113
----------
104114
aesara_type
105115
The :class:`Type` to convert.
116+
var
117+
The :class:`Variable` corresponding to `aesara_type`.
106118
layout
107119
The :class:`numpy.ndarray` layout to use.
108120
force_scalar
@@ -114,7 +126,10 @@ def get_numba_type_TensorType(
114126
numba_dtype = numba.from_dtype(dtype)
115127
if force_scalar or (reduce_to_scalar and getattr(aesara_type, "ndim", None) == 0):
116128
return numba_dtype
117-
return numba.types.Array(numba_dtype, aesara_type.ndim, layout)
129+
130+
readonly = getattr(var.tag, "indestructible", False)
131+
132+
return numba.types.Array(numba_dtype, aesara_type.ndim, layout, readonly=readonly)
118133

119134

120135
def create_numba_signature(
@@ -123,11 +138,11 @@ def create_numba_signature(
123138
"""Create a Numba type for the signature of an `Apply` node or `FunctionGraph`."""
124139
input_types = []
125140
for inp in node_or_fgraph.inputs:
126-
input_types.append(get_numba_type(inp.type, **kwargs))
141+
input_types.append(get_numba_type(inp.type, inp, **kwargs))
127142

128143
output_types = []
129144
for out in node_or_fgraph.outputs:
130-
output_types.append(get_numba_type(out.type, **kwargs))
145+
output_types.append(get_numba_type(out.type, inp, **kwargs))
131146

132147
if isinstance(node_or_fgraph, FunctionGraph):
133148
return numba.types.Tuple(output_types)(*input_types)
@@ -379,9 +394,9 @@ def numba_funcify_perform(op, node, storage_map=None, **kwargs) -> Callable:
379394
n_outputs = len(node.outputs)
380395

381396
if n_outputs > 1:
382-
ret_sig = numba.types.Tuple([get_numba_type(o.type) for o in node.outputs])
397+
ret_sig = numba.types.Tuple([get_numba_type(o.type, o) for o in node.outputs])
383398
else:
384-
ret_sig = get_numba_type(node.outputs[0].type)
399+
ret_sig = get_numba_type(node.outputs[0].type, node.outputs[0])
385400

386401
output_types = tuple(out.type for out in node.outputs)
387402
params = node.run_params()
@@ -821,7 +836,7 @@ def cholesky(a):
821836
UserWarning,
822837
)
823838

824-
ret_sig = get_numba_type(node.outputs[0].type)
839+
ret_sig = get_numba_type(node.outputs[0].type, node.outputs[0])
825840

826841
@numba_njit
827842
def cholesky(a):
@@ -850,7 +865,7 @@ def numba_funcify_Solve(op, node, **kwargs):
850865
UserWarning,
851866
)
852867

853-
ret_sig = get_numba_type(node.outputs[0].type)
868+
ret_sig = get_numba_type(node.outputs[0].type, node.outputs[0])
854869

855870
@numba_njit
856871
def solve(a, b):

aesara/link/numba/dispatch/extra_ops.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,7 @@ def numba_funcify_Repeat(op, node, **kwargs):
184184
UserWarning,
185185
)
186186

187-
ret_sig = get_numba_type(node.outputs[0].type)
187+
ret_sig = get_numba_type(node.outputs[0].type, node.outputs[0])
188188

189189
@numba_basic.numba_njit
190190
def repeatop(x, repeats):
@@ -243,9 +243,11 @@ def unique(x):
243243
)
244244

245245
if returns_multi:
246-
ret_sig = numba.types.Tuple([get_numba_type(o.type) for o in node.outputs])
246+
ret_sig = numba.types.Tuple(
247+
[get_numba_type(o.type, o) for o in node.outputs]
248+
)
247249
else:
248-
ret_sig = get_numba_type(node.outputs[0].type)
250+
ret_sig = get_numba_type(node.outputs[0].type, node.outputs[0])
249251

250252
@numba_basic.numba_njit
251253
def unique(x):
@@ -308,7 +310,7 @@ def numba_funcify_Searchsorted(op, node, **kwargs):
308310
UserWarning,
309311
)
310312

311-
ret_sig = get_numba_type(node.outputs[0].type)
313+
ret_sig = get_numba_type(node.outputs[0].type, node.outputs[0])
312314

313315
@numba_basic.numba_njit
314316
def searchsorted(a, v, sorter):

aesara/link/numba/dispatch/nlinalg.py

+9-4
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def numba_funcify_SVD(op, node, **kwargs):
3636
UserWarning,
3737
)
3838

39-
ret_sig = get_numba_type(node.outputs[0].type)
39+
ret_sig = get_numba_type(node.outputs[0].type, node.outputs[0])
4040

4141
@numba_basic.numba_njit
4242
def svd(x):
@@ -101,7 +101,10 @@ def numba_funcify_Eigh(op, node, **kwargs):
101101

102102
out_dtypes = tuple(o.type.numpy_dtype for o in node.outputs)
103103
ret_sig = numba.types.Tuple(
104-
[get_numba_type(node.outputs[0].type), get_numba_type(node.outputs[1].type)]
104+
[
105+
get_numba_type(node.outputs[0].type, node.outputs[0]),
106+
get_numba_type(node.outputs[1].type, node.outputs[1]),
107+
]
105108
)
106109

107110
@numba_basic.numba_njit
@@ -173,9 +176,11 @@ def numba_funcify_QRFull(op, node, **kwargs):
173176
)
174177

175178
if len(node.outputs) > 1:
176-
ret_sig = numba.types.Tuple([get_numba_type(o.type) for o in node.outputs])
179+
ret_sig = numba.types.Tuple(
180+
[get_numba_type(o.type, o) for o in node.outputs]
181+
)
177182
else:
178-
ret_sig = get_numba_type(node.outputs[0].type)
183+
ret_sig = get_numba_type(node.outputs[0].type, node.outputs[0])
179184

180185
@numba_basic.numba_njit
181186
def qr_full(x):

aesara/link/numba/dispatch/sparse.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,7 @@ def copy(inst):
214214

215215

216216
@get_numba_type.register(SparseTensorType)
217-
def get_numba_type_SparseType(aesara_type, **kwargs):
217+
def get_numba_type_SparseType(aesara_type, var, **kwargs):
218218
dtype = from_dtype(np.dtype(aesara_type.dtype))
219219

220220
if aesara_type.format == "csr":

tests/link/numba/test_basic.py

+11-3
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,7 @@ def assert_fn(x, y):
246246

247247

248248
@pytest.mark.parametrize(
249-
"v, expected, force_scalar",
249+
"typ, expected, force_scalar",
250250
[
251251
(MyType(), numba.types.pyobject, False),
252252
(
@@ -267,11 +267,19 @@ def assert_fn(x, y):
267267
(at.dmatrix, numba.types.float64, True),
268268
],
269269
)
270-
def test_get_numba_type(v, expected, force_scalar):
271-
res = numba_basic.get_numba_type(v, force_scalar=force_scalar)
270+
def test_get_numba_type(typ, expected, force_scalar):
271+
res = numba_basic.get_numba_type(typ, typ(), force_scalar=force_scalar)
272272
assert res == expected
273273

274274

275+
def test_get_numba_type_readonly():
276+
typ = at.dmatrix
277+
var = typ()
278+
var.tag.indestructible = True
279+
res = numba_basic.get_numba_type(typ, var)
280+
assert not res.mutable
281+
282+
275283
@pytest.mark.parametrize(
276284
"v, expected, force_scalar",
277285
[

tests/link/numba/test_elemwise.py

+10
Original file line numberDiff line numberDiff line change
@@ -507,3 +507,13 @@ def test_MaxAndArgmax(x, axes, exc):
507507
if not isinstance(i, (SharedVariable, Constant))
508508
],
509509
)
510+
511+
512+
def test_sum_broadcast_to():
513+
"""Make sure that we handle the writability of `BroadcastTo` results correctly."""
514+
515+
x = at.vector("x")
516+
out = at.broadcast_to(x, (2, 2)).sum()
517+
518+
x_val = np.array([1, 2], dtype=config.floatX)
519+
compare_numba_and_py(((x,), (out,)), [x_val])

0 commit comments

Comments
 (0)