From af55348f5cbe338a86ed032812ee11e8714be673 Mon Sep 17 00:00:00 2001 From: Sebastian Berg Date: Tue, 10 Oct 2023 19:22:47 +0200 Subject: [PATCH] API: Allow comparisons with and between any python integers This implements comparisons between NumPy integer arrays and arbitrary valued Python integers when weak promotion is enabled. To achieve this: * I allow abstract DTypes (with small bug fixes) to register as loops (`ArrayMethods`). This is fine, you just need to take more care. It does muddy the waters between promotion and not a bit if the result DType would also be abstract. (For the specific case it doesn't, but in general it does.) * A new `resolve_descriptors_raw` function, which does the same job as `resolve_descriptors` but I pass it this scalar argument (can be expanded, but starting small). * This only happens *when available*, so there are some niche paths were this cannot be used (`ufunc.at` and the explicit resolution function right now), we can deal with those by keeping the previous rules (things will just raise trying to convert). * The function also gets the actual `arrays.dtype` instances while I normally ensure that we pass in dtypes already cast to the correct DType class. (The reason is that we don't define how to cast the abstract DTypes as of now, and even if we did, it would not be what we need unless the dtype instance actually had the value information.) * There are new loops added (for combinations!), which: * Use the new `resolve_descriptors_raw` (a single function dealing with everything) * Return the current legacy loop when that makes sense. * Return an always true/false loop when that makes sense. * To achieve this, they employ a hack/trick: `get_loop()` needs to know the value, but only `resolve_descriptors_raw()` does right now, so this is encoded on whether we use the `np.dtype("object")` singleton or a fresh instance! (Yes, probably ugly, but avoids channeling things to more places.) Additionally, there is a promoter to say that Python integer comparisons can just use `object` dtype (in theory weird if the input then wasn't a Python int, but that is breaking promises). --- numpy/_core/include/numpy/_dtype_api.h | 39 +- numpy/_core/meson.build | 1 + numpy/_core/src/multiarray/array_method.c | 9 +- numpy/_core/src/multiarray/array_method.h | 1 + numpy/_core/src/umath/dispatching.c | 6 + numpy/_core/src/umath/legacy_array_method.h | 8 +- numpy/_core/src/umath/scalarmath.c.src | 22 + numpy/_core/src/umath/special_comparisons.cpp | 464 ++++++++++++++++++ numpy/_core/src/umath/special_comparisons.h | 15 + numpy/_core/src/umath/ufunc_object.c | 104 +++- numpy/_core/src/umath/umathmodule.c | 5 + numpy/_core/tests/test_half.py | 8 +- 12 files changed, 634 insertions(+), 48 deletions(-) create mode 100644 numpy/_core/src/umath/special_comparisons.cpp create mode 100644 numpy/_core/src/umath/special_comparisons.h diff --git a/numpy/_core/include/numpy/_dtype_api.h b/numpy/_core/include/numpy/_dtype_api.h index 2bd06e1a1158..500b3e384d1d 100644 --- a/numpy/_core/include/numpy/_dtype_api.h +++ b/numpy/_core/include/numpy/_dtype_api.h @@ -5,7 +5,7 @@ #ifndef NUMPY_CORE_INCLUDE_NUMPY___DTYPE_API_H_ #define NUMPY_CORE_INCLUDE_NUMPY___DTYPE_API_H_ -#define __EXPERIMENTAL_DTYPE_API_VERSION 13 +#define __EXPERIMENTAL_DTYPE_API_VERSION 14 struct PyArrayMethodObject_tag; @@ -129,16 +129,17 @@ typedef struct { * SLOTS IDs For the ArrayMethod creation, once fully public, IDs are fixed * but can be deprecated and arbitrarily extended. */ -#define NPY_METH_resolve_descriptors 1 +#define NPY_METH_resolve_descriptors_raw 1 +#define NPY_METH_resolve_descriptors 2 /* We may want to adapt the `get_loop` signature a bit: */ -#define _NPY_METH_get_loop 2 -#define NPY_METH_get_reduction_initial 3 +#define _NPY_METH_get_loop 4 +#define NPY_METH_get_reduction_initial 5 /* specific loops for constructions/default get_loop: */ -#define NPY_METH_strided_loop 4 -#define NPY_METH_contiguous_loop 5 -#define NPY_METH_unaligned_strided_loop 6 -#define NPY_METH_unaligned_contiguous_loop 7 -#define NPY_METH_contiguous_indexed_loop 8 +#define NPY_METH_strided_loop 6 +#define NPY_METH_contiguous_loop 7 +#define NPY_METH_unaligned_strided_loop 8 +#define NPY_METH_unaligned_contiguous_loop 9 +#define NPY_METH_contiguous_indexed_loop 10 /* * The resolve descriptors function, must be able to handle NULL values for @@ -162,6 +163,26 @@ typedef NPY_CASTING (resolve_descriptors_function)( npy_intp *view_offset); +/* + * Rarely needed, slightly more powerful version of `resolve_descriptors`. + * See also `resolve_descriptors_function` for details on shared arguments. + */ +typedef NPY_CASTING (resolve_descriptors_raw_function)( + struct PyArrayMethodObject_tag *method, + PyArray_DTypeMeta **dtypes, + /* Unlike above, these can have any DType and we may allow NULL. */ + PyArray_Descr **given_descrs, + /* + * Input scalars or NULL. Only ever passed for python scalars. + * WARNING: In some cases, a loop may be explicitly selected and the + * value passed is not available (NULL) or does not have the + * expected type. + */ + PyObject **input_scalars, + PyArray_Descr **loop_descrs, + npy_intp *view_offset); + + typedef int (PyArrayMethod_StridedLoop)(PyArrayMethod_Context *context, char *const *data, const npy_intp *dimensions, const npy_intp *strides, NpyAuxData *transferdata); diff --git a/numpy/_core/meson.build b/numpy/_core/meson.build index 13b32adc7290..9ee37179848e 100644 --- a/numpy/_core/meson.build +++ b/numpy/_core/meson.build @@ -1097,6 +1097,7 @@ src_umath = umath_gen_headers + [ src_file.process('src/umath/scalarmath.c.src'), 'src/umath/ufunc_object.c', 'src/umath/umathmodule.c', + 'src/umath/special_comparisons.cpp', 'src/umath/string_ufuncs.cpp', 'src/umath/wrapping_array_method.c', # For testing. Eventually, should use public API and be separate: diff --git a/numpy/_core/src/multiarray/array_method.c b/numpy/_core/src/multiarray/array_method.c index 87515bb03aa8..bd537ccf40b3 100644 --- a/numpy/_core/src/multiarray/array_method.c +++ b/numpy/_core/src/multiarray/array_method.c @@ -221,12 +221,6 @@ validate_spec(PyArrayMethod_Spec *spec) "(method: %s)", spec->dtypes[i], spec->name); return -1; } - if (NPY_DT_is_abstract(spec->dtypes[i])) { - PyErr_Format(PyExc_TypeError, - "abstract DType %S are currently not supported." - "(method: %s)", spec->dtypes[i], spec->name); - return -1; - } } return 0; } @@ -261,6 +255,9 @@ fill_arraymethod_from_slots( */ for (PyType_Slot *slot = &spec->slots[0]; slot->slot != 0; slot++) { switch (slot->slot) { + case NPY_METH_resolve_descriptors_raw: + meth->resolve_descriptors_raw = slot->pfunc; + continue; case NPY_METH_resolve_descriptors: meth->resolve_descriptors = slot->pfunc; continue; diff --git a/numpy/_core/src/multiarray/array_method.h b/numpy/_core/src/multiarray/array_method.h index c82a968cd136..4a790ab76a3e 100644 --- a/numpy/_core/src/multiarray/array_method.h +++ b/numpy/_core/src/multiarray/array_method.h @@ -45,6 +45,7 @@ typedef struct PyArrayMethodObject_tag { NPY_CASTING casting; /* default flags. The get_strided_loop function can override these */ NPY_ARRAYMETHOD_FLAGS flags; + resolve_descriptors_raw_function *resolve_descriptors_raw; resolve_descriptors_function *resolve_descriptors; get_loop_function *get_strided_loop; get_reduction_initial_function *get_reduction_initial; diff --git a/numpy/_core/src/umath/dispatching.c b/numpy/_core/src/umath/dispatching.c index 9556ec0b8a9b..ba3cfb693c4f 100644 --- a/numpy/_core/src/umath/dispatching.c +++ b/numpy/_core/src/umath/dispatching.c @@ -275,6 +275,12 @@ resolve_implementation_info(PyUFuncObject *ufunc, == PyTuple_GET_ITEM(curr_dtypes, 2)) { continue; } + /* + * This should be a reduce, but doesn't follow the reduce + * pattern. So (for now?) consider this not a match. + */ + matches = NPY_FALSE; + continue; } if (resolver_dtype == (PyArray_DTypeMeta *)Py_None) { diff --git a/numpy/_core/src/umath/legacy_array_method.h b/numpy/_core/src/umath/legacy_array_method.h index 498fb1aa27c2..750de06c7992 100644 --- a/numpy/_core/src/umath/legacy_array_method.h +++ b/numpy/_core/src/umath/legacy_array_method.h @@ -5,13 +5,14 @@ #include "numpy/ufuncobject.h" #include "array_method.h" +#ifdef __cplusplus +extern "C" { +#endif NPY_NO_EXPORT PyArrayMethodObject * PyArray_NewLegacyWrappingArrayMethod(PyUFuncObject *ufunc, PyArray_DTypeMeta *signature[]); - - /* * The following two symbols are in the header so that other places can use * them to probe for special cases (or whether an ArrayMethod is a "legacy" @@ -29,5 +30,8 @@ NPY_NO_EXPORT NPY_CASTING wrapped_legacy_resolve_descriptors(PyArrayMethodObject *, PyArray_DTypeMeta **, PyArray_Descr **, PyArray_Descr **, npy_intp *); +#ifdef __cplusplus +} +#endif #endif /*_NPY_LEGACY_ARRAY_METHOD_H */ diff --git a/numpy/_core/src/umath/scalarmath.c.src b/numpy/_core/src/umath/scalarmath.c.src index 743ecc128659..f43e1493db2d 100644 --- a/numpy/_core/src/umath/scalarmath.c.src +++ b/numpy/_core/src/umath/scalarmath.c.src @@ -1842,6 +1842,7 @@ static PyObject * * LONG, ULONG, LONGLONG, ULONGLONG, * HALF, FLOAT, DOUBLE, LONGDOUBLE, * CFLOAT, CDOUBLE, CLONGDOUBLE# + * #isint = 1*10, 0*7# * #simp = def*10, def_half, def*3, fcmplx, cmplx, lcmplx# */ #define IS_@name@ @@ -1852,6 +1853,27 @@ static PyObject* npy_@name@ arg1, arg2; int out = 0; +#if @isint@ + /* Special case comparison with python integers */ + if (PyLong_CheckExact(other)) { + PyObject *self_val = PyNumber_Index(self); + if (self_val == NULL) { + return NULL; + } + int res = PyObject_RichCompareBool(self_val, other, cmp_op); + Py_DECREF(self_val); + if (res < 0) { + return NULL; + } + else if (res) { + PyArrayScalar_RETURN_TRUE; + } + else { + PyArrayScalar_RETURN_FALSE; + } + } +#endif + /* * Extract the other value (if it is compatible). */ diff --git a/numpy/_core/src/umath/special_comparisons.cpp b/numpy/_core/src/umath/special_comparisons.cpp new file mode 100644 index 000000000000..747e9008a2cd --- /dev/null +++ b/numpy/_core/src/umath/special_comparisons.cpp @@ -0,0 +1,464 @@ +#include + +#define NPY_NO_DEPRECATED_API NPY_API_VERSION +#define _MULTIARRAYMODULE +#define _UMATHMODULE + +#include "numpy/ndarraytypes.h" +#include "numpy/npy_math.h" +#include "numpy/ufuncobject.h" + +#include "abstractdtypes.h" +#include "dispatching.h" +#include "dtypemeta.h" +#include "common_dtype.h" +#include "convert_datatype.h" + +#include "legacy_array_method.h" /* For `get_wrapped_legacy_ufunc_loop`. */ +#include "special_comparisons.h" + + +/* + * Helper for templating, avoids warnings about uncovered switch paths. + */ +enum class COMP { + EQ, NE, LT, LE, GT, GE, +}; + +static char const * +comp_name(COMP comp) { + switch(comp) { + case COMP::EQ: return "equal"; + case COMP::NE: return "not_equal"; + case COMP::LT: return "less"; + case COMP::LE: return "less_equal"; + case COMP::GT: return "greater"; + case COMP::GE: return "greater_equal"; + default: + assert(0); + return nullptr; + } +} + + +template +static int +fixed_result_loop(PyArrayMethod_Context *NPY_UNUSED(context), + char *const data[], npy_intp const dimensions[], + npy_intp const strides[], NpyAuxData *NPY_UNUSED(auxdata)) +{ + npy_intp N = dimensions[0]; + char *out = data[2]; + npy_intp stride = strides[2]; + + while (N--) { + *reinterpret_cast(out) = result; + out += stride; + } + return 0; +} + +static inline void +get_min_max(int typenum, long long *min, unsigned long long *max) +{ + *min = 0; + switch (typenum) { + case NPY_BYTE: + *min = NPY_MIN_BYTE; + *max = NPY_MAX_BYTE; + break; + case NPY_UBYTE: + *max = NPY_MAX_UBYTE; + break; + case NPY_SHORT: + *min = NPY_MIN_SHORT; + *max = NPY_MAX_SHORT; + break; + case NPY_USHORT: + *max = NPY_MAX_USHORT; + break; + case NPY_INT: + *min = NPY_MIN_INT; + *max = NPY_MAX_INT; + break; + case NPY_UINT: + *max = NPY_MAX_UINT; + break; + case NPY_LONG: + *min = NPY_MIN_LONG; + *max = NPY_MAX_LONG; + break; + case NPY_ULONG: + *max = NPY_MAX_ULONG; + break; + case NPY_LONGLONG: + *min = NPY_MIN_LONGLONG; + *max = NPY_MAX_LONGLONG; + break; + case NPY_ULONGLONG: + *max = NPY_MAX_ULONGLONG; + break; + default: + assert(0); + } +} + + +/* + * Determine if a Python long is within the typenums range, smaller, or larger. + * + * Function returns -2 for errors. + */ +static int +get_value_range(PyObject *value, int type_num) +{ + long long min; + unsigned long long max; + get_min_max(type_num, &min, &max); + + int overflow; + long long val = PyLong_AsLongLongAndOverflow(value, &overflow); + if (val == -1 && overflow == 0 && PyErr_Occurred()) { + return (NPY_CASTING)-1; + } + + if (overflow == 0) { + if (val < min) { + return -1; + } + else if (val > 0 && (unsigned long long)val > max) { + return 1; + } + return 0; + } + else if (overflow < 0) { + return -1; + } + else if (max <= NPY_MAX_LONGLONG) { + return 1; + } + + /* + * If we are checking for unisgned long long, the value may be larger + * then long long, but within range of unsigned long long. Check this + * by doing the normal Python integer comparison. + */ + PyObject *obj = PyLong_FromUnsignedLongLong(max); + if (obj == NULL) { + return -2; + } + int cmp = PyObject_RichCompareBool(value, obj, Py_GT); + Py_DECREF(obj); + if (cmp < 0) { + return -2; + } + if (cmp) { + return 1; + } + else { + return 0; + } +} + + +/* + * Find the type resolution for any numpy_int with pyint comparison. This + * function supports *both* directions for all types. + */ +static NPY_CASTING +resolve_descriptors_raw( + PyArrayMethodObject *self, PyArray_DTypeMeta **dtypes, + PyArray_Descr **given_descrs, PyObject **input_scalars, + PyArray_Descr **loop_descrs, npy_intp *view_offset) +{ + int value_range = 0; + int arr_idx = 0; + int scalar_idx = 1; + + if (dtypes[0] == &PyArray_PyIntAbstractDType) { + arr_idx = 1; + scalar_idx = 0; + } + assert(PyTypeNum_ISINTEGER(dtypes[arr_idx]->type_num)); + PyArray_DTypeMeta *arr_dtype = dtypes[arr_idx]; + + /* + * Three way decision (with hack) on value range: + * 0: The value fits within the range of the dtype. + * 1: The value came second and is larger or came first and is smaller. + * -1: The value came second and is smaller or came first and is larger + */ + if (input_scalars[scalar_idx] != NULL + && PyLong_CheckExact(input_scalars[scalar_idx])) { + value_range = get_value_range(input_scalars[scalar_idx], arr_dtype->type_num); + if (value_range == -2) { + return (NPY_CASTING)-1; + } + if (arr_idx == 1) { + value_range *= -1; + } + } + + /* + * Very small/large values always need to be encoded as `object` dtype + * in order to never fail casting. + * TRICK: We encode the value range by whether or not we use the object + * singleton! This information is then available in `get_loop()`. + */ + if (value_range == 0) { + Py_INCREF(arr_dtype->singleton); + loop_descrs[scalar_idx] = arr_dtype->singleton; + } + else if (value_range < 0) { + loop_descrs[scalar_idx] = PyArray_DescrFromType(NPY_OBJECT); + } + else { + loop_descrs[scalar_idx] = PyArray_DescrNewFromType(NPY_OBJECT); + if (loop_descrs[scalar_idx] == NULL) { + return (NPY_CASTING)-1; + } + } + Py_INCREF(arr_dtype->singleton); + loop_descrs[arr_idx] = arr_dtype->singleton; + loop_descrs[2] = PyArray_DescrFromType(NPY_BOOL); + + return NPY_NO_CASTING; +} + + +template +static int +get_loop(PyArrayMethod_Context *context, + int aligned, int move_references, const npy_intp *strides, + PyArrayMethod_StridedLoop **out_loop, NpyAuxData **out_transferdata, + NPY_ARRAYMETHOD_FLAGS *flags) +{ + if (context->descriptors[1]->type_num == context->descriptors[0]->type_num) { + /* + * Fall back to the current implementation, which wraps legacy loops. + */ + return get_wrapped_legacy_ufunc_loop( + context, aligned, move_references, strides, + out_loop, out_transferdata, flags); + } + else { + PyArray_Descr *other_descr; + if (context->descriptors[1]->type_num == NPY_OBJECT) { + other_descr = context->descriptors[1]; + } + else { + assert(context->descriptors[0]->type_num == NPY_OBJECT); + other_descr = context->descriptors[0]; + } + /* HACK: If the descr is the singleton the result is smaller */ + PyArray_Descr *obj_singleton = PyArray_DescrFromType(NPY_OBJECT); + if (other_descr == obj_singleton) { + /* Singleton came second and is smaller, or first and is larger */ + switch (comp) { + case COMP::EQ: + case COMP::LT: + case COMP::LE: + *out_loop = &fixed_result_loop; + break; + case COMP::NE: + case COMP::GT: + case COMP::GE: + *out_loop = &fixed_result_loop; + break; + } + } + else { + /* Singleton came second and is larger, or first and is smaller */ + switch (comp) { + case COMP::EQ: + case COMP::GT: + case COMP::GE: + *out_loop = &fixed_result_loop; + break; + case COMP::NE: + case COMP::LT: + case COMP::LE: + *out_loop = &fixed_result_loop; + break; + } + } + Py_DECREF(obj_singleton); + } + *flags = NPY_METH_NO_FLOATINGPOINT_ERRORS; + return 0; +} + + +/* + * Machinery to add the string loops to the existing ufuncs. + */ + +/* + * Simple promoter that ensures we use the object loop when the input + * is python integers only. + * Note that if a user would pass the Python `int` abstract DType explicitly + * they promise to actually pass a Python int and we accept that we never + * check for that. + */ +static int +pyint_comparison_promoter(PyUFuncObject *NPY_UNUSED(ufunc), + PyArray_DTypeMeta *op_dtypes[], PyArray_DTypeMeta *signature[], + PyArray_DTypeMeta *new_op_dtypes[]) +{ + new_op_dtypes[0] = PyArray_DTypeFromTypeNum(NPY_OBJECT); + new_op_dtypes[1] = PyArray_DTypeFromTypeNum(NPY_OBJECT); + new_op_dtypes[2] = PyArray_DTypeFromTypeNum(NPY_BOOL); + return 0; +} + + +/* + * This function replaces the strided loop with the passed in one, + * and registers it with the given ufunc. + * It additionally adds a promoter for (pyint, pyint, bool) to use the + * (object, object, bool) implementation. + */ +template +static int +add_dtype_loops(PyObject *umath, PyArrayMethod_Spec *spec, PyObject *info) +{ + PyArray_DTypeMeta *PyInt = &PyArray_PyIntAbstractDType; + + PyObject *name = PyUnicode_FromString(comp_name(comp)); + if (name == nullptr) { + return -1; + } + PyUFuncObject *ufunc = (PyUFuncObject *)PyObject_GetItem(umath, name); + Py_DECREF(name); + if (ufunc == nullptr) { + return -1; + } + if (Py_TYPE(ufunc) != &PyUFunc_Type) { + PyErr_SetString(PyExc_RuntimeError, + "internal NumPy error: comparison not a ufunc"); + goto fail; + } + + /* + * NOTE: Iterates all type numbers, it would be nice to reduce this. + * (that would be easier if we consolidate int DTypes in general.) + */ + for (int typenum = NPY_BYTE; typenum <= NPY_ULONGLONG; typenum++) { + spec->slots[0].pfunc = (void *)get_loop; + + PyArray_DTypeMeta *Int = PyArray_DTypeFromTypeNum(typenum); + + /* Register the spec/loop for both forward and backward direction */ + spec->dtypes[0] = Int; + spec->dtypes[1] = PyInt; + int res = PyUFunc_AddLoopFromSpec((PyObject *)ufunc, spec); + if (res < 0) { + Py_DECREF(Int); + goto fail; + } + spec->dtypes[0] = PyInt; + spec->dtypes[1] = Int; + res = PyUFunc_AddLoopFromSpec((PyObject *)ufunc, spec); + Py_DECREF(Int); + if (res < 0) { + goto fail; + } + } + + /* + * Install the promoter info to allow two Python integers to work. + */ + return PyUFunc_AddLoop((PyUFuncObject *)ufunc, info, 0); + + Py_DECREF(ufunc); + return 0; + + fail: + Py_DECREF(ufunc); + return -1; +} + + +template +struct add_loops; + +template<> +struct add_loops<> { + int operator()(PyObject*, PyArrayMethod_Spec*, PyObject *) { + return 0; + } +}; + + +template +struct add_loops { + int operator()(PyObject* umath, PyArrayMethod_Spec* spec, PyObject *info) { + if (add_dtype_loops(umath, spec, info) < 0) { + return -1; + } + else { + return add_loops()(umath, spec, info); + } + } +}; + + +NPY_NO_EXPORT int +init_special_int_comparisons(PyObject *umath) +{ + int res = -1; + PyObject *info = NULL, *promoter = NULL; + /* NOTE: This should receive global symbols? */ + PyArray_DTypeMeta *Bool = PyArray_DTypeFromTypeNum(NPY_BOOL); + + /* We start with the string loops: */ + PyArray_DTypeMeta *dtypes[] = {NULL, NULL, Bool}; + /* + * We only have one loop right now, the strided one. The default type + * resolver ensures native byte order/canonical representation. + */ + PyType_Slot slots[] = { + {_NPY_METH_get_loop, nullptr}, + {NPY_METH_resolve_descriptors_raw, (void *)&resolve_descriptors_raw}, + {0, NULL}, + }; + + PyArrayMethod_Spec spec = {}; + spec.name = "templated_pyint_comp"; + spec.nin = 2; + spec.nout = 1; + spec.dtypes = dtypes; + spec.slots = slots; + spec.flags = NPY_METH_NO_FLOATINGPOINT_ERRORS; + + PyObject *dtype_tuple = PyTuple_Pack(3, + &PyArray_PyIntAbstractDType, &PyArray_PyIntAbstractDType, Bool); + if (dtype_tuple == NULL) { + goto finish; + } + promoter = PyCapsule_New( + (void *)&pyint_comparison_promoter, "numpy._ufunc_promoter", NULL); + if (promoter == NULL) { + Py_DECREF(promoter); + goto finish; + } + info = PyTuple_Pack(2, dtype_tuple, promoter); + Py_DECREF(dtype_tuple); + Py_DECREF(promoter); + if (info == NULL) { + goto finish; + } + + /* All String loops */ + using comp_looper = add_loops; + if (comp_looper()(umath, &spec, info) < 0) { + goto finish; + } + + res = 0; + finish: + + Py_XDECREF(info); + Py_DECREF(Bool); + return res; +} diff --git a/numpy/_core/src/umath/special_comparisons.h b/numpy/_core/src/umath/special_comparisons.h new file mode 100644 index 000000000000..2312bcae1e65 --- /dev/null +++ b/numpy/_core/src/umath/special_comparisons.h @@ -0,0 +1,15 @@ +#ifndef _NPY_CORE_SRC_UMATH_SPECIAL_COMPARISONS_H_ +#define _NPY_CORE_SRC_UMATH_SPECIAL_COMPARISONS_H_ + +#ifdef __cplusplus +extern "C" { +#endif + +NPY_NO_EXPORT int +init_special_int_comparisons(PyObject *umath); + +#ifdef __cplusplus +} +#endif + +#endif /* _NPY_CORE_SRC_UMATH_SPECIAL_COMPARISONS_H_ */ diff --git a/numpy/_core/src/umath/ufunc_object.c b/numpy/_core/src/umath/ufunc_object.c index 62e7f23ec918..a93856485764 100644 --- a/numpy/_core/src/umath/ufunc_object.c +++ b/numpy/_core/src/umath/ufunc_object.c @@ -115,10 +115,11 @@ static PyObject * prepare_input_arguments_for_outer(PyObject *args, PyUFuncObject *ufunc); static int -resolve_descriptors(int nop, +resolve_descriptors(int nop, int nin, PyUFuncObject *ufunc, PyArrayMethodObject *ufuncimpl, PyArrayObject *operands[], PyArray_Descr *dtypes[], - PyArray_DTypeMeta *signature[], NPY_CASTING casting); + PyArray_DTypeMeta *signature[], PyObject *inputs_tup, + NPY_CASTING casting); /*UFUNC_API*/ @@ -2803,8 +2804,8 @@ reducelike_promote_and_resolve(PyUFuncObject *ufunc, * casting safety could in principle be set to the default same-kind. * (although this should possibly happen through a deprecation) */ - if (resolve_descriptors(3, ufunc, ufuncimpl, - ops, out_descrs, signature, casting) < 0) { + if (resolve_descriptors(3, 2, ufunc, ufuncimpl, + ops, out_descrs, signature, NULL, casting) < 0) { return NULL; } @@ -4475,14 +4476,60 @@ _get_fixed_signature(PyUFuncObject *ufunc, * need to "cast" to string first). */ static int -resolve_descriptors(int nop, +resolve_descriptors(int nop, int nin, PyUFuncObject *ufunc, PyArrayMethodObject *ufuncimpl, PyArrayObject *operands[], PyArray_Descr *dtypes[], - PyArray_DTypeMeta *signature[], NPY_CASTING casting) + PyArray_DTypeMeta *signature[], PyObject *inputs_tup, + NPY_CASTING casting) { int retval = -1; + NPY_CASTING safety; PyArray_Descr *original_dtypes[NPY_MAXARGS]; + NPY_UF_DBG_PRINT("Resolving the descriptors\n"); + + if (NPY_UNLIKELY(ufuncimpl->resolve_descriptors_raw != NULL)) { + /* + * Allow a somewhat more powerful approach which: + * 1. Has access to scalars (currently only ever Python ones) + * 2. Can in principle customize `PyArray_CastDescrToDType()` + * (also because we want to avoid calling it for the scalars). + */ + PyObject *input_scalars[NPY_MAXARGS]; + for (int i = 0; i < nop; i++) { + if (operands[i] == NULL) { + original_dtypes[i] = NULL; + } + else { + /* For abstract DTypes, we might want to change what this is */ + original_dtypes[i] = PyArray_DTYPE(operands[i]); + Py_INCREF(original_dtypes[i]); + } + if (i < nin + && NPY_DT_is_abstract(signature[i]) + && inputs_tup != NULL) { + /* + * TODO: We may wish to allow any scalar here. Checking for + * abstract assumes this works out for Python scalars, + * which is the important case (especially for now). + * + * One possible check would be `DType->type == type(obj)`. + */ + input_scalars[i] = PyTuple_GET_ITEM(inputs_tup, i); + } + else { + input_scalars[i] = NULL; + } + } + + npy_intp view_offset = NPY_MIN_INTP; /* currently ignored */ + safety = ufuncimpl->resolve_descriptors_raw( + ufuncimpl, signature, original_dtypes, input_scalars, + dtypes, &view_offset + ); + goto check_safety; + } + for (int i = 0; i < nop; ++i) { if (operands[i] == NULL) { original_dtypes[i] = NULL; @@ -4501,26 +4548,13 @@ resolve_descriptors(int nop, } } - NPY_UF_DBG_PRINT("Resolving the descriptors\n"); - if (ufuncimpl->resolve_descriptors != &wrapped_legacy_resolve_descriptors) { /* The default: use the `ufuncimpl` as nature intended it */ npy_intp view_offset = NPY_MIN_INTP; /* currently ignored */ - NPY_CASTING safety = ufuncimpl->resolve_descriptors(ufuncimpl, + safety = ufuncimpl->resolve_descriptors(ufuncimpl, signature, original_dtypes, dtypes, &view_offset); - if (safety < 0) { - goto finish; - } - if (NPY_UNLIKELY(PyArray_MinCastSafety(safety, casting) != casting)) { - /* TODO: Currently impossible to reach (specialized unsafe loop) */ - PyErr_Format(PyExc_TypeError, - "The ufunc implementation for %s with the given dtype " - "signature is not possible under the casting rule %s", - ufunc_get_name_cstr(ufunc), npy_casting_to_string(casting)); - goto finish; - } - retval = 0; + goto check_safety; } else { /* @@ -4528,7 +4562,22 @@ resolve_descriptors(int nop, * for datetime64/timedelta64 and custom ufuncs (in pyerfa/astropy). */ retval = ufunc->type_resolver(ufunc, casting, operands, NULL, dtypes); + goto finish; + } + + check_safety: + if (safety < 0) { + goto finish; + } + if (NPY_UNLIKELY(PyArray_MinCastSafety(safety, casting) != casting)) { + /* TODO: Currently impossible to reach (specialized unsafe loop) */ + PyErr_Format(PyExc_TypeError, + "The ufunc implementation for %s with the given dtype " + "signature is not possible under the casting rule %s", + ufunc_get_name_cstr(ufunc), npy_casting_to_string(casting)); + goto finish; } + retval = 0; finish: for (int i = 0; i < nop; i++) { @@ -4856,8 +4905,8 @@ ufunc_generic_fastcall(PyUFuncObject *ufunc, } /* Find the correct descriptors for the operation */ - if (resolve_descriptors(nop, ufunc, ufuncimpl, - operands, operation_descrs, signature, casting) < 0) { + if (resolve_descriptors(nop, nin, ufunc, ufuncimpl, + operands, operation_descrs, signature, full_args.in, casting) < 0) { goto fail; } @@ -6228,8 +6277,8 @@ ufunc_at(PyUFuncObject *ufunc, PyObject *args) } /* Find the correct operation_descrs for the operation */ - int resolve_result = resolve_descriptors(nop, ufunc, ufuncimpl, - tmp_operands, operation_descrs, signature, NPY_UNSAFE_CASTING); + int resolve_result = resolve_descriptors(nop, ufunc->nin, ufunc, ufuncimpl, + tmp_operands, operation_descrs, signature, NULL, NPY_UNSAFE_CASTING); for (int i = 0; i < 3; i++) { Py_XDECREF(signature[i]); Py_XDECREF(operand_DTypes[i]); @@ -6557,8 +6606,9 @@ py_resolve_dtypes_generic(PyUFuncObject *ufunc, npy_bool return_context, } /* Find the correct descriptors for the operation */ - if (resolve_descriptors(ufunc->nargs, ufunc, ufuncimpl, - dummy_arrays, operation_descrs, signature, casting) < 0) { + if (resolve_descriptors(ufunc->nargs, ufunc->nin, ufunc, ufuncimpl, + dummy_arrays, operation_descrs, signature, + NULL, casting) < 0) { goto finish; } diff --git a/numpy/_core/src/umath/umathmodule.c b/numpy/_core/src/umath/umathmodule.c index 07a9159b0dcc..d0daa6fffaae 100644 --- a/numpy/_core/src/umath/umathmodule.c +++ b/numpy/_core/src/umath/umathmodule.c @@ -27,6 +27,7 @@ #include "number.h" #include "dispatching.h" #include "string_ufuncs.h" +#include "special_comparisons.h" #include "extobj.h" /* for _extobject_contextvar exposure */ /* Automatically generated code to define all ufuncs: */ @@ -334,5 +335,9 @@ int initumath(PyObject *m) return -1; } + if (init_special_int_comparisons(d) < 0) { + return -1; + } + return 0; } diff --git a/numpy/_core/tests/test_half.py b/numpy/_core/tests/test_half.py index 89bed2215357..954ba5987689 100644 --- a/numpy/_core/tests/test_half.py +++ b/numpy/_core/tests/test_half.py @@ -266,8 +266,8 @@ def test_half_correctness(self): if len(a32_fail) != 0: bad_index = a32_fail[0] assert_equal(self.finite_f32, a_manual, - "First non-equal is half value %x -> %g != %g" % - (self.finite_f16[bad_index], + "First non-equal is half value 0x%x -> %g != %g" % + (a_bits[bad_index], self.finite_f32[bad_index], a_manual[bad_index])) @@ -275,8 +275,8 @@ def test_half_correctness(self): if len(a64_fail) != 0: bad_index = a64_fail[0] assert_equal(self.finite_f64, a_manual, - "First non-equal is half value %x -> %g != %g" % - (self.finite_f16[bad_index], + "First non-equal is half value 0x%x -> %g != %g" % + (a_bits[bad_index], self.finite_f64[bad_index], a_manual[bad_index]))