|
2 | 2 | import scipy as sp
|
3 | 3 | import scipy.sparse
|
4 | 4 | from numba.core import cgutils, types
|
| 5 | +from numba.core.imputils import impl_ret_borrowed |
5 | 6 | from numba.extending import (
|
6 | 7 | NativeValue,
|
7 | 8 | box,
|
| 9 | + intrinsic, |
8 | 10 | make_attribute_wrapper,
|
9 | 11 | models,
|
10 | 12 | overload,
|
11 | 13 | overload_attribute,
|
| 14 | + overload_method, |
12 | 15 | register_model,
|
13 | 16 | typeof_impl,
|
14 | 17 | unbox,
|
@@ -174,6 +177,42 @@ def ndim(inst):
|
174 | 177 | return ndim
|
175 | 178 |
|
176 | 179 |
|
| 180 | +@intrinsic |
| 181 | +def _sparse_copy(typingctx, inst, data, indices, indptr, shape): |
| 182 | + def _construct(context, builder, sig, args): |
| 183 | + typ = sig.return_type |
| 184 | + struct = cgutils.create_struct_proxy(typ)(context, builder) |
| 185 | + _, data, indices, indptr, shape = args |
| 186 | + struct.data = data |
| 187 | + struct.indices = indices |
| 188 | + struct.indptr = indptr |
| 189 | + struct.shape = shape |
| 190 | + return impl_ret_borrowed( |
| 191 | + context, |
| 192 | + builder, |
| 193 | + sig.return_type, |
| 194 | + struct._getvalue(), |
| 195 | + ) |
| 196 | + |
| 197 | + sig = inst(inst, inst.data, inst.indices, inst.indptr, inst.shape) |
| 198 | + |
| 199 | + return sig, _construct |
| 200 | + |
| 201 | + |
| 202 | +@overload_method(CSMatrixType, "copy") |
| 203 | +def overload_sparse_copy(inst): |
| 204 | + |
| 205 | + if not isinstance(inst, CSMatrixType): |
| 206 | + return |
| 207 | + |
| 208 | + def copy(inst): |
| 209 | + return _sparse_copy( |
| 210 | + inst, inst.data.copy(), inst.indices.copy(), inst.indptr.copy(), inst.shape |
| 211 | + ) |
| 212 | + |
| 213 | + return copy |
| 214 | + |
| 215 | + |
177 | 216 | @get_numba_type.register(SparseTensorType)
|
178 | 217 | def get_numba_type_SparseType(aesara_type, **kwargs):
|
179 | 218 | dtype = from_dtype(np.dtype(aesara_type.dtype))
|
|
0 commit comments