diff --git a/aesara/configdefaults.py b/aesara/configdefaults.py index d327cb5c80..5bce7b067a 100644 --- a/aesara/configdefaults.py +++ b/aesara/configdefaults.py @@ -1252,12 +1252,6 @@ def add_numba_configvars(): BoolParam(True), in_c_key=False, ) - config.add( - "numba_scipy", - ("Enable usage of the numba_scipy package for special functions",), - BoolParam(True), - in_c_key=False, - ) def _default_compiledirname(): diff --git a/aesara/link/numba/dispatch/cython_support.py b/aesara/link/numba/dispatch/cython_support.py new file mode 100644 index 0000000000..cd82bc39e5 --- /dev/null +++ b/aesara/link/numba/dispatch/cython_support.py @@ -0,0 +1,209 @@ +import ctypes +import importlib +import re +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, cast + +import numba +import numpy as np +from numpy.typing import DTypeLike +from scipy import LowLevelCallable + + +_C_TO_NUMPY: Dict[str, DTypeLike] = { + "bool": np.bool_, + "signed char": np.byte, + "unsigned char": np.ubyte, + "short": np.short, + "unsigned short": np.ushort, + "int": np.intc, + "unsigned int": np.uintc, + "long": np.int_, + "unsigned long": np.uint, + "long long": np.longlong, + "float": np.single, + "double": np.double, + "long double": np.longdouble, + "float complex": np.csingle, + "double complex": np.cdouble, +} + + +@dataclass +class Signature: + res_dtype: DTypeLike + res_c_type: str + arg_dtypes: List[DTypeLike] + arg_c_types: List[str] + arg_names: List[Optional[str]] + + @property + def arg_numba_types(self) -> List[DTypeLike]: + return [numba.from_dtype(dtype) for dtype in self.arg_dtypes] + + def can_cast_args(self, args: List[DTypeLike]) -> bool: + ok = True + count = 0 + for name, dtype in zip(self.arg_names, self.arg_dtypes): + if name == "__pyx_skip_dispatch": + continue + if len(args) <= count: + raise ValueError("Incorrect number of arguments") + ok &= np.can_cast(args[count], dtype) + count += 1 + if count != len(args): + return False + return ok + + def provides(self, restype: DTypeLike, arg_dtypes: List[DTypeLike]) -> bool: + args_ok = self.can_cast_args(arg_dtypes) + if np.issubdtype(restype, np.inexact): + result_ok = np.can_cast(self.res_dtype, restype, casting="same_kind") + # We do not want to provide less accuracy than advertised + result_ok &= np.dtype(self.res_dtype).itemsize >= np.dtype(restype).itemsize + else: + result_ok = np.can_cast(self.res_dtype, restype) + return args_ok and result_ok + + @staticmethod + def from_c_types(signature: bytes) -> "Signature": + # Match strings like "double(int, double)" + # and extract the return type and the joined arguments + expr = re.compile(rb"\s*(?P[\w ]*\w+)\s*\((?P[\w\s,]*)\)") + re_match = re.fullmatch(expr, signature) + + if re_match is None: + raise ValueError(f"Invalid signature: {signature.decode()}") + + groups = re_match.groupdict() + res_c_type = groups["restype"].decode() + res_dtype: DTypeLike = _C_TO_NUMPY[res_c_type] + + raw_args = groups["args"] + + decl_expr = re.compile( + rb"\s*(?P((long )|(unsigned )|(signed )|(double )|)" + rb"((double)|(float)|(int)|(short)|(char)|(long)|(bool)|(complex)))" + rb"(\s(?P[\w_]*))?\s*" + ) + + arg_dtypes = [] + arg_names: List[Optional[str]] = [] + arg_c_types = [] + for raw_arg in raw_args.split(b","): + re_match = re.fullmatch(decl_expr, raw_arg) + if re_match is None: + raise ValueError(f"Invalid signature: {signature.decode()}") + groups = re_match.groupdict() + arg_c_type = groups["type"].decode() + try: + arg_dtype = _C_TO_NUMPY[arg_c_type] + except KeyError: + raise ValueError(f"Unknown C type: {arg_c_type}") + + arg_c_types.append(arg_c_type) + arg_dtypes.append(arg_dtype) + name = groups["name"] + if not name: + arg_names.append(None) + else: + arg_names.append(name.decode()) + + return Signature(res_dtype, res_c_type, arg_dtypes, arg_c_types, arg_names) + + +def _available_impls(func: Callable) -> List[Tuple[Signature, Any]]: + """Find all available implementations for a fused cython function.""" + impls = [] + mod = importlib.import_module(func.__module__) + + signatures = getattr(func, "__signatures__", None) + if signatures is not None: + # Cython function with __signatures__ should be fused and thus + # indexable + func_map = cast(Mapping, func) + candidates = [func_map[key] for key in signatures] + else: + candidates = [func] + for candidate in candidates: + name = candidate.__name__ + capsule = mod.__pyx_capi__[name] + llc = LowLevelCallable(capsule) + try: + signature = Signature.from_c_types(llc.signature.encode()) + except KeyError: + continue + impls.append((signature, capsule)) + return impls + + +class _CythonWrapper(numba.types.WrapperAddressProtocol): + def __init__(self, pyfunc, signature, capsule): + self._keep_alive = capsule + get_name = ctypes.pythonapi.PyCapsule_GetName + get_name.restype = ctypes.c_char_p + get_name.argtypes = (ctypes.py_object,) + + raw_signature = get_name(capsule) + + get_pointer = ctypes.pythonapi.PyCapsule_GetPointer + get_pointer.restype = ctypes.c_void_p + get_pointer.argtypes = (ctypes.py_object, ctypes.c_char_p) + self._func_ptr = get_pointer(capsule, raw_signature) + + self._signature = signature + self._pyfunc = pyfunc + + def signature(self): + return numba.from_dtype(self._signature.res_dtype)( + *self._signature.arg_numba_types + ) + + def __wrapper_address__(self): + return self._func_ptr + + def __call__(self, *args, **kwargs): + if self.has_pyx_skip_dispatch(): + return self._pyfunc(*args[:-1], **kwargs) + else: + return self._pyfunc(*args, **kwargs) + + def has_pyx_skip_dispatch(self): + if not self._signature.arg_names: + return False + if any( + name == "__pyx_skip_dispatch" for name in self._signature.arg_names[:-1] + ): + raise ValueError("skip_dispatch parameter must be last") + return self._signature.arg_names[-1] == "__pyx_skip_dispatch" + + def numpy_arg_dtypes(self): + return self._signature.arg_dtypes + + def numpy_output_dtype(self): + return self._signature.res_dtype + + +def wrap_cython_function(func, restype, arg_types): + impls = _available_impls(func) + compatible = [] + for sig, capsule in impls: + if sig.provides(restype, arg_types): + compatible.append((sig, capsule)) + + def sort_key(args): + sig, _ = args + + # Prefer functions with less inputs bytes + argsize = sum(np.dtype(dtype).itemsize for dtype in sig.arg_dtypes) + + # Prefer functions with more exact (integer) arguments + num_inexact = sum(np.issubdtype(dtype, np.inexact) for dtype in sig.arg_dtypes) + return (num_inexact, argsize) + + compatible.sort(key=sort_key) + + if not compatible: + raise NotImplementedError(f"Could not find a compatible impl of {func}") + sig, capsule = compatible[0] + return _CythonWrapper(func, sig, capsule) diff --git a/aesara/link/numba/dispatch/scalar.py b/aesara/link/numba/dispatch/scalar.py index 06c681bab3..a278efaad5 100644 --- a/aesara/link/numba/dispatch/scalar.py +++ b/aesara/link/numba/dispatch/scalar.py @@ -1,11 +1,7 @@ import math -import warnings -from functools import reduce from typing import List import numpy as np -import scipy -import scipy.special from aesara import config from aesara.compile.ops import ViewOp @@ -16,6 +12,7 @@ generate_fallback_impl, numba_funcify, ) +from aesara.link.numba.dispatch.cython_support import wrap_cython_function from aesara.link.utils import ( compile_function_src, get_name_for_object, @@ -41,86 +38,83 @@ def numba_funcify_ScalarOp(op, node, **kwargs): # TODO: Do we need to cache these functions so that we don't end up # compiling the same Numba function over and over again? - scalar_func_name = op.nfunc_spec[0] - scalar_func = None - - if scalar_func_name.startswith("scipy."): - func_package = scipy - scalar_func_name = scalar_func_name.split(".", 1)[-1] - - use_numba_scipy = config.numba_scipy - if use_numba_scipy: - try: - import numba_scipy # noqa: F401 - except ImportError: - use_numba_scipy = False - if not use_numba_scipy: - warnings.warn( - "Native numba versions of scipy functions might be " - "avalable if numba-scipy is installed.", - UserWarning, + scalar_func_path = op.nfunc_spec[0] + scalar_func_numba = None + + *module_path, scalar_func_name = scalar_func_path.split(".") + if not module_path: + # Assume it is numpy, and numba has an implementation + scalar_func_numba = getattr(np, scalar_func_name) + + input_dtypes = [np.dtype(input.type.dtype) for input in node.inputs] + output_dtypes = [np.dtype(output.type.dtype) for output in node.outputs] + + if len(output_dtypes) != 1: + raise ValueError("ScalarOps with more than one output are not supported") + + output_dtype = output_dtypes[0] + + input_inner_dtypes = None + output_inner_dtype = None + + # Cython functions might have an additonal argument + has_pyx_skip_dispatch = False + + if scalar_func_path.startswith("scipy.special"): + import scipy.special.cython_special + + cython_func = getattr(scipy.special.cython_special, scalar_func_name, None) + if cython_func is not None: + # try: + scalar_func_numba = wrap_cython_function( + cython_func, output_dtype, input_dtypes ) - scalar_func = generate_fallback_impl(op, node, **kwargs) - else: - func_package = np + has_pyx_skip_dispatch = scalar_func_numba.has_pyx_skip_dispatch + input_inner_dtypes = scalar_func_numba.numpy_arg_dtypes() + output_inner_dtype = scalar_func_numba.numpy_output_dtype() + # except NotImplementedError: + # pass - if scalar_func is not None: - pass - elif "." in scalar_func_name: - scalar_func = reduce(getattr, [scipy] + scalar_func_name.split(".")) - else: - scalar_func = getattr(func_package, scalar_func_name) + if scalar_func_numba is None: + scalar_func_numba = generate_fallback_impl(op, node, **kwargs) - scalar_op_fn_name = get_name_for_object(scalar_func) + scalar_op_fn_name = get_name_for_object(scalar_func_numba) unique_names = unique_name_generator( - [scalar_op_fn_name, "scalar_func"], suffix_sep="_" + [scalar_op_fn_name, "scalar_func_numba"], suffix_sep="_" ) - global_env = {"scalar_func": scalar_func} + global_env = {"scalar_func_numba": scalar_func_numba} - input_tmp_dtypes = None - if func_package == scipy and hasattr(scalar_func, "types"): - # The `numba-scipy` bindings don't provide implementations for all - # inputs types, so we need to convert the inputs to floats and back. - inp_dtype_kinds = tuple(np.dtype(inp.type.dtype).kind for inp in node.inputs) - accepted_inp_kinds = tuple( - sig_type.split("->")[0] for sig_type in scalar_func.types - ) - if not any( - all(dk == ik for dk, ik in zip(inp_dtype_kinds, ok_kinds)) - for ok_kinds in accepted_inp_kinds - ): - # They're usually ordered from lower-to-higher precision, so - # we pick the last acceptable input types - # - # XXX: We should pick the first acceptable float/int types in - # reverse, excluding all the incompatible ones (e.g. `"0"`). - # The assumption is that this is only used by `numba-scipy`-exposed - # functions, although it's possible for this to be triggered by - # something else from the `scipy` package - input_tmp_dtypes = tuple(np.dtype(k) for k in accepted_inp_kinds[-1]) - - if input_tmp_dtypes is None: + if input_inner_dtypes is None and output_inner_dtype is None: unique_names = unique_name_generator( - [scalar_op_fn_name, "scalar_func"], suffix_sep="_" + [scalar_op_fn_name, "scalar_func_numba"], suffix_sep="_" ) input_names = ", ".join( [unique_names(v, force_unique=True) for v in node.inputs] ) - scalar_op_src = f""" + if not has_pyx_skip_dispatch: + scalar_op_src = f""" +def {scalar_op_fn_name}({input_names}): + return scalar_func_numba({input_names}) + """ + else: + scalar_op_src = f""" def {scalar_op_fn_name}({input_names}): - return scalar_func({input_names}) - """ + return scalar_func_numba({input_names}, np.intc(1)) + """ + else: global_env["direct_cast"] = numba_basic.direct_cast - global_env["output_dtype"] = np.dtype(node.outputs[0].type.dtype) + global_env["output_dtype"] = np.dtype(output_inner_dtype) input_tmp_dtype_names = { - f"inp_tmp_dtype_{i}": i_dtype for i, i_dtype in enumerate(input_tmp_dtypes) + f"inp_tmp_dtype_{i}": i_dtype + for i, i_dtype in enumerate(input_inner_dtypes) } global_env.update(input_tmp_dtype_names) unique_names = unique_name_generator( - [scalar_op_fn_name, "scalar_func"] + list(global_env.keys()), suffix_sep="_" + [scalar_op_fn_name, "scalar_func_numba"] + list(global_env.keys()), + suffix_sep="_", ) input_names = [unique_names(v, force_unique=True) for v in node.inputs] @@ -132,10 +126,16 @@ def {scalar_op_fn_name}({input_names}): ) ] ) - scalar_op_src = f""" + if not has_pyx_skip_dispatch: + scalar_op_src = f""" +def {scalar_op_fn_name}({', '.join(input_names)}): + return direct_cast(scalar_func_numba({converted_call_args}), output_dtype) + """ + else: + scalar_op_src = f""" def {scalar_op_fn_name}({', '.join(input_names)}): - return direct_cast(scalar_func({converted_call_args}), output_dtype) - """ + return direct_cast(scalar_func_numba({converted_call_args}, np.intc(1)), output_dtype) + """ scalar_op_fn = compile_function_src( scalar_op_src, scalar_op_fn_name, {**globals(), **global_env} diff --git a/tests/link/test_cython_support.py b/tests/link/test_cython_support.py new file mode 100644 index 0000000000..c119b2fb62 --- /dev/null +++ b/tests/link/test_cython_support.py @@ -0,0 +1,92 @@ +import numpy as np +import pytest +import scipy.special.cython_special +from numba.types import float32, float64, int32, int64 + +from aesara.link.numba.dispatch.cython_support import Signature, wrap_cython_function + + +@pytest.mark.parametrize( + "sig, expected_result, expected_args", + [ + (b"double(double)", np.float64, [np.float64]), + (b"float(unsigned int)", np.float32, [np.uintc]), + (b"unsigned char(unsigned short foo)", np.ubyte, [np.ushort]), + ( + b"unsigned char(unsigned short foo, double bar)", + np.ubyte, + [np.ushort, np.float64], + ), + ], +) +def test_parse_signature(sig, expected_result, expected_args): + actual = Signature.from_c_types(sig) + assert actual.res_dtype == expected_result + assert actual.arg_dtypes == expected_args + + +@pytest.mark.parametrize( + "have, want, should_provide", + [ + (b"double(int)", b"float(int)", True), + (b"float(int)", b"double(int)", False), + (b"double(unsigned short)", b"double(unsigned char)", True), + (b"double(unsigned char)", b"double(short)", False), + (b"short(double)", b"int(double)", True), + (b"int(double)", b"short(double)", False), + (b"float(double, int)", b"float(double, short)", True), + ], +) +def test_signature_provides(have, want, should_provide): + have = Signature.from_c_types(have) + want = Signature.from_c_types(want) + provides = have.provides(want.res_dtype, want.arg_dtypes) + assert provides == should_provide + + +@pytest.mark.parametrize( + "func, output, inputs, expected", + [ + ( + scipy.special.cython_special.agm, + np.float64, + [np.float64, np.float64], + float64(float64, float64, int32), + ), + ( + scipy.special.cython_special.erfc, + np.float64, + [np.float64], + float64(float64, int32), + ), + ( + scipy.special.cython_special.expit, + np.float32, + [np.float32], + float32(float32, int32), + ), + ( + scipy.special.cython_special.expit, + np.float64, + [np.float64], + float64(float64, int32), + ), + ( + # expn doesn't have a float32 implementation + scipy.special.cython_special.expn, + np.float32, + [np.float32, np.float32], + float64(float64, float64, int32), + ), + ( + # We choose the integer implementation if possible + scipy.special.cython_special.expn, + np.float32, + [np.int64, np.float32], + float64(int64, float64, int32), + ), + ], +) +def test_choose_signature(func, output, inputs, expected): + wrapper = wrap_cython_function(func, output, inputs) + assert wrapper.signature() == expected diff --git a/tests/link/test_numba.py b/tests/link/test_numba.py index 8d23ea888f..0fd1c4bbc3 100644 --- a/tests/link/test_numba.py +++ b/tests/link/test_numba.py @@ -319,10 +319,6 @@ def test_box_unbox(input, wrapper_fn, check_fn): assert check_fn(res, input) -@pytest.mark.parametrize( - "numba_scipy", - [True, False], -) @pytest.mark.parametrize( "inputs, input_vals, output_fn, exc", [ @@ -403,17 +399,16 @@ def test_box_unbox(input, wrapper_fn, check_fn): ), ], ) -def test_Elemwise(numba_scipy, inputs, input_vals, output_fn, exc): - with config.change_flags(numba_scipy=numba_scipy): - outputs = output_fn(*inputs) +def test_Elemwise(inputs, input_vals, output_fn, exc): + outputs = output_fn(*inputs) - out_fg = FunctionGraph( - outputs=[outputs] if not isinstance(outputs, list) else outputs - ) + out_fg = FunctionGraph( + outputs=[outputs] if not isinstance(outputs, list) else outputs + ) - cm = contextlib.suppress() if exc is None else pytest.raises(exc) - with cm: - compare_numba_and_py(out_fg, input_vals) + cm = contextlib.suppress() if exc is None else pytest.raises(exc) + with cm: + compare_numba_and_py(out_fg, input_vals) @pytest.mark.parametrize(