Skip to content

Commit 4c6bfd6

Browse files
Add support for SciPy CSC and CSR sparse types to Numba
1 parent e8c042b commit 4c6bfd6

File tree

3 files changed

+183
-0
lines changed

3 files changed

+183
-0
lines changed

aesara/link/numba/dispatch/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -13,5 +13,6 @@
1313
import aesara.link.numba.dispatch.random
1414
import aesara.link.numba.dispatch.elemwise
1515
import aesara.link.numba.dispatch.scan
16+
import aesara.link.numba.dispatch.sparse
1617

1718
# isort: on

aesara/link/numba/dispatch/sparse.py

+142
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
import scipy as sp
2+
import scipy.sparse
3+
from numba.core import cgutils, types
4+
from numba.extending import (
5+
NativeValue,
6+
box,
7+
make_attribute_wrapper,
8+
models,
9+
register_model,
10+
typeof_impl,
11+
unbox,
12+
)
13+
14+
15+
class CSMatrixType(types.Type):
16+
"""A Numba `Type` modeled after the base class `scipy.sparse.compressed._cs_matrix`."""
17+
18+
name: str
19+
instance_class: type
20+
21+
def __init__(self, dtype):
22+
self.dtype = dtype
23+
self.data = types.Array(dtype, 1, "A")
24+
self.indices = types.Array(types.int32, 1, "A")
25+
self.indptr = types.Array(types.int32, 1, "A")
26+
self.shape = types.UniTuple(types.int64, 2)
27+
super().__init__(self.name)
28+
29+
30+
make_attribute_wrapper(CSMatrixType, "data", "data")
31+
make_attribute_wrapper(CSMatrixType, "indices", "indices")
32+
make_attribute_wrapper(CSMatrixType, "indptr", "indptr")
33+
make_attribute_wrapper(CSMatrixType, "shape", "shape")
34+
35+
36+
class CSRMatrixType(CSMatrixType):
37+
name = "csr_matrix"
38+
39+
@staticmethod
40+
def instance_class(data, indices, indptr, shape):
41+
return sp.sparse.csr_matrix((data, indices, indptr), shape, copy=False)
42+
43+
44+
class CSCMatrixType(CSMatrixType):
45+
name = "csc_matrix"
46+
47+
@staticmethod
48+
def instance_class(data, indices, indptr, shape):
49+
return sp.sparse.csc_matrix((data, indices, indptr), shape, copy=False)
50+
51+
52+
@typeof_impl.register(sp.sparse.csc_matrix)
53+
def typeof_csc_matrix(val, c):
54+
data = typeof_impl(val.data, c)
55+
return CSCMatrixType(data.dtype)
56+
57+
58+
@typeof_impl.register(sp.sparse.csr_matrix)
59+
def typeof_csr_matrix(val, c):
60+
data = typeof_impl(val.data, c)
61+
return CSRMatrixType(data.dtype)
62+
63+
64+
@register_model(CSRMatrixType)
65+
class CSRMatrixModel(models.StructModel):
66+
def __init__(self, dmm, fe_type):
67+
members = [
68+
("data", fe_type.data),
69+
("indices", fe_type.indices),
70+
("indptr", fe_type.indptr),
71+
("shape", fe_type.shape),
72+
]
73+
super().__init__(dmm, fe_type, members)
74+
75+
76+
@register_model(CSCMatrixType)
77+
class CSCMatrixModel(models.StructModel):
78+
def __init__(self, dmm, fe_type):
79+
members = [
80+
("data", fe_type.data),
81+
("indices", fe_type.indices),
82+
("indptr", fe_type.indptr),
83+
("shape", fe_type.shape),
84+
]
85+
super().__init__(dmm, fe_type, members)
86+
87+
88+
@unbox(CSCMatrixType)
89+
@unbox(CSRMatrixType)
90+
def unbox_matrix(typ, obj, c):
91+
92+
struct_ptr = cgutils.create_struct_proxy(typ)(c.context, c.builder)
93+
94+
data = c.pyapi.object_getattr_string(obj, "data")
95+
indices = c.pyapi.object_getattr_string(obj, "indices")
96+
indptr = c.pyapi.object_getattr_string(obj, "indptr")
97+
shape = c.pyapi.object_getattr_string(obj, "shape")
98+
99+
struct_ptr.data = c.unbox(typ.data, data).value
100+
struct_ptr.indices = c.unbox(typ.indices, indices).value
101+
struct_ptr.indptr = c.unbox(typ.indptr, indptr).value
102+
struct_ptr.shape = c.unbox(typ.shape, shape).value
103+
104+
c.pyapi.decref(data)
105+
c.pyapi.decref(indices)
106+
c.pyapi.decref(indptr)
107+
c.pyapi.decref(shape)
108+
109+
is_error_ptr = cgutils.alloca_once_value(c.builder, cgutils.false_bit)
110+
is_error = c.builder.load(is_error_ptr)
111+
112+
res = NativeValue(struct_ptr._getvalue(), is_error=is_error)
113+
114+
return res
115+
116+
117+
@box(CSCMatrixType)
118+
@box(CSRMatrixType)
119+
def box_matrix(typ, val, c):
120+
struct_ptr = cgutils.create_struct_proxy(typ)(c.context, c.builder, value=val)
121+
122+
data_obj = c.box(typ.data, struct_ptr.data)
123+
indices_obj = c.box(typ.indices, struct_ptr.indices)
124+
indptr_obj = c.box(typ.indptr, struct_ptr.indptr)
125+
shape_obj = c.box(typ.shape, struct_ptr.shape)
126+
127+
c.pyapi.incref(data_obj)
128+
c.pyapi.incref(indices_obj)
129+
c.pyapi.incref(indptr_obj)
130+
c.pyapi.incref(shape_obj)
131+
132+
cls_obj = c.pyapi.unserialize(c.pyapi.serialize_object(typ.instance_class))
133+
obj = c.pyapi.call_function_objargs(
134+
cls_obj, (data_obj, indices_obj, indptr_obj, shape_obj)
135+
)
136+
137+
c.pyapi.decref(data_obj)
138+
c.pyapi.decref(indices_obj)
139+
c.pyapi.decref(indptr_obj)
140+
c.pyapi.decref(shape_obj)
141+
142+
return obj

