49
49
50
50
51
51
if TYPE_CHECKING :
52
+ from aesara .graph .basic import Variable
52
53
from aesara .graph .op import StorageMapType
53
54
54
55
@@ -79,13 +80,21 @@ def numba_vectorize(*args, **kwargs):
79
80
80
81
81
82
@singledispatch
82
- def get_numba_type (aesara_type : Type , ** kwargs ) -> numba .types .Type :
83
- r"""Create a Numba type object for a :class:`Type`."""
83
+ def get_numba_type (aesara_type : Type , var : "Variable" , ** kwargs ) -> numba .types .Type :
84
+ r"""Create a Numba type object for a :class:`Type`.
85
+
86
+ Parameters
87
+ ----------
88
+ aesara_type
89
+ The :class:`Type` to convert.
90
+ var
91
+ The :class:`Variable` corresponding to `aesara_type`.
92
+ """
84
93
return numba .types .pyobject
85
94
86
95
87
96
@get_numba_type .register (ScalarType )
88
- def get_numba_type_ScalarType (aesara_type , ** kwargs ):
97
+ def get_numba_type_ScalarType (aesara_type , var , ** kwargs ):
89
98
dtype = np .dtype (aesara_type .dtype )
90
99
numba_dtype = numba .from_dtype (dtype )
91
100
return numba_dtype
@@ -94,6 +103,7 @@ def get_numba_type_ScalarType(aesara_type, **kwargs):
94
103
@get_numba_type .register (TensorType )
95
104
def get_numba_type_TensorType (
96
105
aesara_type ,
106
+ var : "Variable" ,
97
107
layout : str = "A" ,
98
108
force_scalar : bool = False ,
99
109
reduce_to_scalar : bool = False ,
@@ -103,6 +113,8 @@ def get_numba_type_TensorType(
103
113
----------
104
114
aesara_type
105
115
The :class:`Type` to convert.
116
+ var
117
+ The :class:`Variable` corresponding to `aesara_type`.
106
118
layout
107
119
The :class:`numpy.ndarray` layout to use.
108
120
force_scalar
@@ -114,7 +126,10 @@ def get_numba_type_TensorType(
114
126
numba_dtype = numba .from_dtype (dtype )
115
127
if force_scalar or (reduce_to_scalar and getattr (aesara_type , "ndim" , None ) == 0 ):
116
128
return numba_dtype
117
- return numba .types .Array (numba_dtype , aesara_type .ndim , layout )
129
+
130
+ readonly = getattr (var .tag , "indestructible" , False )
131
+
132
+ return numba .types .Array (numba_dtype , aesara_type .ndim , layout , readonly = readonly )
118
133
119
134
120
135
def create_numba_signature (
@@ -123,11 +138,11 @@ def create_numba_signature(
123
138
"""Create a Numba type for the signature of an `Apply` node or `FunctionGraph`."""
124
139
input_types = []
125
140
for inp in node_or_fgraph .inputs :
126
- input_types .append (get_numba_type (inp .type , ** kwargs ))
141
+ input_types .append (get_numba_type (inp .type , inp , ** kwargs ))
127
142
128
143
output_types = []
129
144
for out in node_or_fgraph .outputs :
130
- output_types .append (get_numba_type (out .type , ** kwargs ))
145
+ output_types .append (get_numba_type (out .type , inp , ** kwargs ))
131
146
132
147
if isinstance (node_or_fgraph , FunctionGraph ):
133
148
return numba .types .Tuple (output_types )(* input_types )
@@ -379,9 +394,9 @@ def numba_funcify_perform(op, node, storage_map=None, **kwargs) -> Callable:
379
394
n_outputs = len (node .outputs )
380
395
381
396
if n_outputs > 1 :
382
- ret_sig = numba .types .Tuple ([get_numba_type (o .type ) for o in node .outputs ])
397
+ ret_sig = numba .types .Tuple ([get_numba_type (o .type , o ) for o in node .outputs ])
383
398
else :
384
- ret_sig = get_numba_type (node .outputs [0 ].type )
399
+ ret_sig = get_numba_type (node .outputs [0 ].type , node . outputs [ 0 ] )
385
400
386
401
output_types = tuple (out .type for out in node .outputs )
387
402
params = node .run_params ()
@@ -821,7 +836,7 @@ def cholesky(a):
821
836
UserWarning ,
822
837
)
823
838
824
- ret_sig = get_numba_type (node .outputs [0 ].type )
839
+ ret_sig = get_numba_type (node .outputs [0 ].type , node . outputs [ 0 ] )
825
840
826
841
@numba_njit
827
842
def cholesky (a ):
@@ -850,7 +865,7 @@ def numba_funcify_Solve(op, node, **kwargs):
850
865
UserWarning ,
851
866
)
852
867
853
- ret_sig = get_numba_type (node .outputs [0 ].type )
868
+ ret_sig = get_numba_type (node .outputs [0 ].type , node . outputs [ 0 ] )
854
869
855
870
@numba_njit
856
871
def solve (a , b ):
0 commit comments