Skip to content

Commit

Permalink
API: Allow comparisons with and between any python integers
Browse files Browse the repository at this point in the history
This implements comparisons between NumPy integer arrays and arbitrary valued
Python integers when weak promotion is enabled.

To achieve this:
* I allow abstract DTypes (with small bug fixes) to register as loops (`ArrayMethods`).
  This is fine, you just need to take more care.
  It does muddy the waters between promotion and not a bit if the
  result DType would also be abstract.
  (For the specific case it doesn't, but in general it does.)
* A new `resolve_descriptors_raw` function, which does the same job as
  `resolve_descriptors` but I pass it this scalar argument
  (can be expanded, but starting small).
  * This only happens *when available*, so there are some niche paths were this cannot
    be used (`ufunc.at` and the explicit resolution function right now),
    we can deal with those by keeping the previous rules (things will just raise
    trying to convert).
  * The function also gets the actual `arrays.dtype` instances while I normally ensure that
    we pass in dtypes already cast to the correct DType class.
    (The reason is that we don't define how to cast the abstract DTypes as of now,
    and even if we did, it would not be what we need unless the dtype instance actually had
    the value information.)
* There are new loops added (for combinations!), which:
  * Use the new `resolve_descriptors_raw` (a single function dealing with everything)
  * Return the current legacy loop when that makes sense.
  * Return an always true/false loop when that makes sense.
  * To achieve this, they employ a hack/trick: `get_loop()` needs to know the value,
    but only `resolve_descriptors_raw()` does right now, so this is encoded on whether we use
    the `np.dtype("object")` singleton or a fresh instance!
    (Yes, probably ugly, but avoids channeling things to more places.)

Additionally, there is a promoter to say that Python integer comparisons can just use
`object` dtype (in theory weird if the input then wasn't a Python int,
but that is breaking promises).
  • Loading branch information
seberg committed Oct 18, 2023
1 parent 8e14541 commit af55348
Show file tree
Hide file tree
Showing 12 changed files with 634 additions and 48 deletions.
39 changes: 30 additions & 9 deletions numpy/_core/include/numpy/_dtype_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
#ifndef NUMPY_CORE_INCLUDE_NUMPY___DTYPE_API_H_
#define NUMPY_CORE_INCLUDE_NUMPY___DTYPE_API_H_

#define __EXPERIMENTAL_DTYPE_API_VERSION 13
#define __EXPERIMENTAL_DTYPE_API_VERSION 14

struct PyArrayMethodObject_tag;

Expand Down Expand Up @@ -129,16 +129,17 @@ typedef struct {
* SLOTS IDs For the ArrayMethod creation, once fully public, IDs are fixed
* but can be deprecated and arbitrarily extended.
*/
#define NPY_METH_resolve_descriptors 1
#define NPY_METH_resolve_descriptors_raw 1
#define NPY_METH_resolve_descriptors 2
/* We may want to adapt the `get_loop` signature a bit: */
#define _NPY_METH_get_loop 2
#define NPY_METH_get_reduction_initial 3
#define _NPY_METH_get_loop 4
#define NPY_METH_get_reduction_initial 5
/* specific loops for constructions/default get_loop: */
#define NPY_METH_strided_loop 4
#define NPY_METH_contiguous_loop 5
#define NPY_METH_unaligned_strided_loop 6
#define NPY_METH_unaligned_contiguous_loop 7
#define NPY_METH_contiguous_indexed_loop 8
#define NPY_METH_strided_loop 6
#define NPY_METH_contiguous_loop 7
#define NPY_METH_unaligned_strided_loop 8
#define NPY_METH_unaligned_contiguous_loop 9
#define NPY_METH_contiguous_indexed_loop 10

/*
* The resolve descriptors function, must be able to handle NULL values for
Expand All @@ -162,6 +163,26 @@ typedef NPY_CASTING (resolve_descriptors_function)(
npy_intp *view_offset);


/*
* Rarely needed, slightly more powerful version of `resolve_descriptors`.
* See also `resolve_descriptors_function` for details on shared arguments.
*/
typedef NPY_CASTING (resolve_descriptors_raw_function)(
struct PyArrayMethodObject_tag *method,
PyArray_DTypeMeta **dtypes,
/* Unlike above, these can have any DType and we may allow NULL. */
PyArray_Descr **given_descrs,
/*
* Input scalars or NULL. Only ever passed for python scalars.
* WARNING: In some cases, a loop may be explicitly selected and the
* value passed is not available (NULL) or does not have the
* expected type.
*/
PyObject **input_scalars,
PyArray_Descr **loop_descrs,
npy_intp *view_offset);


typedef int (PyArrayMethod_StridedLoop)(PyArrayMethod_Context *context,
char *const *data, const npy_intp *dimensions, const npy_intp *strides,
NpyAuxData *transferdata);
Expand Down
1 change: 1 addition & 0 deletions numpy/_core/meson.build
Original file line number Diff line number Diff line change
Expand Up @@ -1097,6 +1097,7 @@ src_umath = umath_gen_headers + [
src_file.process('src/umath/scalarmath.c.src'),
'src/umath/ufunc_object.c',
'src/umath/umathmodule.c',
'src/umath/special_comparisons.cpp',
'src/umath/string_ufuncs.cpp',
'src/umath/wrapping_array_method.c',
# For testing. Eventually, should use public API and be separate:
Expand Down
9 changes: 3 additions & 6 deletions numpy/_core/src/multiarray/array_method.c
Original file line number Diff line number Diff line change
Expand Up @@ -221,12 +221,6 @@ validate_spec(PyArrayMethod_Spec *spec)
"(method: %s)", spec->dtypes[i], spec->name);
return -1;
}
if (NPY_DT_is_abstract(spec->dtypes[i])) {
PyErr_Format(PyExc_TypeError,
"abstract DType %S are currently not supported."
"(method: %s)", spec->dtypes[i], spec->name);
return -1;
}
}
return 0;
}
Expand Down Expand Up @@ -261,6 +255,9 @@ fill_arraymethod_from_slots(
*/
for (PyType_Slot *slot = &spec->slots[0]; slot->slot != 0; slot++) {
switch (slot->slot) {
case NPY_METH_resolve_descriptors_raw:
meth->resolve_descriptors_raw = slot->pfunc;
continue;
case NPY_METH_resolve_descriptors:
meth->resolve_descriptors = slot->pfunc;
continue;
Expand Down
1 change: 1 addition & 0 deletions numpy/_core/src/multiarray/array_method.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ typedef struct PyArrayMethodObject_tag {
NPY_CASTING casting;
/* default flags. The get_strided_loop function can override these */
NPY_ARRAYMETHOD_FLAGS flags;
resolve_descriptors_raw_function *resolve_descriptors_raw;
resolve_descriptors_function *resolve_descriptors;
get_loop_function *get_strided_loop;
get_reduction_initial_function *get_reduction_initial;
Expand Down
6 changes: 6 additions & 0 deletions numpy/_core/src/umath/dispatching.c
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,12 @@ resolve_implementation_info(PyUFuncObject *ufunc,
== PyTuple_GET_ITEM(curr_dtypes, 2)) {
continue;
}
/*
* This should be a reduce, but doesn't follow the reduce
* pattern. So (for now?) consider this not a match.
*/
matches = NPY_FALSE;
continue;
}

if (resolver_dtype == (PyArray_DTypeMeta *)Py_None) {
Expand Down
8 changes: 6 additions & 2 deletions numpy/_core/src/umath/legacy_array_method.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,14 @@
#include "numpy/ufuncobject.h"
#include "array_method.h"

#ifdef __cplusplus
extern "C" {
#endif

NPY_NO_EXPORT PyArrayMethodObject *
PyArray_NewLegacyWrappingArrayMethod(PyUFuncObject *ufunc,
PyArray_DTypeMeta *signature[]);



/*
* The following two symbols are in the header so that other places can use
* them to probe for special cases (or whether an ArrayMethod is a "legacy"
Expand All @@ -29,5 +30,8 @@ NPY_NO_EXPORT NPY_CASTING
wrapped_legacy_resolve_descriptors(PyArrayMethodObject *,
PyArray_DTypeMeta **, PyArray_Descr **, PyArray_Descr **, npy_intp *);

#ifdef __cplusplus
}
#endif

#endif /*_NPY_LEGACY_ARRAY_METHOD_H */
22 changes: 22 additions & 0 deletions numpy/_core/src/umath/scalarmath.c.src
Original file line number Diff line number Diff line change
Expand Up @@ -1842,6 +1842,7 @@ static PyObject *
* LONG, ULONG, LONGLONG, ULONGLONG,
* HALF, FLOAT, DOUBLE, LONGDOUBLE,
* CFLOAT, CDOUBLE, CLONGDOUBLE#
* #isint = 1*10, 0*7#
* #simp = def*10, def_half, def*3, fcmplx, cmplx, lcmplx#
*/
#define IS_@name@
Expand All @@ -1852,6 +1853,27 @@ static PyObject*
npy_@name@ arg1, arg2;
int out = 0;

#if @isint@
/* Special case comparison with python integers */
if (PyLong_CheckExact(other)) {
PyObject *self_val = PyNumber_Index(self);
if (self_val == NULL) {
return NULL;
}
int res = PyObject_RichCompareBool(self_val, other, cmp_op);
Py_DECREF(self_val);
if (res < 0) {
return NULL;
}
else if (res) {
PyArrayScalar_RETURN_TRUE;
}
else {
PyArrayScalar_RETURN_FALSE;
}
}
#endif
/*
* Extract the other value (if it is compatible).
*/
Expand Down
Loading

0 comments on commit af55348

Please sign in to comment.