diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index a33b812fd0..207171eeaa 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -7,7 +7,7 @@ exclude: | )$ repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.3.0 + rev: v4.4.0 hooks: - id: debug-statements exclude: | @@ -25,9 +25,10 @@ repos: - id: black language_version: python3 - repo: https://github.com/pycqa/flake8 - rev: 5.0.4 + rev: 6.0.0 hooks: - id: flake8 + language_version: python39 - repo: https://github.com/pycqa/isort rev: 5.10.1 hooks: @@ -47,7 +48,7 @@ repos: )$ args: ['--in-place', '--remove-all-unused-imports', '--remove-unused-variable'] - repo: https://github.com/pre-commit/mirrors-mypy - rev: v0.982 + rev: v0.991 hooks: - id: mypy additional_dependencies: diff --git a/aesara/link/numba/dispatch/__init__.py b/aesara/link/numba/dispatch/__init__.py index 31a0467483..11f47aa8f9 100644 --- a/aesara/link/numba/dispatch/__init__.py +++ b/aesara/link/numba/dispatch/__init__.py @@ -13,5 +13,6 @@ import aesara.link.numba.dispatch.random import aesara.link.numba.dispatch.elemwise import aesara.link.numba.dispatch.scan +import aesara.link.numba.dispatch.sparse # isort: on diff --git a/aesara/link/numba/dispatch/basic.py b/aesara/link/numba/dispatch/basic.py index b6ee96436e..bd7f1a9c79 100644 --- a/aesara/link/numba/dispatch/basic.py +++ b/aesara/link/numba/dispatch/basic.py @@ -371,6 +371,7 @@ def _numba_funcify( A `Callable` that can be JIT-compiled in Numba using `numba.jit`. """ + raise NotImplementedError() @_numba_funcify.register(Op) diff --git a/aesara/link/numba/dispatch/sparse.py b/aesara/link/numba/dispatch/sparse.py new file mode 100644 index 0000000000..d07e029501 --- /dev/null +++ b/aesara/link/numba/dispatch/sparse.py @@ -0,0 +1,142 @@ +import scipy as sp +import scipy.sparse +from numba.core import cgutils, types +from numba.extending import ( + NativeValue, + box, + make_attribute_wrapper, + models, + register_model, + typeof_impl, + unbox, +) + + +class CSMatrixType(types.Type): + """A Numba `Type` modeled after the base class `scipy.sparse.compressed._cs_matrix`.""" + + name: str + instance_class: type + + def __init__(self, dtype): + self.dtype = dtype + self.data = types.Array(dtype, 1, "A") + self.indices = types.Array(types.int32, 1, "A") + self.indptr = types.Array(types.int32, 1, "A") + self.shape = types.UniTuple(types.int64, 2) + super().__init__(self.name) + + +make_attribute_wrapper(CSMatrixType, "data", "data") +make_attribute_wrapper(CSMatrixType, "indices", "indices") +make_attribute_wrapper(CSMatrixType, "indptr", "indptr") +make_attribute_wrapper(CSMatrixType, "shape", "shape") + + +class CSRMatrixType(CSMatrixType): + name = "csr_matrix" + + @staticmethod + def instance_class(data, indices, indptr, shape): + return sp.sparse.csr_matrix((data, indices, indptr), shape, copy=False) + + +class CSCMatrixType(CSMatrixType): + name = "csc_matrix" + + @staticmethod + def instance_class(data, indices, indptr, shape): + return sp.sparse.csc_matrix((data, indices, indptr), shape, copy=False) + + +@typeof_impl.register(sp.sparse.csc_matrix) +def typeof_csc_matrix(val, c): + data = typeof_impl(val.data, c) + return CSCMatrixType(data.dtype) + + +@typeof_impl.register(sp.sparse.csr_matrix) +def typeof_csr_matrix(val, c): + data = typeof_impl(val.data, c) + return CSRMatrixType(data.dtype) + + +@register_model(CSRMatrixType) +class CSRMatrixModel(models.StructModel): + def __init__(self, dmm, fe_type): + members = [ + ("data", fe_type.data), + ("indices", fe_type.indices), + ("indptr", fe_type.indptr), + ("shape", fe_type.shape), + ] + super().__init__(dmm, fe_type, members) + + +@register_model(CSCMatrixType) +class CSCMatrixModel(models.StructModel): + def __init__(self, dmm, fe_type): + members = [ + ("data", fe_type.data), + ("indices", fe_type.indices), + ("indptr", fe_type.indptr), + ("shape", fe_type.shape), + ] + super().__init__(dmm, fe_type, members) + + +@unbox(CSCMatrixType) +@unbox(CSRMatrixType) +def unbox_matrix(typ, obj, c): + + struct_ptr = cgutils.create_struct_proxy(typ)(c.context, c.builder) + + data = c.pyapi.object_getattr_string(obj, "data") + indices = c.pyapi.object_getattr_string(obj, "indices") + indptr = c.pyapi.object_getattr_string(obj, "indptr") + shape = c.pyapi.object_getattr_string(obj, "shape") + + struct_ptr.data = c.unbox(typ.data, data).value + struct_ptr.indices = c.unbox(typ.indices, indices).value + struct_ptr.indptr = c.unbox(typ.indptr, indptr).value + struct_ptr.shape = c.unbox(typ.shape, shape).value + + c.pyapi.decref(data) + c.pyapi.decref(indices) + c.pyapi.decref(indptr) + c.pyapi.decref(shape) + + is_error_ptr = cgutils.alloca_once_value(c.builder, cgutils.false_bit) + is_error = c.builder.load(is_error_ptr) + + res = NativeValue(struct_ptr._getvalue(), is_error=is_error) + + return res + + +@box(CSCMatrixType) +@box(CSRMatrixType) +def box_matrix(typ, val, c): + struct_ptr = cgutils.create_struct_proxy(typ)(c.context, c.builder, value=val) + + data_obj = c.box(typ.data, struct_ptr.data) + indices_obj = c.box(typ.indices, struct_ptr.indices) + indptr_obj = c.box(typ.indptr, struct_ptr.indptr) + shape_obj = c.box(typ.shape, struct_ptr.shape) + + c.pyapi.incref(data_obj) + c.pyapi.incref(indices_obj) + c.pyapi.incref(indptr_obj) + c.pyapi.incref(shape_obj) + + cls_obj = c.pyapi.unserialize(c.pyapi.serialize_object(typ.instance_class)) + obj = c.pyapi.call_function_objargs( + cls_obj, (data_obj, indices_obj, indptr_obj, shape_obj) + ) + + c.pyapi.decref(data_obj) + c.pyapi.decref(indices_obj) + c.pyapi.decref(indptr_obj) + c.pyapi.decref(shape_obj) + + return obj diff --git a/aesara/link/vm.py b/aesara/link/vm.py index e1cefb33b5..e79e9ebbb4 100644 --- a/aesara/link/vm.py +++ b/aesara/link/vm.py @@ -1056,7 +1056,7 @@ def make_vm( callback=self.callback, callback_input=self.callback_input, ) - elif self.use_cloop and CVM: + elif self.use_cloop and CVM is not None: # create a map from nodes to ints and vars to ints nodes_idx = {} diff --git a/aesara/scan/op.py b/aesara/scan/op.py index 24e0e867a7..4914e041fe 100644 --- a/aesara/scan/op.py +++ b/aesara/scan/op.py @@ -54,6 +54,7 @@ import numpy as np import aesara +import aesara.link.utils as link_utils from aesara import tensor as at from aesara.compile.builders import construct_nominal_fgraph, infer_shape from aesara.compile.function.pfunc import pfunc @@ -75,7 +76,6 @@ from aesara.graph.utils import InconsistencyError, MissingInputError from aesara.link.c.basic import CLinker from aesara.link.c.exceptions import MissingGXX -from aesara.link.utils import raise_with_op from aesara.printing import op_debug_information from aesara.scan.utils import ScanProfileStats, Validator, forced_replace, safe_new from aesara.tensor.basic import as_tensor_variable @@ -1629,7 +1629,7 @@ def p(node, inputs, outputs): if hasattr(self.fn.vm, "position_of_error") and hasattr( self.fn.vm, "thunks" ): - raise_with_op( + link_utils.raise_with_op( self.fn.maker.fgraph, self.fn.vm.nodes[self.fn.vm.position_of_error], self.fn.vm.thunks[self.fn.vm.position_of_error], @@ -1932,7 +1932,7 @@ def perform(self, node, inputs, output_storage, params=None): # done by raise_with_op is not implemented in C. if hasattr(vm, "thunks"): # For the CVM - raise_with_op( + link_utils.raise_with_op( self.fn.maker.fgraph, vm.nodes[vm.position_of_error], vm.thunks[vm.position_of_error], @@ -1942,7 +1942,7 @@ def perform(self, node, inputs, output_storage, params=None): # We don't have access from python to all the # temps values So for now, we just don't print # the extra shapes/strides info - raise_with_op( + link_utils.raise_with_op( self.fn.maker.fgraph, vm.nodes[vm.position_of_error] ) else: @@ -3427,7 +3427,7 @@ def profile_printer( ) -@op_debug_information.register(Scan) # type: ignore[has-type] +@op_debug_information.register(Scan) def _op_debug_information_Scan(op, node): from typing import Sequence diff --git a/tests/link/numba/test_sparse.py b/tests/link/numba/test_sparse.py new file mode 100644 index 0000000000..af49752f3b --- /dev/null +++ b/tests/link/numba/test_sparse.py @@ -0,0 +1,40 @@ +import numba +import numpy as np +import scipy as sp + +# Load Numba customizations +import aesara.link.numba.dispatch.sparse # noqa: F401 + + +def test_sparse_unboxing(): + @numba.njit + def test_unboxing(x, y): + return x.shape, y.shape + + x_val = sp.sparse.csr_matrix(np.eye(100)) + y_val = sp.sparse.csc_matrix(np.eye(101)) + + res = test_unboxing(x_val, y_val) + + assert res == (x_val.shape, y_val.shape) + + +def test_sparse_boxing(): + @numba.njit + def test_boxing(x, y): + return x, y + + x_val = sp.sparse.csr_matrix(np.eye(100)) + y_val = sp.sparse.csc_matrix(np.eye(101)) + + res_x_val, res_y_val = test_boxing(x_val, y_val) + + assert np.array_equal(res_x_val.data, x_val.data) + assert np.array_equal(res_x_val.indices, x_val.indices) + assert np.array_equal(res_x_val.indptr, x_val.indptr) + assert res_x_val.shape == x_val.shape + + assert np.array_equal(res_y_val.data, y_val.data) + assert np.array_equal(res_y_val.indices, y_val.indices) + assert np.array_equal(res_y_val.indptr, y_val.indptr) + assert res_y_val.shape == y_val.shape