Skip to content

Commit 5dee3c7

Browse files
Implement a copy method for Numba sparse types
1 parent e9bba0f commit 5dee3c7

File tree

2 files changed

+52
-0
lines changed

2 files changed

+52
-0
lines changed

aesara/link/numba/dispatch/sparse.py

+39
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,16 @@
22
import scipy as sp
33
import scipy.sparse
44
from numba.core import cgutils, types
5+
from numba.core.imputils import impl_ret_borrowed
56
from numba.extending import (
67
NativeValue,
78
box,
9+
intrinsic,
810
make_attribute_wrapper,
911
models,
1012
overload,
1113
overload_attribute,
14+
overload_method,
1215
register_model,
1316
typeof_impl,
1417
unbox,
@@ -174,6 +177,42 @@ def ndim(inst):
174177
return ndim
175178

176179

180+
@intrinsic
181+
def _sparse_copy(typingctx, inst, data, indices, indptr, shape):
182+
def _construct(context, builder, sig, args):
183+
typ = sig.return_type
184+
struct = cgutils.create_struct_proxy(typ)(context, builder)
185+
_, data, indices, indptr, shape = args
186+
struct.data = data
187+
struct.indices = indices
188+
struct.indptr = indptr
189+
struct.shape = shape
190+
return impl_ret_borrowed(
191+
context,
192+
builder,
193+
sig.return_type,
194+
struct._getvalue(),
195+
)
196+
197+
sig = inst(inst, inst.data, inst.indices, inst.indptr, inst.shape)
198+
199+
return sig, _construct
200+
201+
202+
@overload_method(CSMatrixType, "copy")
203+
def overload_sparse_copy(inst):
204+
205+
if not isinstance(inst, CSMatrixType):
206+
return
207+
208+
def copy(inst):
209+
return _sparse_copy(
210+
inst, inst.data.copy(), inst.indices.copy(), inst.indptr.copy(), inst.shape
211+
)
212+
213+
return copy
214+
215+
177216
@get_numba_type.register(SparseTensorType)
178217
def get_numba_type_SparseType(aesara_type, **kwargs):
179218
dtype = from_dtype(np.dtype(aesara_type.dtype))

tests/link/numba/test_sparse.py

+13
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,19 @@ def test_fn(x):
7171
assert res == 2
7272

7373

74+
def test_sparse_copy():
75+
@numba.njit
76+
def test_fn(x):
77+
y = x.copy()
78+
return (
79+
y is not x and np.all(x.data == y.data) and np.all(x.indices == y.indices)
80+
)
81+
82+
x_val = sp.sparse.csr_matrix(np.eye(100))
83+
84+
assert test_fn(x_val)
85+
86+
7487
def test_sparse_objmode():
7588

7689
x = SparseTensorType("csc", dtype=config.floatX)()

0 commit comments

Comments
 (0)