tests/link/numba/test_sparse.py

+40
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
import numba
2+
import numpy as np
3+
import scipy as sp
4+
5+
# Load Numba customizations
6+
import aesara.link.numba.dispatch.sparse # noqa: F401
7+
8+
9+
def test_sparse_unboxing():
10+
@numba.njit
11+
def test_unboxing(x, y):
12+
return x.shape, y.shape
13+
14+
x_val = sp.sparse.csr_matrix(np.eye(100))
15+
y_val = sp.sparse.csc_matrix(np.eye(101))
16+
17+
res = test_unboxing(x_val, y_val)
18+
19+
assert res == (x_val.shape, y_val.shape)
20+
21+
22+
def test_sparse_boxing():
23+
@numba.njit
24+
def test_boxing(x, y):
25+
return x, y
26+
27+
x_val = sp.sparse.csr_matrix(np.eye(100))
28+
y_val = sp.sparse.csc_matrix(np.eye(101))
29+
30+
res_x_val, res_y_val = test_boxing(x_val, y_val)
31+
32+
assert np.array_equal(res_x_val.data, x_val.data)
33+
assert np.array_equal(res_x_val.indices, x_val.indices)
34+
assert np.array_equal(res_x_val.indptr, x_val.indptr)
35+
assert res_x_val.shape == x_val.shape
36+
37+
assert np.array_equal(res_y_val.data, y_val.data)
38+
assert np.array_equal(res_y_val.indices, y_val.indices)
39+
assert np.array_equal(res_y_val.indptr, y_val.indptr)
40+
assert res_y_val.shape == y_val.shape

0 commit comments

Comments
 (0)